mithril_stm/hash/
poseidon.rs

1use digest::generic_array::typenum::U32;
2use digest::{FixedOutput, HashMarker, Output, OutputSizeUser, Reset, Update};
3use midnight_circuits::{hash::poseidon::PoseidonChip, instructions::hash::HashCPU};
4use midnight_curves::Fq as JubjubBase;
5
6/// Wrapper to implement the Digest trait for the Poseidon hash function
7/// We need this implementation to keep the merkle tree implementation
8/// generic over the Digest used.
9///
10/// This implementation differs from the usual behavior of the digest
11/// update as documented in the update implementation.
12#[derive(Debug, Clone, Default, Eq, PartialEq)]
13pub struct MidnightPoseidonDigest {
14    buffer: Vec<u8>,
15}
16
17impl MidnightPoseidonDigest {
18    pub fn new() -> Self {
19        Self { buffer: Vec::new() }
20    }
21}
22
23impl Update for MidnightPoseidonDigest {
24    // The finalize function uses the from_raw method from JubjubBase
25    // but this function converts a value of exactly 256 bits to a value of 255 bits
26    // which means we can loose one bit of the input if not careful.
27    // This is why we need to make sure that the update function input
28    // always represent an element of JubjubBase before calling it.
29    // A potential way to make sure we don't loose a bit of information
30    // could be to only allow the update function to take a maximum of 32 bytes
31    // at a time. It would be a sort of check that makes sure we're not calling
32    // this function on "wrong" inputs somewhere in the code
33    //
34    // The function is also using a padding which deviates from the tradition usage
35    // of the digest in which update([1]).update([2]) gives the same result as
36    // update([1, 2]). We leave this functionality as is for now since this
37    // function is only used for the merkle tree and we prefer this behavior
38    fn update(&mut self, data: &[u8]) {
39        // Computes the next multiple of 32 as target length
40        let target_len = (data.len() + 31) & !31;
41        let mut padded_data = Vec::with_capacity(target_len);
42        padded_data.extend_from_slice(data);
43        // Pad the data with zeros on the right to match the target length
44        padded_data.resize(target_len, 0);
45        self.buffer.extend_from_slice(&padded_data);
46    }
47}
48
49impl OutputSizeUser for MidnightPoseidonDigest {
50    type OutputSize = U32;
51}
52
53impl FixedOutput for MidnightPoseidonDigest {
54    fn finalize_into(self, out: &mut Output<Self>) {
55        // The data is padded during the call to the update function
56        // so there should always be a multiple of 32 bytes in the buffer
57        // We are taking chunks of 32 u8 so it should be fine to unwrap
58        let poseidon_input = self
59            .buffer
60            .chunks_exact(32)
61            .map(|chunk| {
62                // The from_raw function performs a modular reduction so
63                // it will never fail even if the buffer value exceeds
64                // the value of the modulus
65                // Since we should only get inputs that represent JubjubBase
66                // elements, we could also switch to from_bytes_le
67                JubjubBase::from_raw([
68                    u64::from_le_bytes(chunk[0..8].try_into().unwrap()),
69                    u64::from_le_bytes(chunk[8..16].try_into().unwrap()),
70                    u64::from_le_bytes(chunk[16..24].try_into().unwrap()),
71                    u64::from_le_bytes(chunk[24..32].try_into().unwrap()),
72                ])
73            })
74            .collect::<Vec<JubjubBase>>();
75        let result: JubjubBase = PoseidonChip::<JubjubBase>::hash(&poseidon_input);
76        out.copy_from_slice(&result.to_bytes_le());
77    }
78}
79
80impl Reset for MidnightPoseidonDigest {
81    fn reset(&mut self) {
82        self.buffer.clear();
83    }
84}
85
86impl HashMarker for MidnightPoseidonDigest {}
87
88#[cfg(test)]
89mod tests {
90    use blake2::digest::Digest;
91    use midnight_circuits::{hash::poseidon::PoseidonChip, instructions::hash::HashCPU};
92    use midnight_curves::Fq as JubjubBase;
93
94    use super::MidnightPoseidonDigest;
95
96    #[test]
97    fn test_digest_impl_single_element() {
98        let bytes = [0u8; 32];
99        let elem = JubjubBase::from_raw([
100            u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
101            u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
102            u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
103            u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
104        ]);
105
106        let digest_result = MidnightPoseidonDigest::digest(bytes).to_vec();
107        let mut digest_result_bytes = [0u8; 32];
108        digest_result_bytes.copy_from_slice(&digest_result);
109        let digest_result_elem = JubjubBase::from_bytes_le(&digest_result_bytes).unwrap();
110        let digest_result_poseidon = PoseidonChip::<JubjubBase>::hash(&[elem]);
111
112        assert_eq!(digest_result_elem, digest_result_poseidon);
113    }
114
115    #[test]
116    fn test_digest_impl_chain_update() {
117        let bytes = [0u8; 32];
118        let elem = JubjubBase::from_raw([
119            u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
120            u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
121            u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
122            u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
123        ]);
124
125        let digest_result = MidnightPoseidonDigest::new()
126            .chain_update(bytes)
127            .chain_update(bytes)
128            .finalize()
129            .to_vec();
130        let mut digest_result_bytes = [0u8; 32];
131        digest_result_bytes.copy_from_slice(&digest_result);
132        let digest_result_elem = JubjubBase::from_bytes_le(&digest_result_bytes).unwrap();
133        let digest_result_poseidon = PoseidonChip::<JubjubBase>::hash(&[elem, elem]);
134
135        assert_eq!(digest_result_elem, digest_result_poseidon);
136    }
137
138    #[test]
139    fn test_digest_impl_single_byte() {
140        let byte = 2u8;
141        let elem = JubjubBase::from(byte as u64);
142
143        let digest_result = MidnightPoseidonDigest::digest([byte]).to_vec();
144        let mut digest_result_bytes = [0u8; 32];
145        digest_result_bytes.copy_from_slice(&digest_result);
146        let digest_result_elem = JubjubBase::from_bytes_le(&digest_result_bytes).unwrap();
147        let poseidon_result = PoseidonChip::<JubjubBase>::hash(&[elem]);
148
149        assert_eq!(digest_result_elem, poseidon_result);
150    }
151
152    #[test]
153    fn test_digest_impl_empty_byte_array() {
154        let digest_result = MidnightPoseidonDigest::digest([]).to_vec();
155        let mut digest_result_bytes = [0u8; 32];
156        digest_result_bytes.copy_from_slice(&digest_result);
157        let digest_result_elem = JubjubBase::from_bytes_le(&digest_result_bytes).unwrap();
158        let poseidon_result = PoseidonChip::<JubjubBase>::hash(&[]);
159
160        assert_eq!(digest_result_elem, poseidon_result);
161    }
162
163    #[test]
164    fn test_digest_impl_input_not_multiple_32() {
165        let bytes = [1u8; 48];
166        let zero_bytes = [0u8; 16];
167        let elem1 = JubjubBase::from_raw([
168            u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
169            u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
170            u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
171            u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
172        ]);
173        let elem2 = JubjubBase::from_raw([
174            u64::from_le_bytes(bytes[32..40].try_into().unwrap()),
175            u64::from_le_bytes(bytes[40..48].try_into().unwrap()),
176            u64::from_le_bytes(zero_bytes[0..8].try_into().unwrap()),
177            u64::from_le_bytes(zero_bytes[8..16].try_into().unwrap()),
178        ]);
179
180        let digest_result = MidnightPoseidonDigest::digest(bytes).to_vec();
181        let mut digest_result_bytes = [0u8; 32];
182        digest_result_bytes.copy_from_slice(&digest_result);
183        let digest_result_elem = JubjubBase::from_bytes_le(&digest_result_bytes).unwrap();
184        let poseidon_result = PoseidonChip::<JubjubBase>::hash(&[elem1, elem2]);
185
186        assert_eq!(digest_result_elem, poseidon_result);
187    }
188
189    #[test]
190    fn test_digest_impl_chain_update_order() {
191        let one = JubjubBase::from(1u64);
192        let two = JubjubBase::from(2u64);
193        let three = JubjubBase::from(3u64);
194
195        let digest_result = MidnightPoseidonDigest::new()
196            .chain_update([1u8])
197            .chain_update([3u8])
198            .chain_update([2u8])
199            .finalize()
200            .to_vec();
201        let mut digest_result_bytes = [0u8; 32];
202        digest_result_bytes.copy_from_slice(&digest_result);
203        let digest_result_elem = JubjubBase::from_bytes_le(&digest_result_bytes).unwrap();
204        let poseidon_result = PoseidonChip::<JubjubBase>::hash(&[one, three, two]);
205
206        assert_eq!(digest_result_elem, poseidon_result);
207    }
208
209    #[test]
210    fn test_collision_for_large_values() {
211        let mut value = [0; 32];
212        value[0] = 1;
213        let modulus_plus_one = [
214            2, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8,
215            216, 57, 51, 72, 125, 157, 41, 83, 167, 237, 115,
216        ];
217
218        let digest_result = MidnightPoseidonDigest::new().chain_update(value).finalize().to_vec();
219        let mut digest_result_bytes = [0u8; 32];
220        digest_result_bytes.copy_from_slice(&digest_result);
221        let digest_result_elem = JubjubBase::from_bytes_le(&digest_result_bytes).unwrap();
222        let digest_result_mod = MidnightPoseidonDigest::new()
223            .chain_update(modulus_plus_one)
224            .finalize()
225            .to_vec();
226        let mut digest_result_bytes_mod = [0u8; 32];
227        digest_result_bytes_mod.copy_from_slice(&digest_result_mod);
228        let digest_result_elem_mod = JubjubBase::from_bytes_le(&digest_result_bytes_mod).unwrap();
229
230        assert!(
231            digest_result_elem == digest_result_elem_mod,
232            "The hash of 1 and modulus + 1 give the same result!"
233        );
234    }
235
236    #[cfg(test)]
237    mod golden_tests {
238        use super::*;
239
240        const GOLDEN_BYTES: [u8; 32] = [
241            110, 103, 7, 180, 60, 102, 100, 65, 91, 212, 214, 109, 138, 43, 27, 222, 2, 206, 234,
242            218, 176, 114, 103, 100, 18, 121, 123, 177, 36, 188, 37, 95,
243        ];
244
245        fn golden_value() -> JubjubBase {
246            let digest_result = MidnightPoseidonDigest::new()
247                .chain_update([1u8])
248                .chain_update([3u8])
249                .chain_update([2u8])
250                .finalize()
251                .to_vec();
252            let mut digest_result_bytes = [0u8; 32];
253            digest_result_bytes.copy_from_slice(&digest_result);
254
255            JubjubBase::from_bytes_le(&digest_result_bytes).unwrap()
256        }
257
258        #[test]
259        fn golden_test_chain_update() {
260            let value = JubjubBase::from_bytes_le(&GOLDEN_BYTES).unwrap();
261
262            assert_eq!(golden_value(), value);
263        }
264    }
265}