mithril_stm/merkle_tree/
tree.rs

1use crate::error::MerkleTreeError;
2use crate::merkle_tree::{
3    left_child, parent, right_child, sibling, BatchPath, MTLeaf, MerkleTreeCommitment,
4    MerkleTreeCommitmentBatchCompat, Path,
5};
6use blake2::digest::{Digest, FixedOutput};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9
10/// Tree of hashes, providing a commitment of data and its ordering.
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub struct MerkleTree<D: Digest> {
13    /// The nodes are stored in an array heap:
14    /// * `nodes[0]` is the root,
15    /// * the parent of `nodes[i]` is `nodes[(i-1)/2]`
16    /// * the children of `nodes[i]` are `{nodes[2i + 1], nodes[2i + 2]}`
17    /// * All nodes have size `Output<D>::output_size()`, even leafs (which are hashed before committing them).
18    nodes: Vec<Vec<u8>>,
19    /// The leaves begin at `nodes[leaf_off]`.
20    leaf_off: usize,
21    /// Number of leaves cached in the merkle tree.
22    n: usize,
23    /// Phantom type to link the tree with its hasher
24    hasher: PhantomData<D>,
25}
26
27impl<D: Digest + FixedOutput> MerkleTree<D> {
28    /// Provided a non-empty list of leaves, `create` generates its corresponding `MerkleTree`.
29    pub fn create(leaves: &[MTLeaf]) -> MerkleTree<D> {
30        let n = leaves.len();
31        assert!(n > 0, "MerkleTree::create() called with no leaves");
32
33        let num_nodes = n + n.next_power_of_two() - 1;
34
35        let mut nodes = vec![vec![0u8]; num_nodes];
36
37        for i in 0..leaves.len() {
38            nodes[num_nodes - n + i] = D::digest(leaves[i].to_bytes()).to_vec();
39        }
40
41        for i in (0..num_nodes - n).rev() {
42            let z = D::digest([0u8]).to_vec();
43            let left = if left_child(i) < num_nodes {
44                &nodes[left_child(i)]
45            } else {
46                &z
47            };
48            let right = if right_child(i) < num_nodes {
49                &nodes[right_child(i)]
50            } else {
51                &z
52            };
53            nodes[i] = D::new()
54                .chain_update(left)
55                .chain_update(right)
56                .finalize()
57                .to_vec();
58        }
59
60        Self {
61            nodes,
62            n,
63            leaf_off: num_nodes - n,
64            hasher: PhantomData,
65        }
66    }
67
68    /// Get the root of the tree.
69    pub fn root(&self) -> &Vec<u8> {
70        &self.nodes[0]
71    }
72
73    /// Return the index of the leaf.
74    fn idx_of_leaf(&self, i: usize) -> usize {
75        self.leaf_off + i
76    }
77
78    /// Convert merkle tree to a batch compatible commitment.
79    /// This function simply returns the root and the number of leaves in the tree.
80    pub fn to_commitment_batch_compat(&self) -> MerkleTreeCommitmentBatchCompat<D> {
81        MerkleTreeCommitmentBatchCompat::new(self.nodes[0].clone(), self.n)
82    }
83
84    /// Get a path for a batch of leaves. The indices must be ordered. We use the Octopus algorithm to
85    /// avoid redundancy with nodes in the path. Let `x1, . . . , xk` be the indices of elements we
86    /// want to produce an opening for. The algorithm takes as input `x1, . . ., xk`, and  proceeds as follows:
87    /// 1. Initialise the proof vector, `proof = []`.
88    /// 2. Given an input vector `v = v1, . . .,vl`, if `v.len() == 1`, return `proof`, else, continue.
89    /// 3. Map each `vi` to the corresponding number of the leaf (by adding the offset).
90    /// 4. Initialise a new empty vector `p = []`. Next, iterate over each element `vi`
91    ///    a. Append the parent of `vi` to `p`
92    ///    b. Compute the sibling, `si` of `vi`
93    ///    c. If `si == v(i+1)` then do nothing, and skip step four for `v(i+1)`. Else append `si` to `proof`
94    /// 5. Iterate from step 2 with input vector `p`
95    ///
96    /// # Panics
97    /// If the indices provided are out of bounds (higher than the number of elements
98    /// committed in the `MerkleTree`) or are not ordered, the function fails.
99    // todo: Update doc.
100    pub fn get_batched_path(&self, indices: Vec<usize>) -> BatchPath<D>
101    where
102        D: FixedOutput,
103    {
104        assert!(
105            !indices.is_empty(),
106            "get_batched_path() called with no indices"
107        );
108        for i in &indices {
109            assert!(
110                i < &self.n,
111                "Proof index out of bounds: asked for {} out of {}",
112                i,
113                self.n
114            );
115        }
116
117        let mut ordered_indices: Vec<usize> = indices.clone();
118        ordered_indices.sort_unstable();
119
120        assert_eq!(ordered_indices, indices, "Indices should be ordered");
121
122        ordered_indices = ordered_indices
123            .into_iter()
124            .map(|i| self.idx_of_leaf(i))
125            .collect();
126
127        let mut idx = ordered_indices[0];
128        let mut proof = Vec::new();
129
130        while idx > 0 {
131            let mut new_indices = Vec::with_capacity(ordered_indices.len());
132            let mut i = 0;
133            idx = parent(idx);
134            while i < ordered_indices.len() {
135                new_indices.push(parent(ordered_indices[i]));
136                let sibling = sibling(ordered_indices[i]);
137                if i < ordered_indices.len() - 1 && ordered_indices[i + 1] == sibling {
138                    i += 1;
139                } else if sibling < self.nodes.len() {
140                    proof.push(self.nodes[sibling].clone());
141                }
142                i += 1;
143            }
144            ordered_indices.clone_from(&new_indices);
145        }
146
147        BatchPath::new(proof, indices)
148    }
149
150    /// Convert a `MerkleTree` into a byte string, containing $4 + n * S$ bytes where $n$ is the
151    /// number of nodes and $S$ the output size of the hash function.
152    /// # Layout
153    /// * Number of leaves committed in the Merkle Tree (as u64)
154    /// * All nodes of the merkle tree (starting with the root)
155    pub fn to_bytes(&self) -> Vec<u8> {
156        let mut result = Vec::with_capacity(8 + self.nodes.len() * <D as Digest>::output_size());
157        result.extend_from_slice(&u64::try_from(self.n).unwrap().to_be_bytes());
158        for node in self.nodes.iter() {
159            result.extend_from_slice(node);
160        }
161        result
162    }
163
164    /// Try to convert a byte string into a `MerkleTree`.
165    /// # Error
166    /// It returns error if conversion fails.
167    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MerkleTreeError<D>> {
168        let mut u64_bytes = [0u8; 8];
169        u64_bytes.copy_from_slice(&bytes[..8]);
170        let n = usize::try_from(u64::from_be_bytes(u64_bytes))
171            .map_err(|_| MerkleTreeError::SerializationError)?;
172        let num_nodes = n + n.next_power_of_two() - 1;
173        let mut nodes = Vec::with_capacity(num_nodes);
174        for i in 0..num_nodes {
175            nodes.push(
176                bytes[8 + i * <D as Digest>::output_size()
177                    ..8 + (i + 1) * <D as Digest>::output_size()]
178                    .to_vec(),
179            );
180        }
181        Ok(Self {
182            nodes,
183            leaf_off: num_nodes - n,
184            n,
185            hasher: PhantomData,
186        })
187    }
188
189    /// Convert merkle tree to a commitment. This function simply returns the root.
190    pub fn to_commitment(&self) -> MerkleTreeCommitment<D> {
191        MerkleTreeCommitment::new(self.nodes[0].clone()) // Use private constructor
192    }
193
194    /// Get a path (hashes of siblings of the path to the root node)
195    /// for the `i`th value stored in the tree.
196    /// Requires `i < self.n`
197    pub fn get_path(&self, i: usize) -> Path<D> {
198        assert!(
199            i < self.n,
200            "Proof index out of bounds: asked for {} out of {}",
201            i,
202            self.n
203        );
204        let mut idx = self.idx_of_leaf(i);
205        let mut proof = Vec::new();
206
207        while idx > 0 {
208            let h = if sibling(idx) < self.nodes.len() {
209                self.nodes[sibling(idx)].clone()
210            } else {
211                D::digest([0u8]).to_vec()
212            };
213            proof.push(h.clone());
214            idx = parent(idx);
215        }
216
217        Path::new(proof, i)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::multi_sig::VerificationKey;
225    use blake2::{digest::consts::U32, Blake2b};
226    use proptest::collection::vec;
227    use proptest::prelude::*;
228    use rand::{rng, seq::IteratorRandom};
229
230    fn pow2_plus1(h: usize) -> usize {
231        1 + 2_usize.pow(h as u32)
232    }
233
234    prop_compose! {
235        fn arb_tree(max_size: u32)
236                   (v in vec(any::<u64>(), 2..max_size as usize)) -> (MerkleTree<Blake2b<U32>>, Vec<MTLeaf>) {
237            let pks = vec![VerificationKey::default(); v.len()];
238            let leaves = pks.into_iter().zip(v.into_iter()).map(|(key, stake)| MTLeaf(key, stake)).collect::<Vec<MTLeaf>>();
239             (MerkleTree::<Blake2b<U32>>::create(&leaves), leaves)
240        }
241    }
242
243    proptest! {
244        // Test the relation that t.get_path(i) is a valid
245        // proof for i
246        #![proptest_config(ProptestConfig::with_cases(100))]
247        #[test]
248        fn test_create_proof((t, values) in arb_tree(30)) {
249            values.iter().enumerate().for_each(|(i, _v)| {
250                let pf = t.get_path(i);
251                assert!(t.to_commitment().check(&values[i], &pf).is_ok());
252            })
253        }
254
255        #[test]
256        fn test_bytes_path((t, values) in arb_tree(30)) {
257            values.iter().enumerate().for_each(|(i, _v)| {
258                let pf = t.get_path(i);
259                let bytes = pf.to_bytes();
260                let deserialised = Path::from_bytes(&bytes).unwrap();
261                assert!(t.to_commitment().check(&values[i], &deserialised).is_ok());
262
263                let encoded = bincode::serialize(&pf).unwrap();
264                let decoded: Path<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
265                assert!(t.to_commitment().check(&values[i], &decoded).is_ok());
266            })
267        }
268
269        #[test]
270        fn test_bytes_tree_commitment((t, values) in arb_tree(5)) {
271            let encoded = bincode::serialize(&t.to_commitment()).unwrap();
272            let decoded: MerkleTreeCommitment::<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
273            let tree_commitment = MerkleTree::<Blake2b<U32>>::create(&values).to_commitment();
274            assert_eq!(tree_commitment.root, decoded.root);
275        }
276
277        #[test]
278        fn test_bytes_tree((t, values) in arb_tree(5)) {
279            let bytes = t.to_bytes();
280            let deserialised = MerkleTree::<Blake2b<U32>>::from_bytes(&bytes).unwrap();
281            let tree = MerkleTree::<Blake2b<U32>>::create(&values);
282            assert_eq!(tree.nodes, deserialised.nodes);
283
284            let encoded = bincode::serialize(&t).unwrap();
285            let decoded: MerkleTree::<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
286            assert_eq!(tree.nodes, decoded.nodes);
287        }
288
289        #[test]
290        fn test_bytes_tree_commitment_batch_compat((t, values) in arb_tree(5)) {
291            let encoded = bincode::serialize(&t.to_commitment_batch_compat()).unwrap();
292            let decoded: MerkleTreeCommitmentBatchCompat::<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
293            let tree_commitment = MerkleTree::<Blake2b<U32>>::create(&values).to_commitment_batch_compat();
294            assert_eq!(tree_commitment.root, decoded.root);
295            assert_eq!(tree_commitment.get_nr_leaves(), decoded.get_nr_leaves());
296
297        }
298
299    }
300
301    prop_compose! {
302        // Returns values with a randomly generated path
303        fn values_with_invalid_proof(max_height: usize)
304                                    (h in 1..max_height)
305                                    (v in vec(any::<u64>(), 2..pow2_plus1(h)),
306                                     proof in vec(vec(any::<u8>(), 16), h)) -> (Vec<MTLeaf>, Vec<Vec<u8>>) {
307            let pks = vec![VerificationKey::default(); v.len()];
308            let leaves = pks.into_iter().zip(v.into_iter()).map(|(key, stake)| MTLeaf(key, stake)).collect::<Vec<MTLeaf>>();
309            (leaves, proof)
310        }
311    }
312
313    proptest! {
314        #[test]
315        fn test_create_invalid_proof(
316            i in any::<usize>(),
317            (values, proof) in values_with_invalid_proof(10)
318        ) {
319            let t = MerkleTree::<Blake2b<U32>>::create(&values[1..]);
320            let index = i % (values.len() - 1);
321            let path_values = proof. iter().map(|x|  Blake2b::<U32>::digest(x).to_vec()).collect();
322            let path = Path::new(path_values, index);
323            assert!(t.to_commitment().check(&values[0], &path).is_err());
324        }
325
326        #[test]
327        fn test_create_invalid_batch_proof(
328            i in any::<usize>(),
329            (values, proof) in values_with_invalid_proof(10)
330        ) {
331            let t = MerkleTree::<Blake2b<U32>>::create(&values[1..]);
332            let indices = vec![i % (values.len() - 1); values.len() / 2];
333            let batch_values = vec![values[i % (values.len() - 1)]; values.len() / 2];
334            let path = BatchPath{values: proof
335                            .iter()
336                            .map(|x|  Blake2b::<U32>::digest(x).to_vec())
337                            .collect(),
338                indices,
339                hasher: PhantomData::<Blake2b<U32>>
340                };
341            assert!(t.to_commitment_batch_compat().check(&batch_values, &path).is_err());
342        }
343    }
344
345    prop_compose! {
346        fn arb_tree_arb_batch(max_size: u32)
347                   (v in vec(any::<u64>(), 2..max_size as usize)) -> (MerkleTree<Blake2b<U32>>, Vec<MTLeaf>, Vec<usize>) {
348            let mut rng = rng();
349            let size = v.len();
350            let pks = vec![VerificationKey::default(); size];
351            let leaves = pks.into_iter().zip(v.into_iter()).map(|(key, stake)| MTLeaf(key, stake)).collect::<Vec<MTLeaf>>();
352
353            let indices: Vec<usize> = (0..size).collect();
354            let mut mt_list: Vec<usize> = indices.into_iter().choose_multiple(&mut rng, size * 2 / 10 + 1);
355            mt_list.sort_unstable();
356
357            let mut batch_values: Vec<MTLeaf> = Vec::with_capacity(mt_list.len());
358            for i in mt_list.iter() {
359                batch_values.push(leaves[*i]);
360            }
361
362            (MerkleTree::<Blake2b<U32>>::create(&leaves), batch_values, mt_list)
363        }
364    }
365
366    proptest! {
367        #![proptest_config(ProptestConfig::with_cases(100))]
368        #[test]
369        fn test_create_batch_proof((t, batch_values, indices) in arb_tree_arb_batch(30)) {
370            let batch_proof = t.get_batched_path(indices);
371            assert!(t.to_commitment_batch_compat().check(&batch_values, &batch_proof).is_ok());
372        }
373
374        #[test]
375        fn test_bytes_batch_path((t, batch_values, indices) in arb_tree_arb_batch(30)) {
376            let bp = t.get_batched_path(indices);
377
378            let bytes = &bp.to_bytes();
379            let deserialized = BatchPath::from_bytes(bytes).unwrap();
380            assert!(t.to_commitment_batch_compat().check(&batch_values, &deserialized).is_ok());
381
382            let encoded = bincode::serialize(&bp).unwrap();
383            let decoded: BatchPath<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
384            assert!(t.to_commitment_batch_compat().check(&batch_values, &decoded).is_ok());
385        }
386    }
387}