mithril_stm/circuits/halo2/off_circuit/
merkle_tree.rs

1use ff::Field;
2
3use crate::circuits::halo2::hash::{HashCPU, PoseidonHash};
4use crate::circuits::halo2::off_circuit::error::MerkleTreeError;
5use crate::circuits::halo2::off_circuit::unique_signature::VerificationKey;
6use crate::circuits::halo2::types::{JubjubBase, Target};
7
8type F = JubjubBase;
9
10#[derive(Debug, Copy, Clone)]
11pub struct MTLeaf(pub VerificationKey, pub Target);
12
13impl MTLeaf {
14    pub fn to_field(&self) -> [F; 3] {
15        let mut elements = [F::ZERO; 3];
16        elements[0..2].copy_from_slice(&self.0.to_field());
17        elements[2] = self.1;
18        elements
19    }
20
21    pub fn to_bytes(&self) -> [u8; 96] {
22        let mut bytes = [0u8; 96];
23        bytes[0..64].copy_from_slice(&self.0.to_bytes());
24        bytes[64..96].copy_from_slice(&self.1.to_bytes_le());
25        bytes
26    }
27
28    pub fn from_bytes(bytes: &[u8]) -> Result<Self, MerkleTreeError> {
29        let bytes = bytes.get(0..96).ok_or(MerkleTreeError::SerializationError)?;
30        let target_bytes: [u8; 32] = bytes[64..96]
31            .try_into()
32            .map_err(|_| MerkleTreeError::SerializationError)?;
33        let vk = VerificationKey::from_bytes(&bytes[0..64])
34            .map_err(|_| MerkleTreeError::SerializationError)?;
35        let target = Target::from_bytes_le(&target_bytes)
36            .into_option()
37            .ok_or(MerkleTreeError::SerializationError)?;
38
39        Ok(Self(vk, target))
40    }
41}
42
43#[derive(Clone, Copy, Debug)]
44// The position of the sibling in the tree.
45pub enum Position {
46    Left,
47    Right,
48}
49
50impl From<Position> for F {
51    fn from(value: Position) -> Self {
52        match value {
53            Position::Left => F::ZERO,
54            Position::Right => F::ONE,
55        }
56    }
57}
58
59#[derive(Clone, Debug)]
60// Struct defining the witness of the MT proof.
61pub struct MerklePath {
62    // Sibling nodes corresponding to a field value F representing some
63    // hash and whether the position is left or right.
64    // if position == Position::Left, then sibling is on the left
65    // if position == Position::Right, then sibling is on the right
66    pub siblings: Vec<(Position, F)>,
67}
68
69impl MerklePath {
70    pub fn new(siblings: Vec<(Position, F)>) -> Self {
71        Self { siblings }
72    }
73
74    pub fn get_siblings(&self) -> &[(Position, F)] {
75        &self.siblings
76    }
77
78    // Function to compute (off circuit) the Merkle tree root given the leaf and the
79    // sibling nodes.
80    pub fn compute_root(&self, leaf: MTLeaf) -> F {
81        let digest = PoseidonHash::hash(&leaf.to_field());
82
83        // Compute the Merkle root.
84        self.siblings.iter().fold(digest, |acc, x| match x.0 {
85            // if sibling is on the left => hash(sibling, node)
86            Position::Left => PoseidonHash::hash(&[x.1, acc]),
87            // if sibling is on the right => hash(node, sibling)
88            Position::Right => PoseidonHash::hash(&[acc, x.1]),
89        })
90    }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub struct MerkleTree {
95    /// The nodes are stored in an array heap:
96    /// * `nodes[0]` is the root,
97    /// * the parent of `nodes[i]` is `nodes[(i-1)/2]`
98    /// * the children of `nodes[i]` are `{nodes[2i + 1], nodes[2i + 2]}`
99    /// * All nodes have size `Output<D>::output_size()`, even leafs (which are hashed before committing them).
100    nodes: Vec<F>,
101    /// The leaves begin at `nodes[leaf_off]`.
102    leaf_off: usize,
103    /// Number of leaves cached in the merkle tree.
104    n: usize,
105}
106
107fn parent(i: usize) -> usize {
108    assert!(i > 0, "The root node does not have a parent");
109    (i - 1) / 2
110}
111
112fn left_child(i: usize) -> usize {
113    (2 * i) + 1
114}
115
116fn right_child(i: usize) -> usize {
117    (2 * i) + 2
118}
119
120fn sibling(i: usize) -> usize {
121    assert!(i > 0, "The root node does not have a sibling");
122    // In the heap representation, the left sibling is always odd
123    // And the right sibling is the next node
124    // We're assuming that the heap is complete
125    if i % 2 == 1 { i + 1 } else { i - 1 }
126}
127
128impl MerkleTree {
129    /// Provided a non-empty list of leaves, `create` generates its corresponding `MerkleTree`.
130    pub fn create(leaves: &[MTLeaf]) -> MerkleTree {
131        let n = leaves.len();
132        assert!(n > 0, "MerkleTree::create() called with no leaves");
133
134        let num_nodes = n + n.next_power_of_two() - 1;
135        let mut nodes = vec![F::ZERO; num_nodes];
136
137        for i in 0..leaves.len() {
138            nodes[num_nodes - n + i] = PoseidonHash::hash(&leaves[i].to_field());
139        }
140
141        let z = PoseidonHash::hash(&[F::ZERO]);
142        for i in (0..num_nodes - n).rev() {
143            let left = if left_child(i) < num_nodes {
144                nodes[left_child(i)]
145            } else {
146                z
147            };
148            let right = if right_child(i) < num_nodes {
149                nodes[right_child(i)]
150            } else {
151                z
152            };
153            nodes[i] = PoseidonHash::hash(&[left, right]);
154        }
155
156        Self {
157            nodes,
158            n,
159            leaf_off: num_nodes - n,
160        }
161    }
162
163    /// Get the root of the tree.
164    pub fn root(&self) -> F {
165        self.nodes[0]
166    }
167
168    /// Return the index of the leaf.
169    fn idx_of_leaf(&self, i: usize) -> usize {
170        self.leaf_off + i
171    }
172
173    pub fn get_path(&self, i: usize) -> MerklePath {
174        assert!(
175            i < self.n,
176            "Proof index out of bounds: asked for {} out of {}",
177            i,
178            self.n
179        );
180        let mut idx = self.idx_of_leaf(i);
181        let z = PoseidonHash::hash(&[F::ZERO]);
182        let mut proof = Vec::new();
183
184        while idx > 0 {
185            let h = if sibling(idx) < self.nodes.len() {
186                self.nodes[sibling(idx)]
187            } else {
188                z
189            };
190            let pos = {
191                if (idx & 0b1) == 0 {
192                    Position::Left
193                } else {
194                    Position::Right
195                }
196            };
197            proof.push((pos, h));
198            idx = parent(idx);
199        }
200
201        MerklePath::new(proof)
202    }
203    pub fn to_merkle_tree_commitment(&self) -> MerkleTreeCommitment {
204        MerkleTreeCommitment::new(self.nodes[0], self.n as u32)
205    }
206}
207
208#[derive(Debug, Clone)]
209pub struct MerkleTreeCommitment {
210    merkle_root: F,
211    nr_leaves: u32,
212}
213
214impl MerkleTreeCommitment {
215    pub fn new(merkle_root: F, nr_leaves: u32) -> Self {
216        Self {
217            merkle_root,
218            nr_leaves,
219        }
220    }
221}
222
223impl From<MerkleTreeCommitment> for Vec<u8> {
224    fn from(mt_commit: MerkleTreeCommitment) -> Vec<u8> {
225        let mut bytes = Vec::new();
226        bytes.extend_from_slice(&mt_commit.merkle_root.to_bytes_le());
227        bytes.extend_from_slice(&mt_commit.nr_leaves.to_le_bytes());
228        bytes
229    }
230}
231
232impl TryFrom<&[u8]> for MerkleTreeCommitment {
233    type Error = &'static str;
234
235    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
236        if bytes.len() != 36 {
237            return Err("Invalid byte length for MerkleTreeCommitment");
238        }
239
240        let merkle_root = JubjubBase::from_bytes_le(bytes[0..32].try_into().unwrap()).unwrap();
241        let nr_leaves = u32::from_le_bytes(bytes[32..36].try_into().unwrap());
242
243        Ok(MerkleTreeCommitment {
244            merkle_root,
245            nr_leaves,
246        })
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::circuits::halo2::off_circuit::unique_signature::SigningKey;
254    use rand_core::OsRng;
255
256    fn create_leaf(value: F) -> MTLeaf {
257        let mut rng = OsRng;
258        let sk = SigningKey::generate(&mut rng);
259        let vk = VerificationKey::from(&sk);
260        MTLeaf(vk, value)
261    }
262
263    #[test]
264    fn test_merkle_tree_creation() {
265        let leaves = vec![
266            create_leaf(F::from(1u64)),
267            create_leaf(F::from(2u64)),
268            create_leaf(F::from(3u64)),
269            create_leaf(F::from(4u64)),
270        ];
271
272        let tree = MerkleTree::create(&leaves);
273        assert_eq!(tree.n, leaves.len(), "Number of leaves mismatch");
274        assert!(
275            tree.root() != F::ZERO,
276            "Root should not be zero when tree is valid"
277        );
278    }
279
280    #[test]
281    fn test_merkle_path_generation() {
282        let leaves = vec![
283            create_leaf(F::from(1u64)),
284            create_leaf(F::from(2u64)),
285            create_leaf(F::from(3u64)),
286            create_leaf(F::from(4u64)),
287        ];
288
289        let tree = MerkleTree::create(&leaves);
290
291        for i in 0..leaves.len() {
292            let path = tree.get_path(i);
293            assert!(
294                !path.siblings.is_empty(),
295                "Path should not be empty for any leaf"
296            );
297        }
298    }
299
300    #[test]
301    fn test_merkle_path_verification() {
302        let leaves = vec![
303            create_leaf(F::from(1u64)),
304            create_leaf(F::from(2u64)),
305            create_leaf(F::from(3u64)),
306            create_leaf(F::from(4u64)),
307            create_leaf(F::from(5u64)),
308            create_leaf(F::from(6u64)),
309        ];
310
311        let tree = MerkleTree::create(&leaves);
312
313        for (i, _leaf) in leaves.iter().enumerate() {
314            let path = tree.get_path(i);
315            let computed_root = path.compute_root(leaves[i]);
316            assert_eq!(
317                tree.root(),
318                computed_root,
319                "Computed root does not match the actual root"
320            );
321        }
322    }
323}