mithril_stm/circuits/halo2/off_circuit/
merkle_tree.rs1use ff::Field;
2
3use crate::circuits::halo2::hash::{HashCPU, PoseidonHash};
4use crate::circuits::halo2::off_circuit::error::MerkleTreeError;
5use crate::circuits::halo2::off_circuit::unique_signature::VerificationKey;
6use crate::circuits::halo2::types::{JubjubBase, Target};
7
8type F = JubjubBase;
9
10#[derive(Debug, Copy, Clone)]
11pub struct MTLeaf(pub VerificationKey, pub Target);
12
13impl MTLeaf {
14 pub fn to_field(&self) -> [F; 3] {
15 let mut elements = [F::ZERO; 3];
16 elements[0..2].copy_from_slice(&self.0.to_field());
17 elements[2] = self.1;
18 elements
19 }
20
21 pub fn to_bytes(&self) -> [u8; 96] {
22 let mut bytes = [0u8; 96];
23 bytes[0..64].copy_from_slice(&self.0.to_bytes());
24 bytes[64..96].copy_from_slice(&self.1.to_bytes_le());
25 bytes
26 }
27
28 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MerkleTreeError> {
29 let bytes = bytes.get(0..96).ok_or(MerkleTreeError::SerializationError)?;
30 let target_bytes: [u8; 32] = bytes[64..96]
31 .try_into()
32 .map_err(|_| MerkleTreeError::SerializationError)?;
33 let vk = VerificationKey::from_bytes(&bytes[0..64])
34 .map_err(|_| MerkleTreeError::SerializationError)?;
35 let target = Target::from_bytes_le(&target_bytes)
36 .into_option()
37 .ok_or(MerkleTreeError::SerializationError)?;
38
39 Ok(Self(vk, target))
40 }
41}
42
43#[derive(Clone, Copy, Debug)]
44pub enum Position {
46 Left,
47 Right,
48}
49
50impl From<Position> for F {
51 fn from(value: Position) -> Self {
52 match value {
53 Position::Left => F::ZERO,
54 Position::Right => F::ONE,
55 }
56 }
57}
58
59#[derive(Clone, Debug)]
60pub struct MerklePath {
62 pub siblings: Vec<(Position, F)>,
67}
68
69impl MerklePath {
70 pub fn new(siblings: Vec<(Position, F)>) -> Self {
71 Self { siblings }
72 }
73
74 pub fn get_siblings(&self) -> &[(Position, F)] {
75 &self.siblings
76 }
77
78 pub fn compute_root(&self, leaf: MTLeaf) -> F {
81 let digest = PoseidonHash::hash(&leaf.to_field());
82
83 self.siblings.iter().fold(digest, |acc, x| match x.0 {
85 Position::Left => PoseidonHash::hash(&[x.1, acc]),
87 Position::Right => PoseidonHash::hash(&[acc, x.1]),
89 })
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub struct MerkleTree {
95 nodes: Vec<F>,
101 leaf_off: usize,
103 n: usize,
105}
106
107fn parent(i: usize) -> usize {
108 assert!(i > 0, "The root node does not have a parent");
109 (i - 1) / 2
110}
111
112fn left_child(i: usize) -> usize {
113 (2 * i) + 1
114}
115
116fn right_child(i: usize) -> usize {
117 (2 * i) + 2
118}
119
120fn sibling(i: usize) -> usize {
121 assert!(i > 0, "The root node does not have a sibling");
122 if i % 2 == 1 { i + 1 } else { i - 1 }
126}
127
128impl MerkleTree {
129 pub fn create(leaves: &[MTLeaf]) -> MerkleTree {
131 let n = leaves.len();
132 assert!(n > 0, "MerkleTree::create() called with no leaves");
133
134 let num_nodes = n + n.next_power_of_two() - 1;
135 let mut nodes = vec![F::ZERO; num_nodes];
136
137 for i in 0..leaves.len() {
138 nodes[num_nodes - n + i] = PoseidonHash::hash(&leaves[i].to_field());
139 }
140
141 let z = PoseidonHash::hash(&[F::ZERO]);
142 for i in (0..num_nodes - n).rev() {
143 let left = if left_child(i) < num_nodes {
144 nodes[left_child(i)]
145 } else {
146 z
147 };
148 let right = if right_child(i) < num_nodes {
149 nodes[right_child(i)]
150 } else {
151 z
152 };
153 nodes[i] = PoseidonHash::hash(&[left, right]);
154 }
155
156 Self {
157 nodes,
158 n,
159 leaf_off: num_nodes - n,
160 }
161 }
162
163 pub fn root(&self) -> F {
165 self.nodes[0]
166 }
167
168 fn idx_of_leaf(&self, i: usize) -> usize {
170 self.leaf_off + i
171 }
172
173 pub fn get_path(&self, i: usize) -> MerklePath {
174 assert!(
175 i < self.n,
176 "Proof index out of bounds: asked for {} out of {}",
177 i,
178 self.n
179 );
180 let mut idx = self.idx_of_leaf(i);
181 let z = PoseidonHash::hash(&[F::ZERO]);
182 let mut proof = Vec::new();
183
184 while idx > 0 {
185 let h = if sibling(idx) < self.nodes.len() {
186 self.nodes[sibling(idx)]
187 } else {
188 z
189 };
190 let pos = {
191 if (idx & 0b1) == 0 {
192 Position::Left
193 } else {
194 Position::Right
195 }
196 };
197 proof.push((pos, h));
198 idx = parent(idx);
199 }
200
201 MerklePath::new(proof)
202 }
203 pub fn to_merkle_tree_commitment(&self) -> MerkleTreeCommitment {
204 MerkleTreeCommitment::new(self.nodes[0], self.n as u32)
205 }
206}
207
208#[derive(Debug, Clone)]
209pub struct MerkleTreeCommitment {
210 merkle_root: F,
211 nr_leaves: u32,
212}
213
214impl MerkleTreeCommitment {
215 pub fn new(merkle_root: F, nr_leaves: u32) -> Self {
216 Self {
217 merkle_root,
218 nr_leaves,
219 }
220 }
221}
222
223impl From<MerkleTreeCommitment> for Vec<u8> {
224 fn from(mt_commit: MerkleTreeCommitment) -> Vec<u8> {
225 let mut bytes = Vec::new();
226 bytes.extend_from_slice(&mt_commit.merkle_root.to_bytes_le());
227 bytes.extend_from_slice(&mt_commit.nr_leaves.to_le_bytes());
228 bytes
229 }
230}
231
232impl TryFrom<&[u8]> for MerkleTreeCommitment {
233 type Error = &'static str;
234
235 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
236 if bytes.len() != 36 {
237 return Err("Invalid byte length for MerkleTreeCommitment");
238 }
239
240 let merkle_root = JubjubBase::from_bytes_le(bytes[0..32].try_into().unwrap()).unwrap();
241 let nr_leaves = u32::from_le_bytes(bytes[32..36].try_into().unwrap());
242
243 Ok(MerkleTreeCommitment {
244 merkle_root,
245 nr_leaves,
246 })
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::circuits::halo2::off_circuit::unique_signature::SigningKey;
254 use rand_core::OsRng;
255
256 fn create_leaf(value: F) -> MTLeaf {
257 let mut rng = OsRng;
258 let sk = SigningKey::generate(&mut rng);
259 let vk = VerificationKey::from(&sk);
260 MTLeaf(vk, value)
261 }
262
263 #[test]
264 fn test_merkle_tree_creation() {
265 let leaves = vec![
266 create_leaf(F::from(1u64)),
267 create_leaf(F::from(2u64)),
268 create_leaf(F::from(3u64)),
269 create_leaf(F::from(4u64)),
270 ];
271
272 let tree = MerkleTree::create(&leaves);
273 assert_eq!(tree.n, leaves.len(), "Number of leaves mismatch");
274 assert!(
275 tree.root() != F::ZERO,
276 "Root should not be zero when tree is valid"
277 );
278 }
279
280 #[test]
281 fn test_merkle_path_generation() {
282 let leaves = vec![
283 create_leaf(F::from(1u64)),
284 create_leaf(F::from(2u64)),
285 create_leaf(F::from(3u64)),
286 create_leaf(F::from(4u64)),
287 ];
288
289 let tree = MerkleTree::create(&leaves);
290
291 for i in 0..leaves.len() {
292 let path = tree.get_path(i);
293 assert!(
294 !path.siblings.is_empty(),
295 "Path should not be empty for any leaf"
296 );
297 }
298 }
299
300 #[test]
301 fn test_merkle_path_verification() {
302 let leaves = vec![
303 create_leaf(F::from(1u64)),
304 create_leaf(F::from(2u64)),
305 create_leaf(F::from(3u64)),
306 create_leaf(F::from(4u64)),
307 create_leaf(F::from(5u64)),
308 create_leaf(F::from(6u64)),
309 ];
310
311 let tree = MerkleTree::create(&leaves);
312
313 for (i, _leaf) in leaves.iter().enumerate() {
314 let path = tree.get_path(i);
315 let computed_root = path.compute_root(leaves[i]);
316 assert_eq!(
317 tree.root(),
318 computed_root,
319 "Computed root does not match the actual root"
320 );
321 }
322 }
323}