mithril_common/crypto_helper/
merkle_tree.rs

1use anyhow::{anyhow, Context};
2use blake2::{Blake2s256, Digest};
3use ckb_merkle_mountain_range::{
4    Error as MMRError, MMRStoreReadOps, MMRStoreWriteOps, Merge, MerkleProof, Result as MMRResult,
5    MMR,
6};
7use serde::{Deserialize, Serialize};
8use std::{
9    collections::{BTreeMap, HashMap},
10    fmt::Display,
11    ops::{Add, Deref},
12    sync::{Arc, RwLock},
13};
14
15use crate::{StdError, StdResult};
16
17/// Alias for a byte
18pub type Bytes = Vec<u8>;
19
20/// Alias for a Merkle tree leaf position
21pub type MKTreeLeafPosition = u64;
22
23/// A node of a Merkle tree
24#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash, Serialize, Deserialize)]
25pub struct MKTreeNode {
26    hash: Bytes,
27}
28
29impl MKTreeNode {
30    /// MKTreeNode factory
31    pub fn new(hash: Bytes) -> Self {
32        Self { hash }
33    }
34
35    /// Create a MKTreeNode from a hex representation
36    pub fn from_hex(hex: &str) -> StdResult<Self> {
37        let hash = hex::decode(hex)?;
38        Ok(Self { hash })
39    }
40
41    /// Create a hex representation of the MKTreeNode
42    pub fn to_hex(&self) -> String {
43        hex::encode(&self.hash)
44    }
45}
46
47impl Deref for MKTreeNode {
48    type Target = Bytes;
49
50    fn deref(&self) -> &Self::Target {
51        &self.hash
52    }
53}
54
55impl From<String> for MKTreeNode {
56    fn from(other: String) -> Self {
57        Self {
58            hash: other.as_str().into(),
59        }
60    }
61}
62
63impl From<&String> for MKTreeNode {
64    fn from(other: &String) -> Self {
65        Self {
66            hash: other.as_str().into(),
67        }
68    }
69}
70
71impl From<&str> for MKTreeNode {
72    fn from(other: &str) -> Self {
73        Self {
74            hash: other.as_bytes().to_vec(),
75        }
76    }
77}
78
79impl<S: MKTreeStorer> TryFrom<MKTree<S>> for MKTreeNode {
80    type Error = StdError;
81    fn try_from(other: MKTree<S>) -> Result<Self, Self::Error> {
82        other.compute_root()
83    }
84}
85
86impl<S: MKTreeStorer> TryFrom<&MKTree<S>> for MKTreeNode {
87    type Error = StdError;
88    fn try_from(other: &MKTree<S>) -> Result<Self, Self::Error> {
89        other.compute_root()
90    }
91}
92
93impl Display for MKTreeNode {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "{}", String::from_utf8_lossy(&self.hash))
96    }
97}
98
99impl Add for MKTreeNode {
100    type Output = MKTreeNode;
101
102    fn add(self, other: MKTreeNode) -> MKTreeNode {
103        &self + &other
104    }
105}
106
107impl Add for &MKTreeNode {
108    type Output = MKTreeNode;
109
110    fn add(self, other: &MKTreeNode) -> MKTreeNode {
111        let mut hasher = Blake2s256::new();
112        hasher.update(self.deref());
113        hasher.update(other.deref());
114        let hash_merge = hasher.finalize();
115        MKTreeNode::new(hash_merge.to_vec())
116    }
117}
118
119struct MergeMKTreeNode {}
120
121impl Merge for MergeMKTreeNode {
122    type Item = Arc<MKTreeNode>;
123
124    fn merge(lhs: &Self::Item, rhs: &Self::Item) -> MMRResult<Self::Item> {
125        Ok(Arc::new((**lhs).clone() + (**rhs).clone()))
126    }
127}
128
129/// A Merkle proof
130#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
131pub struct MKProof {
132    inner_root: Arc<MKTreeNode>,
133    inner_leaves: Vec<(MKTreeLeafPosition, Arc<MKTreeNode>)>,
134    inner_proof_size: u64,
135    inner_proof_items: Vec<Arc<MKTreeNode>>,
136}
137
138impl MKProof {
139    /// Return a reference to its merkle root.
140    pub fn root(&self) -> &MKTreeNode {
141        &self.inner_root
142    }
143
144    /// Verification of a Merkle proof
145    pub fn verify(&self) -> StdResult<()> {
146        MerkleProof::<Arc<MKTreeNode>, MergeMKTreeNode>::new(
147            self.inner_proof_size,
148            self.inner_proof_items.clone(),
149        )
150        .verify(self.inner_root.to_owned(), self.inner_leaves.to_owned())?
151        .then_some(())
152        .ok_or(anyhow!("Invalid MKProof"))
153    }
154
155    /// Check if the proof contains the given leaves
156    pub fn contains(&self, leaves: &[MKTreeNode]) -> StdResult<()> {
157        leaves
158            .iter()
159            .all(|leaf| self.inner_leaves.iter().any(|(_, l)| l.deref() == leaf))
160            .then_some(())
161            .ok_or(anyhow!("Leaves not found in the MKProof"))
162    }
163
164    /// List the leaves of the proof
165    pub fn leaves(&self) -> Vec<MKTreeNode> {
166        self.inner_leaves
167            .iter()
168            .map(|(_, l)| (**l).clone())
169            .collect::<Vec<_>>()
170    }
171
172    cfg_test_tools! {
173        /// Build a [MKProof] based on the given leaves (*Test only*).
174        pub fn from_leaves<T: Into<MKTreeNode> + Clone>(
175            leaves: &[T],
176        ) -> StdResult<MKProof> {
177            Self::from_subset_of_leaves(leaves, leaves)
178        }
179
180        /// Build a [MKProof] based on the given leaves (*Test only*).
181        pub fn from_subset_of_leaves<T: Into<MKTreeNode> + Clone>(
182            leaves: &[T],
183            leaves_to_verify: &[T],
184        ) -> StdResult<MKProof> {
185            let leaves = Self::list_to_mknode(leaves);
186            let leaves_to_verify =
187                Self::list_to_mknode(leaves_to_verify);
188
189            let mktree =
190                MKTree::<MKTreeStoreInMemory>::new(&leaves).with_context(|| "MKTree creation should not fail")?;
191            mktree.compute_proof(&leaves_to_verify)
192        }
193
194        fn list_to_mknode<T: Into<MKTreeNode> + Clone>(hashes: &[T]) -> Vec<MKTreeNode> {
195            hashes.iter().map(|h| h.clone().into()).collect()
196        }
197    }
198}
199
200impl From<MKProof> for MKTreeNode {
201    fn from(other: MKProof) -> Self {
202        other.root().to_owned()
203    }
204}
205
206/// A Merkle tree store in memory
207#[derive(Clone)]
208pub struct MKTreeStoreInMemory {
209    inner_leaves: Arc<RwLock<HashMap<Arc<MKTreeNode>, MKTreeLeafPosition>>>,
210    inner_store: Arc<RwLock<HashMap<u64, Arc<MKTreeNode>>>>,
211}
212
213impl MKTreeStoreInMemory {
214    fn new() -> Self {
215        Self {
216            inner_leaves: Arc::new(RwLock::new(HashMap::new())),
217            inner_store: Arc::new(RwLock::new(HashMap::new())),
218        }
219    }
220}
221
222impl MKTreeLeafIndexer for MKTreeStoreInMemory {
223    fn set_leaf_position(&self, pos: MKTreeLeafPosition, node: Arc<MKTreeNode>) -> StdResult<()> {
224        let mut inner_leaves = self.inner_leaves.write().unwrap();
225        (*inner_leaves).insert(node, pos);
226
227        Ok(())
228    }
229
230    fn get_leaf_position(&self, node: &MKTreeNode) -> Option<MKTreeLeafPosition> {
231        let inner_leaves = self.inner_leaves.read().unwrap();
232        (*inner_leaves).get(node).cloned()
233    }
234
235    fn total_leaves(&self) -> usize {
236        let inner_leaves = self.inner_leaves.read().unwrap();
237        (*inner_leaves).len()
238    }
239
240    fn leaves(&self) -> Vec<MKTreeNode> {
241        let inner_leaves = self.inner_leaves.read().unwrap();
242        (*inner_leaves)
243            .iter()
244            .map(|(leaf, position)| (position, leaf))
245            .collect::<BTreeMap<_, _>>()
246            .into_values()
247            .map(|leaf| (**leaf).clone())
248            .collect()
249    }
250}
251
252impl MKTreeStorer for MKTreeStoreInMemory {
253    fn build() -> StdResult<Self> {
254        Ok(Self::new())
255    }
256
257    fn get_elem(&self, pos: u64) -> StdResult<Option<Arc<MKTreeNode>>> {
258        let inner_store = self.inner_store.read().unwrap();
259
260        Ok((*inner_store).get(&pos).cloned())
261    }
262
263    fn append(&self, pos: u64, elems: Vec<Arc<MKTreeNode>>) -> StdResult<()> {
264        let mut inner_store = self.inner_store.write().unwrap();
265        for (i, elem) in elems.into_iter().enumerate() {
266            (*inner_store).insert(pos + i as u64, elem);
267        }
268
269        Ok(())
270    }
271}
272
273/// The Merkle tree storer trait
274pub trait MKTreeStorer: Clone + Send + Sync + MKTreeLeafIndexer {
275    /// Try to create a new instance of the storer
276    fn build() -> StdResult<Self>;
277
278    /// Get the element at the given position
279    fn get_elem(&self, pos: u64) -> StdResult<Option<Arc<MKTreeNode>>>;
280
281    /// Append elements at the given position
282    fn append(&self, pos: u64, elems: Vec<Arc<MKTreeNode>>) -> StdResult<()>;
283}
284
285/// This struct exists only to implement for a [MkTreeStore] the [MMRStoreReadOps] and
286/// [MMRStoreWriteOps] from merkle_mountain_range crate without the need to reexport types
287/// from that crate.
288///
289/// Rust don't allow the following:
290/// ```ignore
291/// impl<S: MKTreeStorer> MMRStoreReadOps<Arc<MKTreeNode>> for S {}
292/// ```
293/// Since it disallows implementations of traits for arbitrary types which are not defined in
294/// the same crate as the trait itself (see [E0117](https://doc.rust-lang.org/error_codes/E0117.html)).
295struct MKTreeStore<S: MKTreeStorer> {
296    storer: Box<S>,
297}
298
299impl<S: MKTreeStorer> MKTreeStore<S> {
300    fn build() -> StdResult<Self> {
301        let storer = Box::new(S::build()?);
302        Ok(Self { storer })
303    }
304}
305
306impl<S: MKTreeStorer> MMRStoreReadOps<Arc<MKTreeNode>> for MKTreeStore<S> {
307    fn get_elem(&self, pos: u64) -> MMRResult<Option<Arc<MKTreeNode>>> {
308        self.storer
309            .get_elem(pos)
310            .map_err(|e| MMRError::StoreError(e.to_string()))
311    }
312}
313
314impl<S: MKTreeStorer> MMRStoreWriteOps<Arc<MKTreeNode>> for MKTreeStore<S> {
315    fn append(&mut self, pos: u64, elems: Vec<Arc<MKTreeNode>>) -> MMRResult<()> {
316        self.storer
317            .append(pos, elems)
318            .map_err(|e| MMRError::StoreError(e.to_string()))
319    }
320}
321
322impl<S: MKTreeStorer> MKTreeLeafIndexer for MKTreeStore<S> {
323    fn set_leaf_position(&self, pos: MKTreeLeafPosition, leaf: Arc<MKTreeNode>) -> StdResult<()> {
324        self.storer.set_leaf_position(pos, leaf)
325    }
326
327    fn get_leaf_position(&self, leaf: &MKTreeNode) -> Option<MKTreeLeafPosition> {
328        self.storer.get_leaf_position(leaf)
329    }
330
331    fn total_leaves(&self) -> usize {
332        self.storer.total_leaves()
333    }
334
335    fn leaves(&self) -> Vec<MKTreeNode> {
336        self.storer.leaves()
337    }
338}
339
340/// The Merkle tree leaves indexer trait
341pub trait MKTreeLeafIndexer {
342    /// Get the position of the leaf in the Merkle tree
343    fn set_leaf_position(&self, pos: MKTreeLeafPosition, leaf: Arc<MKTreeNode>) -> StdResult<()>;
344
345    /// Get the position of the leaf in the Merkle tree
346    fn get_leaf_position(&self, leaf: &MKTreeNode) -> Option<MKTreeLeafPosition>;
347
348    /// Number of leaves in the Merkle tree
349    fn total_leaves(&self) -> usize;
350
351    /// List of leaves with their positions in the Merkle tree
352    fn leaves(&self) -> Vec<MKTreeNode>;
353
354    /// Check if the Merkle tree contains the given leaf
355    fn contains_leaf(&self, leaf: &MKTreeNode) -> bool {
356        self.get_leaf_position(leaf).is_some()
357    }
358}
359
360/// A Merkle tree
361pub struct MKTree<S: MKTreeStorer> {
362    inner_tree: MMR<Arc<MKTreeNode>, MergeMKTreeNode, MKTreeStore<S>>,
363}
364
365impl<S: MKTreeStorer> MKTree<S> {
366    /// MKTree factory
367    pub fn new<T: Into<MKTreeNode> + Clone>(leaves: &[T]) -> StdResult<Self> {
368        let mut inner_tree = MMR::<_, _, _>::new(0, MKTreeStore::<S>::build()?);
369        for leaf in leaves {
370            let leaf = Arc::new(leaf.to_owned().into());
371            let inner_tree_position = inner_tree.push(leaf.clone())?;
372            inner_tree
373                .store()
374                .set_leaf_position(inner_tree_position, leaf.clone())?;
375        }
376        inner_tree.commit()?;
377
378        Ok(Self { inner_tree })
379    }
380
381    /// Append leaves to the Merkle tree
382    pub fn append<T: Into<MKTreeNode> + Clone>(&mut self, leaves: &[T]) -> StdResult<()> {
383        for leaf in leaves {
384            let leaf = Arc::new(leaf.to_owned().into());
385            let inner_tree_position = self.inner_tree.push(leaf.clone())?;
386            self.inner_tree
387                .store()
388                .set_leaf_position(inner_tree_position, leaf.clone())?;
389        }
390        self.inner_tree.commit()?;
391
392        Ok(())
393    }
394
395    /// Number of leaves in the Merkle tree
396    pub fn total_leaves(&self) -> usize {
397        self.inner_tree.store().total_leaves()
398    }
399
400    /// List of leaves with their positions in the Merkle tree
401    pub fn leaves(&self) -> Vec<MKTreeNode> {
402        self.inner_tree.store().leaves()
403    }
404
405    /// Check if the Merkle tree contains the given leaf
406    pub fn contains(&self, leaf: &MKTreeNode) -> bool {
407        self.inner_tree.store().contains_leaf(leaf)
408    }
409
410    /// Generate root of the Merkle tree
411    pub fn compute_root(&self) -> StdResult<MKTreeNode> {
412        Ok((*self
413            .inner_tree
414            .get_root()
415            .with_context(|| "Could not compute Merkle Tree root")?)
416        .clone())
417    }
418
419    /// Generate Merkle proof of memberships in the tree
420    pub fn compute_proof(&self, leaves: &[MKTreeNode]) -> StdResult<MKProof> {
421        let inner_leaves = leaves
422            .iter()
423            .map(|leaf| {
424                if let Some(leaf_position) = self.inner_tree.store().get_leaf_position(leaf) {
425                    Ok((leaf_position, Arc::new(leaf.to_owned())))
426                } else {
427                    Err(anyhow!("Leaf not found in the Merkle tree"))
428                }
429            })
430            .collect::<StdResult<Vec<_>>>()?;
431        let proof = self.inner_tree.gen_proof(
432            inner_leaves
433                .iter()
434                .map(|(leaf_position, _leaf)| *leaf_position)
435                .collect(),
436        )?;
437        Ok(MKProof {
438            inner_root: Arc::new(self.compute_root()?),
439            inner_leaves,
440            inner_proof_size: proof.mmr_size(),
441            inner_proof_items: proof.proof_items().to_vec(),
442        })
443    }
444}
445
446impl<S: MKTreeStorer> Clone for MKTree<S> {
447    fn clone(&self) -> Self {
448        // Cloning should never fail so unwrap is safe
449        Self::new(&self.leaves()).unwrap()
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    fn generate_leaves(total_leaves: usize) -> Vec<MKTreeNode> {
458        (0..total_leaves)
459            .map(|i| format!("test-{i}").into())
460            .collect()
461    }
462
463    #[test]
464    fn test_golden_merkle_root() {
465        let leaves = vec!["golden-1", "golden-2", "golden-3", "golden-4", "golden-5"];
466        let mktree =
467            MKTree::<MKTreeStoreInMemory>::new(&leaves).expect("MKTree creation should not fail");
468        let mkroot = mktree
469            .compute_root()
470            .expect("MKRoot generation should not fail");
471
472        assert_eq!(
473            "3bbced153528697ecde7345a22e50115306478353619411523e804f2323fd921",
474            mkroot.to_hex()
475        );
476    }
477
478    #[test]
479    fn test_should_accept_valid_proof_generated_by_merkle_tree() {
480        let leaves = generate_leaves(10);
481        let leaves_to_verify = &[leaves[0].to_owned(), leaves[3].to_owned()];
482        let proof =
483            MKProof::from_leaves(leaves_to_verify).expect("MKProof generation should not fail");
484        proof.verify().expect("The MKProof should be valid");
485    }
486
487    #[test]
488    fn test_should_reject_invalid_proof_generated_by_merkle_tree() {
489        let leaves = generate_leaves(10);
490        let leaves_to_verify = &[leaves[0].to_owned(), leaves[3].to_owned()];
491        let mut proof =
492            MKProof::from_leaves(leaves_to_verify).expect("MKProof generation should not fail");
493        proof.inner_root = Arc::new(leaves[1].to_owned());
494        proof.verify().expect_err("The MKProof should be invalid");
495    }
496
497    #[test]
498    fn test_should_list_leaves() {
499        let leaves: Vec<MKTreeNode> = vec!["test-0".into(), "test-1".into(), "test-2".into()];
500        let mktree =
501            MKTree::<MKTreeStoreInMemory>::new(&leaves).expect("MKTree creation should not fail");
502        let leaves_retrieved = mktree.leaves();
503
504        assert_eq!(
505            leaves.iter().collect::<Vec<_>>(),
506            leaves_retrieved.iter().collect::<Vec<_>>()
507        );
508    }
509
510    #[test]
511    fn test_should_clone_and_compute_same_root() {
512        let leaves = generate_leaves(10);
513        let mktree =
514            MKTree::<MKTreeStoreInMemory>::new(&leaves).expect("MKTree creation should not fail");
515        let mktree_clone = mktree.clone();
516
517        assert_eq!(
518            mktree.compute_root().unwrap(),
519            mktree_clone.compute_root().unwrap(),
520        );
521    }
522
523    #[test]
524    fn test_should_support_append_leaves() {
525        let leaves = generate_leaves(10);
526        let leaves_creation = &leaves[..9];
527        let leaves_to_append = &leaves[9..];
528        let mut mktree = MKTree::<MKTreeStoreInMemory>::new(leaves_creation)
529            .expect("MKTree creation should not fail");
530        mktree
531            .append(leaves_to_append)
532            .expect("MKTree append leaves should not fail");
533
534        assert_eq!(10, mktree.total_leaves());
535    }
536
537    #[test]
538    fn tree_node_from_to_string() {
539        let expected_str = "my_string";
540        let expected_string = expected_str.to_string();
541        let node_str: MKTreeNode = expected_str.into();
542        let node_string: MKTreeNode = expected_string.clone().into();
543
544        assert_eq!(node_str.to_string(), expected_str);
545        assert_eq!(node_string.to_string(), expected_string);
546    }
547
548    #[test]
549    fn contains_leaves() {
550        let mut leaves_to_verify = generate_leaves(10);
551        let leaves_not_verified = leaves_to_verify.drain(3..6).collect::<Vec<_>>();
552        let proof =
553            MKProof::from_leaves(&leaves_to_verify).expect("MKProof generation should not fail");
554
555        // contains everything
556        proof.contains(&leaves_to_verify).unwrap();
557
558        // contains subpart
559        proof.contains(&leaves_to_verify[0..2]).unwrap();
560
561        // don't contains all not verified
562        proof.contains(&leaves_not_verified).unwrap_err();
563
564        // don't contains subpart of not verified
565        proof.contains(&leaves_not_verified[1..2]).unwrap_err();
566
567        // fail if part verified and part unverified
568        proof
569            .contains(&[
570                leaves_to_verify[2].to_owned(),
571                leaves_not_verified[0].to_owned(),
572            ])
573            .unwrap_err();
574    }
575
576    #[test]
577    fn list_leaves() {
578        let leaves_to_verify = generate_leaves(10);
579        let proof =
580            MKProof::from_leaves(&leaves_to_verify).expect("MKProof generation should not fail");
581
582        let proof_leaves = proof.leaves();
583        assert_eq!(proof_leaves, leaves_to_verify);
584    }
585}