1use crate::error::MerkleTreeError;
2use crate::merkle_tree::{
3 left_child, parent, right_child, sibling, BatchPath, MTLeaf, MerkleTreeCommitment,
4 MerkleTreeCommitmentBatchCompat, Path,
5};
6use blake2::digest::{Digest, FixedOutput};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub struct MerkleTree<D: Digest> {
13 nodes: Vec<Vec<u8>>,
19 leaf_off: usize,
21 n: usize,
23 hasher: PhantomData<D>,
25}
26
27impl<D: Digest + FixedOutput> MerkleTree<D> {
28 pub fn create(leaves: &[MTLeaf]) -> MerkleTree<D> {
30 let n = leaves.len();
31 assert!(n > 0, "MerkleTree::create() called with no leaves");
32
33 let num_nodes = n + n.next_power_of_two() - 1;
34
35 let mut nodes = vec![vec![0u8]; num_nodes];
36
37 for i in 0..leaves.len() {
38 nodes[num_nodes - n + i] = D::digest(leaves[i].to_bytes()).to_vec();
39 }
40
41 for i in (0..num_nodes - n).rev() {
42 let z = D::digest([0u8]).to_vec();
43 let left = if left_child(i) < num_nodes {
44 &nodes[left_child(i)]
45 } else {
46 &z
47 };
48 let right = if right_child(i) < num_nodes {
49 &nodes[right_child(i)]
50 } else {
51 &z
52 };
53 nodes[i] = D::new()
54 .chain_update(left)
55 .chain_update(right)
56 .finalize()
57 .to_vec();
58 }
59
60 Self {
61 nodes,
62 n,
63 leaf_off: num_nodes - n,
64 hasher: PhantomData,
65 }
66 }
67
68 pub fn root(&self) -> &Vec<u8> {
70 &self.nodes[0]
71 }
72
73 fn idx_of_leaf(&self, i: usize) -> usize {
75 self.leaf_off + i
76 }
77
78 pub fn to_commitment_batch_compat(&self) -> MerkleTreeCommitmentBatchCompat<D> {
81 MerkleTreeCommitmentBatchCompat::new(self.nodes[0].clone(), self.n)
82 }
83
84 pub fn get_batched_path(&self, indices: Vec<usize>) -> BatchPath<D>
101 where
102 D: FixedOutput,
103 {
104 assert!(
105 !indices.is_empty(),
106 "get_batched_path() called with no indices"
107 );
108 for i in &indices {
109 assert!(
110 i < &self.n,
111 "Proof index out of bounds: asked for {} out of {}",
112 i,
113 self.n
114 );
115 }
116
117 let mut ordered_indices: Vec<usize> = indices.clone();
118 ordered_indices.sort_unstable();
119
120 assert_eq!(ordered_indices, indices, "Indices should be ordered");
121
122 ordered_indices = ordered_indices
123 .into_iter()
124 .map(|i| self.idx_of_leaf(i))
125 .collect();
126
127 let mut idx = ordered_indices[0];
128 let mut proof = Vec::new();
129
130 while idx > 0 {
131 let mut new_indices = Vec::with_capacity(ordered_indices.len());
132 let mut i = 0;
133 idx = parent(idx);
134 while i < ordered_indices.len() {
135 new_indices.push(parent(ordered_indices[i]));
136 let sibling = sibling(ordered_indices[i]);
137 if i < ordered_indices.len() - 1 && ordered_indices[i + 1] == sibling {
138 i += 1;
139 } else if sibling < self.nodes.len() {
140 proof.push(self.nodes[sibling].clone());
141 }
142 i += 1;
143 }
144 ordered_indices.clone_from(&new_indices);
145 }
146
147 BatchPath::new(proof, indices)
148 }
149
150 pub fn to_bytes(&self) -> Vec<u8> {
156 let mut result = Vec::with_capacity(8 + self.nodes.len() * <D as Digest>::output_size());
157 result.extend_from_slice(&u64::try_from(self.n).unwrap().to_be_bytes());
158 for node in self.nodes.iter() {
159 result.extend_from_slice(node);
160 }
161 result
162 }
163
164 pub fn from_bytes(bytes: &[u8]) -> Result<Self, MerkleTreeError<D>> {
168 let mut u64_bytes = [0u8; 8];
169 u64_bytes.copy_from_slice(&bytes[..8]);
170 let n = usize::try_from(u64::from_be_bytes(u64_bytes))
171 .map_err(|_| MerkleTreeError::SerializationError)?;
172 let num_nodes = n + n.next_power_of_two() - 1;
173 let mut nodes = Vec::with_capacity(num_nodes);
174 for i in 0..num_nodes {
175 nodes.push(
176 bytes[8 + i * <D as Digest>::output_size()
177 ..8 + (i + 1) * <D as Digest>::output_size()]
178 .to_vec(),
179 );
180 }
181 Ok(Self {
182 nodes,
183 leaf_off: num_nodes - n,
184 n,
185 hasher: PhantomData,
186 })
187 }
188
189 pub fn to_commitment(&self) -> MerkleTreeCommitment<D> {
191 MerkleTreeCommitment::new(self.nodes[0].clone()) }
193
194 pub fn get_path(&self, i: usize) -> Path<D> {
198 assert!(
199 i < self.n,
200 "Proof index out of bounds: asked for {} out of {}",
201 i,
202 self.n
203 );
204 let mut idx = self.idx_of_leaf(i);
205 let mut proof = Vec::new();
206
207 while idx > 0 {
208 let h = if sibling(idx) < self.nodes.len() {
209 self.nodes[sibling(idx)].clone()
210 } else {
211 D::digest([0u8]).to_vec()
212 };
213 proof.push(h.clone());
214 idx = parent(idx);
215 }
216
217 Path::new(proof, i)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::multi_sig::VerificationKey;
225 use blake2::{digest::consts::U32, Blake2b};
226 use proptest::collection::vec;
227 use proptest::prelude::*;
228 use rand::{rng, seq::IteratorRandom};
229
230 fn pow2_plus1(h: usize) -> usize {
231 1 + 2_usize.pow(h as u32)
232 }
233
234 prop_compose! {
235 fn arb_tree(max_size: u32)
236 (v in vec(any::<u64>(), 2..max_size as usize)) -> (MerkleTree<Blake2b<U32>>, Vec<MTLeaf>) {
237 let pks = vec![VerificationKey::default(); v.len()];
238 let leaves = pks.into_iter().zip(v.into_iter()).map(|(key, stake)| MTLeaf(key, stake)).collect::<Vec<MTLeaf>>();
239 (MerkleTree::<Blake2b<U32>>::create(&leaves), leaves)
240 }
241 }
242
243 proptest! {
244 #![proptest_config(ProptestConfig::with_cases(100))]
247 #[test]
248 fn test_create_proof((t, values) in arb_tree(30)) {
249 values.iter().enumerate().for_each(|(i, _v)| {
250 let pf = t.get_path(i);
251 assert!(t.to_commitment().check(&values[i], &pf).is_ok());
252 })
253 }
254
255 #[test]
256 fn test_bytes_path((t, values) in arb_tree(30)) {
257 values.iter().enumerate().for_each(|(i, _v)| {
258 let pf = t.get_path(i);
259 let bytes = pf.to_bytes();
260 let deserialised = Path::from_bytes(&bytes).unwrap();
261 assert!(t.to_commitment().check(&values[i], &deserialised).is_ok());
262
263 let encoded = bincode::serialize(&pf).unwrap();
264 let decoded: Path<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
265 assert!(t.to_commitment().check(&values[i], &decoded).is_ok());
266 })
267 }
268
269 #[test]
270 fn test_bytes_tree_commitment((t, values) in arb_tree(5)) {
271 let encoded = bincode::serialize(&t.to_commitment()).unwrap();
272 let decoded: MerkleTreeCommitment::<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
273 let tree_commitment = MerkleTree::<Blake2b<U32>>::create(&values).to_commitment();
274 assert_eq!(tree_commitment.root, decoded.root);
275 }
276
277 #[test]
278 fn test_bytes_tree((t, values) in arb_tree(5)) {
279 let bytes = t.to_bytes();
280 let deserialised = MerkleTree::<Blake2b<U32>>::from_bytes(&bytes).unwrap();
281 let tree = MerkleTree::<Blake2b<U32>>::create(&values);
282 assert_eq!(tree.nodes, deserialised.nodes);
283
284 let encoded = bincode::serialize(&t).unwrap();
285 let decoded: MerkleTree::<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
286 assert_eq!(tree.nodes, decoded.nodes);
287 }
288
289 #[test]
290 fn test_bytes_tree_commitment_batch_compat((t, values) in arb_tree(5)) {
291 let encoded = bincode::serialize(&t.to_commitment_batch_compat()).unwrap();
292 let decoded: MerkleTreeCommitmentBatchCompat::<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
293 let tree_commitment = MerkleTree::<Blake2b<U32>>::create(&values).to_commitment_batch_compat();
294 assert_eq!(tree_commitment.root, decoded.root);
295 assert_eq!(tree_commitment.get_nr_leaves(), decoded.get_nr_leaves());
296
297 }
298
299 }
300
301 prop_compose! {
302 fn values_with_invalid_proof(max_height: usize)
304 (h in 1..max_height)
305 (v in vec(any::<u64>(), 2..pow2_plus1(h)),
306 proof in vec(vec(any::<u8>(), 16), h)) -> (Vec<MTLeaf>, Vec<Vec<u8>>) {
307 let pks = vec![VerificationKey::default(); v.len()];
308 let leaves = pks.into_iter().zip(v.into_iter()).map(|(key, stake)| MTLeaf(key, stake)).collect::<Vec<MTLeaf>>();
309 (leaves, proof)
310 }
311 }
312
313 proptest! {
314 #[test]
315 fn test_create_invalid_proof(
316 i in any::<usize>(),
317 (values, proof) in values_with_invalid_proof(10)
318 ) {
319 let t = MerkleTree::<Blake2b<U32>>::create(&values[1..]);
320 let index = i % (values.len() - 1);
321 let path_values = proof. iter().map(|x| Blake2b::<U32>::digest(x).to_vec()).collect();
322 let path = Path::new(path_values, index);
323 assert!(t.to_commitment().check(&values[0], &path).is_err());
324 }
325
326 #[test]
327 fn test_create_invalid_batch_proof(
328 i in any::<usize>(),
329 (values, proof) in values_with_invalid_proof(10)
330 ) {
331 let t = MerkleTree::<Blake2b<U32>>::create(&values[1..]);
332 let indices = vec![i % (values.len() - 1); values.len() / 2];
333 let batch_values = vec![values[i % (values.len() - 1)]; values.len() / 2];
334 let path = BatchPath{values: proof
335 .iter()
336 .map(|x| Blake2b::<U32>::digest(x).to_vec())
337 .collect(),
338 indices,
339 hasher: PhantomData::<Blake2b<U32>>
340 };
341 assert!(t.to_commitment_batch_compat().check(&batch_values, &path).is_err());
342 }
343 }
344
345 prop_compose! {
346 fn arb_tree_arb_batch(max_size: u32)
347 (v in vec(any::<u64>(), 2..max_size as usize)) -> (MerkleTree<Blake2b<U32>>, Vec<MTLeaf>, Vec<usize>) {
348 let mut rng = rng();
349 let size = v.len();
350 let pks = vec![VerificationKey::default(); size];
351 let leaves = pks.into_iter().zip(v.into_iter()).map(|(key, stake)| MTLeaf(key, stake)).collect::<Vec<MTLeaf>>();
352
353 let indices: Vec<usize> = (0..size).collect();
354 let mut mt_list: Vec<usize> = indices.into_iter().choose_multiple(&mut rng, size * 2 / 10 + 1);
355 mt_list.sort_unstable();
356
357 let mut batch_values: Vec<MTLeaf> = Vec::with_capacity(mt_list.len());
358 for i in mt_list.iter() {
359 batch_values.push(leaves[*i]);
360 }
361
362 (MerkleTree::<Blake2b<U32>>::create(&leaves), batch_values, mt_list)
363 }
364 }
365
366 proptest! {
367 #![proptest_config(ProptestConfig::with_cases(100))]
368 #[test]
369 fn test_create_batch_proof((t, batch_values, indices) in arb_tree_arb_batch(30)) {
370 let batch_proof = t.get_batched_path(indices);
371 assert!(t.to_commitment_batch_compat().check(&batch_values, &batch_proof).is_ok());
372 }
373
374 #[test]
375 fn test_bytes_batch_path((t, batch_values, indices) in arb_tree_arb_batch(30)) {
376 let bp = t.get_batched_path(indices);
377
378 let bytes = &bp.to_bytes();
379 let deserialized = BatchPath::from_bytes(bytes).unwrap();
380 assert!(t.to_commitment_batch_compat().check(&batch_values, &deserialized).is_ok());
381
382 let encoded = bincode::serialize(&bp).unwrap();
383 let decoded: BatchPath<Blake2b<U32>> = bincode::deserialize(&encoded).unwrap();
384 assert!(t.to_commitment_batch_compat().check(&batch_values, &decoded).is_ok());
385 }
386 }
387}