mithril_common/crypto_helper/
merkle_tree.rs

1use anyhow::{Context, anyhow};
2use blake2::{Blake2s256, Digest};
3use ckb_merkle_mountain_range::{
4    Error as MMRError, MMR, MMRStoreReadOps, MMRStoreWriteOps, Merge, MerkleProof,
5    Result as MMRResult,
6};
7use serde::{Deserialize, Serialize};
8use std::{
9    collections::{BTreeMap, HashMap},
10    fmt::Display,
11    ops::{Add, Deref},
12    sync::{Arc, RwLock},
13};
14
15use crate::{StdError, StdResult};
16
17/// Alias for a byte
18pub type Bytes = Vec<u8>;
19
20/// Alias for a Merkle tree leaf position
21pub type MKTreeLeafPosition = u64;
22
23/// A node of a Merkle tree
24#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash, Serialize, Deserialize)]
25pub struct MKTreeNode {
26    hash: Bytes,
27}
28
29impl MKTreeNode {
30    /// MKTreeNode factory
31    pub fn new(hash: Bytes) -> Self {
32        Self { hash }
33    }
34
35    /// Create a MKTreeNode from a hex representation
36    pub fn from_hex(hex: &str) -> StdResult<Self> {
37        let hash = hex::decode(hex)?;
38        Ok(Self { hash })
39    }
40
41    /// Create a hex representation of the MKTreeNode
42    pub fn to_hex(&self) -> String {
43        hex::encode(&self.hash)
44    }
45}
46
47impl Deref for MKTreeNode {
48    type Target = Bytes;
49
50    fn deref(&self) -> &Self::Target {
51        &self.hash
52    }
53}
54
55impl From<String> for MKTreeNode {
56    fn from(other: String) -> Self {
57        Self {
58            hash: other.as_str().into(),
59        }
60    }
61}
62
63impl From<&String> for MKTreeNode {
64    fn from(other: &String) -> Self {
65        Self {
66            hash: other.as_str().into(),
67        }
68    }
69}
70
71impl From<&str> for MKTreeNode {
72    fn from(other: &str) -> Self {
73        Self {
74            hash: other.as_bytes().to_vec(),
75        }
76    }
77}
78
79impl<S: MKTreeStorer> TryFrom<MKTree<S>> for MKTreeNode {
80    type Error = StdError;
81    fn try_from(other: MKTree<S>) -> Result<Self, Self::Error> {
82        other.compute_root()
83    }
84}
85
86impl<S: MKTreeStorer> TryFrom<&MKTree<S>> for MKTreeNode {
87    type Error = StdError;
88    fn try_from(other: &MKTree<S>) -> Result<Self, Self::Error> {
89        other.compute_root()
90    }
91}
92
93impl Display for MKTreeNode {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "{}", String::from_utf8_lossy(&self.hash))
96    }
97}
98
99impl Add for MKTreeNode {
100    type Output = MKTreeNode;
101
102    fn add(self, other: MKTreeNode) -> MKTreeNode {
103        &self + &other
104    }
105}
106
107impl Add for &MKTreeNode {
108    type Output = MKTreeNode;
109
110    fn add(self, other: &MKTreeNode) -> MKTreeNode {
111        let mut hasher = Blake2s256::new();
112        hasher.update(self.deref());
113        hasher.update(other.deref());
114        let hash_merge = hasher.finalize();
115        MKTreeNode::new(hash_merge.to_vec())
116    }
117}
118
119struct MergeMKTreeNode {}
120
121impl Merge for MergeMKTreeNode {
122    type Item = Arc<MKTreeNode>;
123
124    fn merge(lhs: &Self::Item, rhs: &Self::Item) -> MMRResult<Self::Item> {
125        Ok(Arc::new((**lhs).clone() + (**rhs).clone()))
126    }
127}
128
129/// A Merkle proof
130#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
131pub struct MKProof {
132    inner_root: Arc<MKTreeNode>,
133    inner_leaves: Vec<(MKTreeLeafPosition, Arc<MKTreeNode>)>,
134    inner_proof_size: u64,
135    inner_proof_items: Vec<Arc<MKTreeNode>>,
136}
137
138impl MKProof {
139    /// Return a reference to its merkle root.
140    pub fn root(&self) -> &MKTreeNode {
141        &self.inner_root
142    }
143
144    /// Verification of a Merkle proof
145    pub fn verify(&self) -> StdResult<()> {
146        MerkleProof::<Arc<MKTreeNode>, MergeMKTreeNode>::new(
147            self.inner_proof_size,
148            self.inner_proof_items.clone(),
149        )
150        .verify(self.inner_root.to_owned(), self.inner_leaves.to_owned())?
151        .then_some(())
152        .ok_or(anyhow!("Invalid MKProof"))
153    }
154
155    /// Check if the proof contains the given leaves
156    pub fn contains(&self, leaves: &[MKTreeNode]) -> StdResult<()> {
157        leaves
158            .iter()
159            .all(|leaf| self.inner_leaves.iter().any(|(_, l)| l.deref() == leaf))
160            .then_some(())
161            .ok_or(anyhow!("Leaves not found in the MKProof"))
162    }
163
164    /// List the leaves of the proof
165    pub fn leaves(&self) -> Vec<MKTreeNode> {
166        self.inner_leaves
167            .iter()
168            .map(|(_, l)| (**l).clone())
169            .collect::<Vec<_>>()
170    }
171
172    cfg_test_tools! {
173        /// Build a [MKProof] based on the given leaves (*Test only*).
174        pub fn from_leaves<T: Into<MKTreeNode> + Clone>(
175            leaves: &[T],
176        ) -> StdResult<MKProof> {
177            Self::from_subset_of_leaves(leaves, leaves)
178        }
179
180        /// Build a [MKProof] based on the given leaves (*Test only*).
181        pub fn from_subset_of_leaves<T: Into<MKTreeNode> + Clone>(
182            leaves: &[T],
183            leaves_to_verify: &[T],
184        ) -> StdResult<MKProof> {
185            let leaves = Self::list_to_mknode(leaves);
186            let leaves_to_verify =
187                Self::list_to_mknode(leaves_to_verify);
188
189            let mktree =
190                MKTree::<MKTreeStoreInMemory>::new(&leaves).with_context(|| "MKTree creation should not fail")?;
191            mktree.compute_proof(&leaves_to_verify)
192        }
193
194        fn list_to_mknode<T: Into<MKTreeNode> + Clone>(hashes: &[T]) -> Vec<MKTreeNode> {
195            hashes.iter().map(|h| h.clone().into()).collect()
196        }
197    }
198
199    /// Convert the proof to bytes
200    pub fn to_bytes(&self) -> StdResult<Bytes> {
201        bincode::serde::encode_to_vec(self, bincode::config::standard()).map_err(|e| e.into())
202    }
203
204    /// Convert the proof from bytes
205    pub fn from_bytes(bytes: &[u8]) -> StdResult<Self> {
206        let (res, _) =
207            bincode::serde::decode_from_slice::<Self, _>(bytes, bincode::config::standard())?;
208
209        Ok(res)
210    }
211}
212
213impl From<MKProof> for MKTreeNode {
214    fn from(other: MKProof) -> Self {
215        other.root().to_owned()
216    }
217}
218
219/// A Merkle tree store in memory
220#[derive(Clone)]
221pub struct MKTreeStoreInMemory {
222    inner_leaves: Arc<RwLock<HashMap<Arc<MKTreeNode>, MKTreeLeafPosition>>>,
223    inner_store: Arc<RwLock<HashMap<u64, Arc<MKTreeNode>>>>,
224}
225
226impl MKTreeStoreInMemory {
227    fn new() -> Self {
228        Self {
229            inner_leaves: Arc::new(RwLock::new(HashMap::new())),
230            inner_store: Arc::new(RwLock::new(HashMap::new())),
231        }
232    }
233}
234
235impl MKTreeLeafIndexer for MKTreeStoreInMemory {
236    fn set_leaf_position(&self, pos: MKTreeLeafPosition, node: Arc<MKTreeNode>) -> StdResult<()> {
237        let mut inner_leaves = self.inner_leaves.write().unwrap();
238        (*inner_leaves).insert(node, pos);
239
240        Ok(())
241    }
242
243    fn get_leaf_position(&self, node: &MKTreeNode) -> Option<MKTreeLeafPosition> {
244        let inner_leaves = self.inner_leaves.read().unwrap();
245        (*inner_leaves).get(node).cloned()
246    }
247
248    fn total_leaves(&self) -> usize {
249        let inner_leaves = self.inner_leaves.read().unwrap();
250        (*inner_leaves).len()
251    }
252
253    fn leaves(&self) -> Vec<MKTreeNode> {
254        let inner_leaves = self.inner_leaves.read().unwrap();
255        (*inner_leaves)
256            .iter()
257            .map(|(leaf, position)| (position, leaf))
258            .collect::<BTreeMap<_, _>>()
259            .into_values()
260            .map(|leaf| (**leaf).clone())
261            .collect()
262    }
263}
264
265impl MKTreeStorer for MKTreeStoreInMemory {
266    fn build() -> StdResult<Self> {
267        Ok(Self::new())
268    }
269
270    fn get_elem(&self, pos: u64) -> StdResult<Option<Arc<MKTreeNode>>> {
271        let inner_store = self.inner_store.read().unwrap();
272
273        Ok((*inner_store).get(&pos).cloned())
274    }
275
276    fn append(&self, pos: u64, elems: Vec<Arc<MKTreeNode>>) -> StdResult<()> {
277        let mut inner_store = self.inner_store.write().unwrap();
278        for (i, elem) in elems.into_iter().enumerate() {
279            (*inner_store).insert(pos + i as u64, elem);
280        }
281
282        Ok(())
283    }
284}
285
286/// The Merkle tree storer trait
287pub trait MKTreeStorer: Clone + Send + Sync + MKTreeLeafIndexer {
288    /// Try to create a new instance of the storer
289    fn build() -> StdResult<Self>;
290
291    /// Get the element at the given position
292    fn get_elem(&self, pos: u64) -> StdResult<Option<Arc<MKTreeNode>>>;
293
294    /// Append elements at the given position
295    fn append(&self, pos: u64, elems: Vec<Arc<MKTreeNode>>) -> StdResult<()>;
296}
297
298/// This struct exists only to implement for a [MkTreeStore] the [MMRStoreReadOps] and
299/// [MMRStoreWriteOps] from merkle_mountain_range crate without the need to reexport types
300/// from that crate.
301///
302/// Rust don't allow the following:
303/// ```ignore
304/// impl<S: MKTreeStorer> MMRStoreReadOps<Arc<MKTreeNode>> for S {}
305/// ```
306/// Since it disallows implementations of traits for arbitrary types which are not defined in
307/// the same crate as the trait itself (see [E0117](https://doc.rust-lang.org/error_codes/E0117.html)).
308struct MKTreeStore<S: MKTreeStorer> {
309    storer: Box<S>,
310}
311
312impl<S: MKTreeStorer> MKTreeStore<S> {
313    fn build() -> StdResult<Self> {
314        let storer = Box::new(S::build()?);
315        Ok(Self { storer })
316    }
317}
318
319impl<S: MKTreeStorer> MMRStoreReadOps<Arc<MKTreeNode>> for MKTreeStore<S> {
320    fn get_elem(&self, pos: u64) -> MMRResult<Option<Arc<MKTreeNode>>> {
321        self.storer
322            .get_elem(pos)
323            .map_err(|e| MMRError::StoreError(e.to_string()))
324    }
325}
326
327impl<S: MKTreeStorer> MMRStoreWriteOps<Arc<MKTreeNode>> for MKTreeStore<S> {
328    fn append(&mut self, pos: u64, elems: Vec<Arc<MKTreeNode>>) -> MMRResult<()> {
329        self.storer
330            .append(pos, elems)
331            .map_err(|e| MMRError::StoreError(e.to_string()))
332    }
333}
334
335impl<S: MKTreeStorer> MKTreeLeafIndexer for MKTreeStore<S> {
336    fn set_leaf_position(&self, pos: MKTreeLeafPosition, leaf: Arc<MKTreeNode>) -> StdResult<()> {
337        self.storer.set_leaf_position(pos, leaf)
338    }
339
340    fn get_leaf_position(&self, leaf: &MKTreeNode) -> Option<MKTreeLeafPosition> {
341        self.storer.get_leaf_position(leaf)
342    }
343
344    fn total_leaves(&self) -> usize {
345        self.storer.total_leaves()
346    }
347
348    fn leaves(&self) -> Vec<MKTreeNode> {
349        self.storer.leaves()
350    }
351}
352
353/// The Merkle tree leaves indexer trait
354pub trait MKTreeLeafIndexer {
355    /// Get the position of the leaf in the Merkle tree
356    fn set_leaf_position(&self, pos: MKTreeLeafPosition, leaf: Arc<MKTreeNode>) -> StdResult<()>;
357
358    /// Get the position of the leaf in the Merkle tree
359    fn get_leaf_position(&self, leaf: &MKTreeNode) -> Option<MKTreeLeafPosition>;
360
361    /// Number of leaves in the Merkle tree
362    fn total_leaves(&self) -> usize;
363
364    /// List of leaves with their positions in the Merkle tree
365    fn leaves(&self) -> Vec<MKTreeNode>;
366
367    /// Check if the Merkle tree contains the given leaf
368    fn contains_leaf(&self, leaf: &MKTreeNode) -> bool {
369        self.get_leaf_position(leaf).is_some()
370    }
371}
372
373/// A Merkle tree
374pub struct MKTree<S: MKTreeStorer> {
375    inner_tree: MMR<Arc<MKTreeNode>, MergeMKTreeNode, MKTreeStore<S>>,
376}
377
378impl<S: MKTreeStorer> MKTree<S> {
379    /// MKTree factory
380    pub fn new<T: Into<MKTreeNode> + Clone>(leaves: &[T]) -> StdResult<Self> {
381        let mut inner_tree = MMR::<_, _, _>::new(0, MKTreeStore::<S>::build()?);
382        for leaf in leaves {
383            let leaf = Arc::new(leaf.to_owned().into());
384            let inner_tree_position = inner_tree.push(leaf.clone())?;
385            inner_tree
386                .store()
387                .set_leaf_position(inner_tree_position, leaf.clone())?;
388        }
389        inner_tree.commit()?;
390
391        Ok(Self { inner_tree })
392    }
393
394    /// Append leaves to the Merkle tree
395    pub fn append<T: Into<MKTreeNode> + Clone>(&mut self, leaves: &[T]) -> StdResult<()> {
396        for leaf in leaves {
397            let leaf = Arc::new(leaf.to_owned().into());
398            let inner_tree_position = self.inner_tree.push(leaf.clone())?;
399            self.inner_tree
400                .store()
401                .set_leaf_position(inner_tree_position, leaf.clone())?;
402        }
403        self.inner_tree.commit()?;
404
405        Ok(())
406    }
407
408    /// Number of leaves in the Merkle tree
409    pub fn total_leaves(&self) -> usize {
410        self.inner_tree.store().total_leaves()
411    }
412
413    /// List of leaves with their positions in the Merkle tree
414    pub fn leaves(&self) -> Vec<MKTreeNode> {
415        self.inner_tree.store().leaves()
416    }
417
418    /// Check if the Merkle tree contains the given leaf
419    pub fn contains(&self, leaf: &MKTreeNode) -> bool {
420        self.inner_tree.store().contains_leaf(leaf)
421    }
422
423    /// Generate root of the Merkle tree
424    pub fn compute_root(&self) -> StdResult<MKTreeNode> {
425        Ok((*self
426            .inner_tree
427            .get_root()
428            .with_context(|| "Could not compute Merkle Tree root")?)
429        .clone())
430    }
431
432    /// Generate Merkle proof of memberships in the tree
433    pub fn compute_proof(&self, leaves: &[MKTreeNode]) -> StdResult<MKProof> {
434        let inner_leaves = leaves
435            .iter()
436            .map(|leaf| {
437                if let Some(leaf_position) = self.inner_tree.store().get_leaf_position(leaf) {
438                    Ok((leaf_position, Arc::new(leaf.to_owned())))
439                } else {
440                    Err(anyhow!("Leaf not found in the Merkle tree"))
441                }
442            })
443            .collect::<StdResult<Vec<_>>>()?;
444        let proof = self.inner_tree.gen_proof(
445            inner_leaves
446                .iter()
447                .map(|(leaf_position, _leaf)| *leaf_position)
448                .collect(),
449        )?;
450        Ok(MKProof {
451            inner_root: Arc::new(self.compute_root()?),
452            inner_leaves,
453            inner_proof_size: proof.mmr_size(),
454            inner_proof_items: proof.proof_items().to_vec(),
455        })
456    }
457}
458
459impl<S: MKTreeStorer> Clone for MKTree<S> {
460    fn clone(&self) -> Self {
461        // Cloning should never fail so unwrap is safe
462        Self::new(&self.leaves()).unwrap()
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    fn generate_leaves(total_leaves: usize) -> Vec<MKTreeNode> {
471        (0..total_leaves).map(|i| format!("test-{i}").into()).collect()
472    }
473
474    #[test]
475    fn test_golden_merkle_root() {
476        let leaves = vec!["golden-1", "golden-2", "golden-3", "golden-4", "golden-5"];
477        let mktree =
478            MKTree::<MKTreeStoreInMemory>::new(&leaves).expect("MKTree creation should not fail");
479        let mkroot = mktree.compute_root().expect("MKRoot generation should not fail");
480
481        assert_eq!(
482            "3bbced153528697ecde7345a22e50115306478353619411523e804f2323fd921",
483            mkroot.to_hex()
484        );
485    }
486
487    #[test]
488    fn test_should_accept_valid_proof_generated_by_merkle_tree() {
489        let leaves = generate_leaves(10);
490        let leaves_to_verify = &[leaves[0].to_owned(), leaves[3].to_owned()];
491        let proof =
492            MKProof::from_leaves(leaves_to_verify).expect("MKProof generation should not fail");
493        proof.verify().expect("The MKProof should be valid");
494    }
495
496    #[test]
497    fn test_should_serialize_deserialize_proof() {
498        let leaves = generate_leaves(10);
499        let leaves_to_verify = &[leaves[0].to_owned(), leaves[3].to_owned()];
500        let proof =
501            MKProof::from_leaves(leaves_to_verify).expect("MKProof generation should not fail");
502
503        let serialized_proof = proof.to_bytes().expect("Serialization should not fail");
504        let deserialized_proof =
505            MKProof::from_bytes(&serialized_proof).expect("Deserialization should not fail");
506        assert_eq!(
507            proof, deserialized_proof,
508            "Deserialized proof should match the original"
509        );
510    }
511
512    #[test]
513    fn test_should_reject_invalid_proof_generated_by_merkle_tree() {
514        let leaves = generate_leaves(10);
515        let leaves_to_verify = &[leaves[0].to_owned(), leaves[3].to_owned()];
516        let mut proof =
517            MKProof::from_leaves(leaves_to_verify).expect("MKProof generation should not fail");
518        proof.inner_root = Arc::new(leaves[1].to_owned());
519        proof.verify().expect_err("The MKProof should be invalid");
520    }
521
522    #[test]
523    fn test_should_list_leaves() {
524        let leaves: Vec<MKTreeNode> = vec!["test-0".into(), "test-1".into(), "test-2".into()];
525        let mktree =
526            MKTree::<MKTreeStoreInMemory>::new(&leaves).expect("MKTree creation should not fail");
527        let leaves_retrieved = mktree.leaves();
528
529        assert_eq!(
530            leaves.iter().collect::<Vec<_>>(),
531            leaves_retrieved.iter().collect::<Vec<_>>()
532        );
533    }
534
535    #[test]
536    fn test_should_clone_and_compute_same_root() {
537        let leaves = generate_leaves(10);
538        let mktree =
539            MKTree::<MKTreeStoreInMemory>::new(&leaves).expect("MKTree creation should not fail");
540        let mktree_clone = mktree.clone();
541
542        assert_eq!(
543            mktree.compute_root().unwrap(),
544            mktree_clone.compute_root().unwrap(),
545        );
546    }
547
548    #[test]
549    fn test_should_support_append_leaves() {
550        let leaves = generate_leaves(10);
551        let leaves_creation = &leaves[..9];
552        let leaves_to_append = &leaves[9..];
553        let mut mktree = MKTree::<MKTreeStoreInMemory>::new(leaves_creation)
554            .expect("MKTree creation should not fail");
555        mktree
556            .append(leaves_to_append)
557            .expect("MKTree append leaves should not fail");
558
559        assert_eq!(10, mktree.total_leaves());
560    }
561
562    #[test]
563    fn tree_node_from_to_string() {
564        let expected_str = "my_string";
565        let expected_string = expected_str.to_string();
566        let node_str: MKTreeNode = expected_str.into();
567        let node_string: MKTreeNode = expected_string.clone().into();
568
569        assert_eq!(node_str.to_string(), expected_str);
570        assert_eq!(node_string.to_string(), expected_string);
571    }
572
573    #[test]
574    fn contains_leaves() {
575        let mut leaves_to_verify = generate_leaves(10);
576        let leaves_not_verified = leaves_to_verify.drain(3..6).collect::<Vec<_>>();
577        let proof =
578            MKProof::from_leaves(&leaves_to_verify).expect("MKProof generation should not fail");
579
580        // contains everything
581        proof.contains(&leaves_to_verify).unwrap();
582
583        // contains subpart
584        proof.contains(&leaves_to_verify[0..2]).unwrap();
585
586        // don't contains all not verified
587        proof.contains(&leaves_not_verified).unwrap_err();
588
589        // don't contains subpart of not verified
590        proof.contains(&leaves_not_verified[1..2]).unwrap_err();
591
592        // fail if part verified and part unverified
593        proof
594            .contains(&[leaves_to_verify[2].to_owned(), leaves_not_verified[0].to_owned()])
595            .unwrap_err();
596    }
597
598    #[test]
599    fn list_leaves() {
600        let leaves_to_verify = generate_leaves(10);
601        let proof =
602            MKProof::from_leaves(&leaves_to_verify).expect("MKProof generation should not fail");
603
604        let proof_leaves = proof.leaves();
605        assert_eq!(proof_leaves, leaves_to_verify);
606    }
607}