mithril_stm/signature_scheme/unique_schnorr_signature/jubjub/
field_elements.rs

1use anyhow::{Context, anyhow};
2use ff::Field;
3use midnight_curves::{Fq as JubjubBase, Fr as JubjubScalar};
4use rand_core::{CryptoRng, RngCore};
5use sha2::{Digest, Sha256};
6use std::ops::{Add, Mul, Neg, Sub};
7
8use crate::StmError;
9use crate::{StmResult, signature_scheme::UniqueSchnorrSignatureError};
10
11/// Represents an element in the base field of the Jubjub curve
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash, PartialOrd, Ord)]
13pub struct BaseFieldElement(pub(crate) JubjubBase);
14
15impl BaseFieldElement {
16    /// Retrieves the multiplicative identity element of the base field
17    pub(crate) fn get_one() -> Self {
18        BaseFieldElement(JubjubBase::ONE)
19    }
20
21    #[cfg(all(test, feature = "future_snark"))]
22    // TODO: remove this allow dead_code directive when function is called or future_snark is activated
23    #[allow(dead_code)]
24    /// Generates a new random scalar field element
25    pub(crate) fn random(rng: &mut (impl RngCore + CryptoRng)) -> Self {
26        BaseFieldElement(JubjubBase::random(rng))
27    }
28
29    /// Converts the base field element to its byte representation in
30    /// little endian form
31    pub(crate) fn to_bytes(self) -> [u8; 32] {
32        self.0.to_bytes_le()
33    }
34
35    /// Constructs a base field element from its byte representation
36    pub(crate) fn from_bytes(bytes: &[u8]) -> StmResult<Self> {
37        let mut base_bytes = [0u8; 32];
38        base_bytes.copy_from_slice(
39            bytes
40                .get(..32)
41                .ok_or(UniqueSchnorrSignatureError::BaseFieldElementSerialization)?,
42        );
43
44        match JubjubBase::from_bytes_le(&base_bytes).into_option() {
45            Some(base_field_element) => Ok(Self(base_field_element)),
46            None => Err(anyhow!(
47                UniqueSchnorrSignatureError::BaseFieldElementSerialization
48            )),
49        }
50    }
51
52    /// Constructs a base field element from bytes by applying modulus reduction
53    /// The underlying JubjubBase conversion function used cannot fail
54    pub(crate) fn from_raw(bytes: &[u8; 32]) -> StmResult<Self> {
55        Ok(BaseFieldElement(JubjubBase::from_raw([
56            u64::from_le_bytes(bytes[0..8].try_into()?),
57            u64::from_le_bytes(bytes[8..16].try_into()?),
58            u64::from_le_bytes(bytes[16..24].try_into()?),
59            u64::from_le_bytes(bytes[24..32].try_into()?),
60        ])))
61    }
62}
63
64/// Try to convert an arbitrary slice of bytes to a BaseFieldElement by first
65/// hashing the bytes using Sha256 and then converting using modulus reduction
66impl TryFrom<&[u8]> for BaseFieldElement {
67    type Error = StmError;
68    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
69        let hashed_input: [u8; 32] = Sha256::digest(value).into();
70        BaseFieldElement::from_raw(&hashed_input)
71    }
72}
73
74impl From<u64> for BaseFieldElement {
75    /// Converts a `u64` integer to a base field element
76    fn from(integer: u64) -> Self {
77        BaseFieldElement(JubjubBase::from(integer))
78    }
79}
80
81impl Add for BaseFieldElement {
82    type Output = BaseFieldElement;
83
84    /// Adds two base field elements
85    fn add(self, other: BaseFieldElement) -> BaseFieldElement {
86        BaseFieldElement(self.0 + other.0)
87    }
88}
89
90impl Neg for BaseFieldElement {
91    type Output = BaseFieldElement;
92
93    /// Negates a base field element
94    fn neg(self) -> BaseFieldElement {
95        BaseFieldElement(-self.0)
96    }
97}
98
99impl Sub for &BaseFieldElement {
100    type Output = BaseFieldElement;
101
102    /// Subtracts one base field element from another
103    fn sub(self, other: &BaseFieldElement) -> BaseFieldElement {
104        BaseFieldElement(self.0 - other.0)
105    }
106}
107
108impl Mul for BaseFieldElement {
109    type Output = BaseFieldElement;
110
111    /// Multiplies two base field elements
112    fn mul(self, other: BaseFieldElement) -> BaseFieldElement {
113        BaseFieldElement(self.0 * other.0)
114    }
115}
116
117impl Mul for &BaseFieldElement {
118    type Output = BaseFieldElement;
119
120    /// Multiplies a base field element by another base field element
121    fn mul(self, other: &BaseFieldElement) -> BaseFieldElement {
122        BaseFieldElement(self.0 * other.0)
123    }
124}
125
126/// Represents an element in the scalar field of the Jubjub curve
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub(crate) struct ScalarFieldElement(pub(crate) JubjubScalar);
129
130impl ScalarFieldElement {
131    /// Generates a new random scalar field element
132    pub(crate) fn new_random_scalar(rng: &mut (impl RngCore + CryptoRng)) -> Self {
133        ScalarFieldElement(JubjubScalar::random(rng))
134    }
135
136    /// Checks if the scalar field element is zero
137    pub(crate) fn is_zero(&self) -> bool {
138        if self.0 == JubjubScalar::zero() {
139            return true;
140        }
141        false
142    }
143
144    /// Tries to generate a new random non-zero scalar field element in 100 attempts
145    ///
146    /// Returns an error if unable to generate a non-zero scalar after 100 attempts
147    pub(crate) fn new_random_nonzero_scalar(
148        rng: &mut (impl RngCore + CryptoRng),
149    ) -> StmResult<Self> {
150        for _ in 0..100 {
151            let random_scalar = Self::new_random_scalar(rng);
152            if !random_scalar.is_zero() {
153                return Ok(random_scalar);
154            }
155        }
156        Err(anyhow!(UniqueSchnorrSignatureError::RandomScalarGeneration))
157    }
158
159    /// Converts the scalar field element to its byte representation
160    pub(crate) fn to_bytes(self) -> [u8; 32] {
161        self.0.to_bytes()
162    }
163
164    /// Constructs a scalar field element from its byte representation
165    pub(crate) fn from_bytes(bytes: &[u8]) -> StmResult<Self> {
166        let mut scalar_bytes = [0u8; 32];
167        scalar_bytes.copy_from_slice(
168            bytes
169                .get(..32)
170                .ok_or(UniqueSchnorrSignatureError::ScalarFieldElementSerialization)?,
171        );
172
173        match JubjubScalar::from_bytes(&scalar_bytes).into_option() {
174            Some(scalar_field_element) => Ok(Self(scalar_field_element)),
175            None => Err(anyhow!(
176                UniqueSchnorrSignatureError::ScalarFieldElementSerialization
177            )),
178        }
179    }
180
181    /// Constructs a scalar field element from its byte representation
182    /// while reducing modulo the scalar field modulus if necessary
183    pub(crate) fn from_raw(bytes: &[u8]) -> StmResult<Self> {
184        let mut scalar_bytes = [0u8; 32];
185        scalar_bytes.copy_from_slice(
186            bytes
187                .get(..32)
188                .ok_or(UniqueSchnorrSignatureError::ScalarFieldElementSerialization)?,
189        );
190
191        let mut bytes64 = [0u64; 4];
192        for i in 0..4 {
193            bytes64[i] =
194                u64::from_le_bytes(bytes[8 * i..8 * (i + 1)].try_into().with_context(|| {
195                    anyhow!(UniqueSchnorrSignatureError::ScalarFieldElementSerialization)
196                })?)
197        }
198
199        Ok(Self(JubjubScalar::from_raw(bytes64)))
200    }
201
202    /// Convert a base field element to a scalar
203    pub(crate) fn from_base_field(base_element: &BaseFieldElement) -> StmResult<Self> {
204        let base_element_bytes = base_element.0.to_bytes_le();
205        ScalarFieldElement::from_raw(&base_element_bytes)
206    }
207}
208
209impl Mul for ScalarFieldElement {
210    type Output = ScalarFieldElement;
211
212    /// Multiplies two scalar field elements
213    fn mul(self, other: ScalarFieldElement) -> ScalarFieldElement {
214        ScalarFieldElement(self.0 * other.0)
215    }
216}
217
218impl Sub for ScalarFieldElement {
219    type Output = ScalarFieldElement;
220
221    /// Subtracts one scalar field element from another
222    fn sub(self, other: ScalarFieldElement) -> ScalarFieldElement {
223        ScalarFieldElement(self.0 - other.0)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use rand_chacha::ChaCha20Rng;
230    use rand_core::SeedableRng;
231
232    use super::*;
233
234    mod golden {
235        use super::*;
236
237        const GOLDEN_JSON: &str = r#"[126, 191, 239, 197, 88, 151, 248, 254, 187, 143, 86, 35, 29, 62, 90, 13, 196, 71, 234, 5, 90, 124, 205, 194, 51, 192, 228, 133, 25, 140, 157, 7]"#;
238
239        fn golden_value() -> ScalarFieldElement {
240            let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
241            ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap()
242        }
243
244        #[test]
245        fn golden_conversions() {
246            let value = serde_json::from_str(GOLDEN_JSON)
247                .expect("This JSON deserialization should not fail");
248            assert_eq!(golden_value(), value);
249
250            let serialized =
251                serde_json::to_string(&value).expect("This JSON serialization should not fail");
252            let golden_serialized = serde_json::to_string(&golden_value())
253                .expect("This JSON serialization should not fail");
254            assert_eq!(golden_serialized, serialized);
255        }
256    }
257
258    mod bytes_conversion {
259        use super::*;
260
261        #[test]
262        fn from_bytes_fails_if_value_too_high() {
263            let bytes = [255; 32];
264
265            let value = BaseFieldElement::from_bytes(&bytes);
266            value.expect_err("Bytes conversion should fail because input is higher than modulus.");
267
268            let value = ScalarFieldElement::from_bytes(&bytes);
269            value.expect_err("Bytes conversion should fail because input is higher than modulus.");
270        }
271
272        #[cfg(feature = "future_snark")]
273        #[test]
274        fn from_raw_recover_element_correctly() {
275            let mut rng = ChaCha20Rng::from_seed([3u8; 32]);
276            let elem = BaseFieldElement::random(&mut rng);
277            let elem_bytes = elem.to_bytes();
278
279            let val1 = BaseFieldElement::from_bytes(&elem_bytes).unwrap();
280            let val2 = BaseFieldElement::from_raw(&elem_bytes).unwrap();
281
282            assert_eq!(val1, val2);
283        }
284
285        #[test]
286        fn from_raw_succeed_for_max_value() {
287            let bytes = [255; 32];
288
289            let value = BaseFieldElement::from_bytes(&bytes);
290            value.expect_err("Bytes conversion should fail because input is higher than modulus.");
291
292            let value = BaseFieldElement::from_raw(&bytes);
293            assert!(
294                value.is_ok(),
295                "The conversion should not fail when using from_raw."
296            );
297        }
298    }
299
300    mod base_field_arithmetic {
301        use super::*;
302
303        #[test]
304        fn test_add() {
305            let a = BaseFieldElement(JubjubBase::ONE);
306            let b = BaseFieldElement(JubjubBase::ONE);
307            let result = a + b;
308            assert_eq!(result, BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE));
309        }
310
311        #[test]
312        fn test_add_with_zero() {
313            let a = BaseFieldElement(JubjubBase::ONE);
314            let zero = BaseFieldElement(JubjubBase::ZERO);
315            let result = a + zero;
316            assert_eq!(result, a);
317        }
318
319        #[test]
320        fn test_sub_references() {
321            let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
322            let b = BaseFieldElement(JubjubBase::ONE);
323            let result = &a - &b;
324            assert_eq!(result, BaseFieldElement(JubjubBase::ONE));
325        }
326
327        #[test]
328        fn test_sub_same_values() {
329            let a = BaseFieldElement(JubjubBase::ONE);
330            let b = BaseFieldElement(JubjubBase::ONE);
331            let result = &a - &b;
332            assert_eq!(result, BaseFieldElement(JubjubBase::ZERO));
333        }
334
335        #[test]
336        fn test_mul_owned() {
337            let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
338            let b = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
339            let result = a * b;
340            let expected = JubjubBase::ONE + JubjubBase::ONE;
341            assert_eq!(result, BaseFieldElement(expected * expected));
342        }
343
344        #[test]
345        fn test_mul_with_one() {
346            let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
347            let one = BaseFieldElement::get_one();
348            let result = a * one;
349            assert_eq!(result, a);
350        }
351
352        #[test]
353        fn test_mul_with_zero() {
354            let a = BaseFieldElement(JubjubBase::ONE);
355            let zero = BaseFieldElement(JubjubBase::ZERO);
356            let result = a * zero;
357            assert_eq!(result, BaseFieldElement(JubjubBase::ZERO));
358        }
359
360        #[test]
361        fn test_mul_references() {
362            let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
363            let b = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE + JubjubBase::ONE);
364            let result = a * b;
365            let expected = (JubjubBase::ONE + JubjubBase::ONE)
366                * (JubjubBase::ONE + JubjubBase::ONE + JubjubBase::ONE);
367            assert_eq!(result, BaseFieldElement(expected));
368        }
369
370        #[test]
371        fn test_chained_operations() {
372            let a = BaseFieldElement(JubjubBase::ONE);
373            let b = BaseFieldElement(JubjubBase::ONE);
374            let c = BaseFieldElement(JubjubBase::ONE);
375            let result = (a + b) * c;
376            assert_eq!(result, BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE));
377        }
378    }
379
380    mod scalar_field_arithmetic {
381        use super::*;
382
383        #[test]
384        fn test_mul() {
385            let mut rng = ChaCha20Rng::from_seed([1u8; 32]);
386            let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
387            let b = ScalarFieldElement(JubjubScalar::one());
388            let result = a * b;
389            assert_eq!(result, a);
390        }
391
392        #[test]
393        fn test_mul_with_zero() {
394            let mut rng = ChaCha20Rng::from_seed([2u8; 32]);
395            let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
396            let zero = ScalarFieldElement(JubjubScalar::zero());
397            let result = a * zero;
398            assert!(result.is_zero());
399        }
400
401        #[test]
402        fn test_mul_associativity() {
403            let mut rng = ChaCha20Rng::from_seed([3u8; 32]);
404            let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
405            let b = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
406            let c = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
407
408            let result1 = (a * b) * c;
409            let result2 = a * (b * c);
410            assert_eq!(result1, result2);
411        }
412
413        #[test]
414        fn test_sub() {
415            let mut rng = ChaCha20Rng::from_seed([4u8; 32]);
416            let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
417            let result = a - a;
418            assert!(result.is_zero());
419        }
420
421        #[test]
422        fn test_sub_with_zero() {
423            let mut rng = ChaCha20Rng::from_seed([5u8; 32]);
424            let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
425            let zero = ScalarFieldElement(JubjubScalar::zero());
426            let result = a - zero;
427            assert_eq!(result, a);
428        }
429
430        #[test]
431        fn test_sub_specific_values() {
432            let two = ScalarFieldElement(JubjubScalar::one() + JubjubScalar::one());
433            let one = ScalarFieldElement(JubjubScalar::one());
434            let result = two - one;
435            assert_eq!(result, ScalarFieldElement(JubjubScalar::one()));
436        }
437
438        #[test]
439        fn test_combined_operations() {
440            let mut rng = ChaCha20Rng::from_seed([6u8; 32]);
441            let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
442            let b = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
443            let c = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
444
445            let left = a * (b - c);
446            let right = (a * b) - (a * c);
447            assert_eq!(left, right);
448        }
449    }
450}