mithril_stm/signature_scheme/unique_schnorr_signature/jubjub/
field_elements.rs1use anyhow::{Context, anyhow};
2use ff::Field;
3use midnight_curves::{Fq as JubjubBase, Fr as JubjubScalar};
4use rand_core::{CryptoRng, RngCore};
5use sha2::{Digest, Sha256};
6use std::ops::{Add, Mul, Neg, Sub};
7
8use crate::StmError;
9use crate::{StmResult, signature_scheme::UniqueSchnorrSignatureError};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Hash, PartialOrd, Ord)]
13pub struct BaseFieldElement(pub(crate) JubjubBase);
14
15impl BaseFieldElement {
16 pub(crate) fn get_one() -> Self {
18 BaseFieldElement(JubjubBase::ONE)
19 }
20
21 #[cfg(all(test, feature = "future_snark"))]
22 #[allow(dead_code)]
24 pub(crate) fn random(rng: &mut (impl RngCore + CryptoRng)) -> Self {
26 BaseFieldElement(JubjubBase::random(rng))
27 }
28
29 pub(crate) fn to_bytes(self) -> [u8; 32] {
32 self.0.to_bytes_le()
33 }
34
35 pub(crate) fn from_bytes(bytes: &[u8]) -> StmResult<Self> {
37 let mut base_bytes = [0u8; 32];
38 base_bytes.copy_from_slice(
39 bytes
40 .get(..32)
41 .ok_or(UniqueSchnorrSignatureError::BaseFieldElementSerialization)?,
42 );
43
44 match JubjubBase::from_bytes_le(&base_bytes).into_option() {
45 Some(base_field_element) => Ok(Self(base_field_element)),
46 None => Err(anyhow!(
47 UniqueSchnorrSignatureError::BaseFieldElementSerialization
48 )),
49 }
50 }
51
52 pub(crate) fn from_raw(bytes: &[u8; 32]) -> StmResult<Self> {
55 Ok(BaseFieldElement(JubjubBase::from_raw([
56 u64::from_le_bytes(bytes[0..8].try_into()?),
57 u64::from_le_bytes(bytes[8..16].try_into()?),
58 u64::from_le_bytes(bytes[16..24].try_into()?),
59 u64::from_le_bytes(bytes[24..32].try_into()?),
60 ])))
61 }
62}
63
64impl TryFrom<&[u8]> for BaseFieldElement {
67 type Error = StmError;
68 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
69 let hashed_input: [u8; 32] = Sha256::digest(value).into();
70 BaseFieldElement::from_raw(&hashed_input)
71 }
72}
73
74impl From<u64> for BaseFieldElement {
75 fn from(integer: u64) -> Self {
77 BaseFieldElement(JubjubBase::from(integer))
78 }
79}
80
81impl Add for BaseFieldElement {
82 type Output = BaseFieldElement;
83
84 fn add(self, other: BaseFieldElement) -> BaseFieldElement {
86 BaseFieldElement(self.0 + other.0)
87 }
88}
89
90impl Neg for BaseFieldElement {
91 type Output = BaseFieldElement;
92
93 fn neg(self) -> BaseFieldElement {
95 BaseFieldElement(-self.0)
96 }
97}
98
99impl Sub for &BaseFieldElement {
100 type Output = BaseFieldElement;
101
102 fn sub(self, other: &BaseFieldElement) -> BaseFieldElement {
104 BaseFieldElement(self.0 - other.0)
105 }
106}
107
108impl Mul for BaseFieldElement {
109 type Output = BaseFieldElement;
110
111 fn mul(self, other: BaseFieldElement) -> BaseFieldElement {
113 BaseFieldElement(self.0 * other.0)
114 }
115}
116
117impl Mul for &BaseFieldElement {
118 type Output = BaseFieldElement;
119
120 fn mul(self, other: &BaseFieldElement) -> BaseFieldElement {
122 BaseFieldElement(self.0 * other.0)
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub(crate) struct ScalarFieldElement(pub(crate) JubjubScalar);
129
130impl ScalarFieldElement {
131 pub(crate) fn new_random_scalar(rng: &mut (impl RngCore + CryptoRng)) -> Self {
133 ScalarFieldElement(JubjubScalar::random(rng))
134 }
135
136 pub(crate) fn is_zero(&self) -> bool {
138 if self.0 == JubjubScalar::zero() {
139 return true;
140 }
141 false
142 }
143
144 pub(crate) fn new_random_nonzero_scalar(
148 rng: &mut (impl RngCore + CryptoRng),
149 ) -> StmResult<Self> {
150 for _ in 0..100 {
151 let random_scalar = Self::new_random_scalar(rng);
152 if !random_scalar.is_zero() {
153 return Ok(random_scalar);
154 }
155 }
156 Err(anyhow!(UniqueSchnorrSignatureError::RandomScalarGeneration))
157 }
158
159 pub(crate) fn to_bytes(self) -> [u8; 32] {
161 self.0.to_bytes()
162 }
163
164 pub(crate) fn from_bytes(bytes: &[u8]) -> StmResult<Self> {
166 let mut scalar_bytes = [0u8; 32];
167 scalar_bytes.copy_from_slice(
168 bytes
169 .get(..32)
170 .ok_or(UniqueSchnorrSignatureError::ScalarFieldElementSerialization)?,
171 );
172
173 match JubjubScalar::from_bytes(&scalar_bytes).into_option() {
174 Some(scalar_field_element) => Ok(Self(scalar_field_element)),
175 None => Err(anyhow!(
176 UniqueSchnorrSignatureError::ScalarFieldElementSerialization
177 )),
178 }
179 }
180
181 pub(crate) fn from_raw(bytes: &[u8]) -> StmResult<Self> {
184 let mut scalar_bytes = [0u8; 32];
185 scalar_bytes.copy_from_slice(
186 bytes
187 .get(..32)
188 .ok_or(UniqueSchnorrSignatureError::ScalarFieldElementSerialization)?,
189 );
190
191 let mut bytes64 = [0u64; 4];
192 for i in 0..4 {
193 bytes64[i] =
194 u64::from_le_bytes(bytes[8 * i..8 * (i + 1)].try_into().with_context(|| {
195 anyhow!(UniqueSchnorrSignatureError::ScalarFieldElementSerialization)
196 })?)
197 }
198
199 Ok(Self(JubjubScalar::from_raw(bytes64)))
200 }
201
202 pub(crate) fn from_base_field(base_element: &BaseFieldElement) -> StmResult<Self> {
204 let base_element_bytes = base_element.0.to_bytes_le();
205 ScalarFieldElement::from_raw(&base_element_bytes)
206 }
207}
208
209impl Mul for ScalarFieldElement {
210 type Output = ScalarFieldElement;
211
212 fn mul(self, other: ScalarFieldElement) -> ScalarFieldElement {
214 ScalarFieldElement(self.0 * other.0)
215 }
216}
217
218impl Sub for ScalarFieldElement {
219 type Output = ScalarFieldElement;
220
221 fn sub(self, other: ScalarFieldElement) -> ScalarFieldElement {
223 ScalarFieldElement(self.0 - other.0)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use rand_chacha::ChaCha20Rng;
230 use rand_core::SeedableRng;
231
232 use super::*;
233
234 mod golden {
235 use super::*;
236
237 const GOLDEN_JSON: &str = r#"[126, 191, 239, 197, 88, 151, 248, 254, 187, 143, 86, 35, 29, 62, 90, 13, 196, 71, 234, 5, 90, 124, 205, 194, 51, 192, 228, 133, 25, 140, 157, 7]"#;
238
239 fn golden_value() -> ScalarFieldElement {
240 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
241 ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap()
242 }
243
244 #[test]
245 fn golden_conversions() {
246 let value = serde_json::from_str(GOLDEN_JSON)
247 .expect("This JSON deserialization should not fail");
248 assert_eq!(golden_value(), value);
249
250 let serialized =
251 serde_json::to_string(&value).expect("This JSON serialization should not fail");
252 let golden_serialized = serde_json::to_string(&golden_value())
253 .expect("This JSON serialization should not fail");
254 assert_eq!(golden_serialized, serialized);
255 }
256 }
257
258 mod bytes_conversion {
259 use super::*;
260
261 #[test]
262 fn from_bytes_fails_if_value_too_high() {
263 let bytes = [255; 32];
264
265 let value = BaseFieldElement::from_bytes(&bytes);
266 value.expect_err("Bytes conversion should fail because input is higher than modulus.");
267
268 let value = ScalarFieldElement::from_bytes(&bytes);
269 value.expect_err("Bytes conversion should fail because input is higher than modulus.");
270 }
271
272 #[cfg(feature = "future_snark")]
273 #[test]
274 fn from_raw_recover_element_correctly() {
275 let mut rng = ChaCha20Rng::from_seed([3u8; 32]);
276 let elem = BaseFieldElement::random(&mut rng);
277 let elem_bytes = elem.to_bytes();
278
279 let val1 = BaseFieldElement::from_bytes(&elem_bytes).unwrap();
280 let val2 = BaseFieldElement::from_raw(&elem_bytes).unwrap();
281
282 assert_eq!(val1, val2);
283 }
284
285 #[test]
286 fn from_raw_succeed_for_max_value() {
287 let bytes = [255; 32];
288
289 let value = BaseFieldElement::from_bytes(&bytes);
290 value.expect_err("Bytes conversion should fail because input is higher than modulus.");
291
292 let value = BaseFieldElement::from_raw(&bytes);
293 assert!(
294 value.is_ok(),
295 "The conversion should not fail when using from_raw."
296 );
297 }
298 }
299
300 mod base_field_arithmetic {
301 use super::*;
302
303 #[test]
304 fn test_add() {
305 let a = BaseFieldElement(JubjubBase::ONE);
306 let b = BaseFieldElement(JubjubBase::ONE);
307 let result = a + b;
308 assert_eq!(result, BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE));
309 }
310
311 #[test]
312 fn test_add_with_zero() {
313 let a = BaseFieldElement(JubjubBase::ONE);
314 let zero = BaseFieldElement(JubjubBase::ZERO);
315 let result = a + zero;
316 assert_eq!(result, a);
317 }
318
319 #[test]
320 fn test_sub_references() {
321 let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
322 let b = BaseFieldElement(JubjubBase::ONE);
323 let result = &a - &b;
324 assert_eq!(result, BaseFieldElement(JubjubBase::ONE));
325 }
326
327 #[test]
328 fn test_sub_same_values() {
329 let a = BaseFieldElement(JubjubBase::ONE);
330 let b = BaseFieldElement(JubjubBase::ONE);
331 let result = &a - &b;
332 assert_eq!(result, BaseFieldElement(JubjubBase::ZERO));
333 }
334
335 #[test]
336 fn test_mul_owned() {
337 let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
338 let b = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
339 let result = a * b;
340 let expected = JubjubBase::ONE + JubjubBase::ONE;
341 assert_eq!(result, BaseFieldElement(expected * expected));
342 }
343
344 #[test]
345 fn test_mul_with_one() {
346 let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
347 let one = BaseFieldElement::get_one();
348 let result = a * one;
349 assert_eq!(result, a);
350 }
351
352 #[test]
353 fn test_mul_with_zero() {
354 let a = BaseFieldElement(JubjubBase::ONE);
355 let zero = BaseFieldElement(JubjubBase::ZERO);
356 let result = a * zero;
357 assert_eq!(result, BaseFieldElement(JubjubBase::ZERO));
358 }
359
360 #[test]
361 fn test_mul_references() {
362 let a = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE);
363 let b = BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE + JubjubBase::ONE);
364 let result = a * b;
365 let expected = (JubjubBase::ONE + JubjubBase::ONE)
366 * (JubjubBase::ONE + JubjubBase::ONE + JubjubBase::ONE);
367 assert_eq!(result, BaseFieldElement(expected));
368 }
369
370 #[test]
371 fn test_chained_operations() {
372 let a = BaseFieldElement(JubjubBase::ONE);
373 let b = BaseFieldElement(JubjubBase::ONE);
374 let c = BaseFieldElement(JubjubBase::ONE);
375 let result = (a + b) * c;
376 assert_eq!(result, BaseFieldElement(JubjubBase::ONE + JubjubBase::ONE));
377 }
378 }
379
380 mod scalar_field_arithmetic {
381 use super::*;
382
383 #[test]
384 fn test_mul() {
385 let mut rng = ChaCha20Rng::from_seed([1u8; 32]);
386 let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
387 let b = ScalarFieldElement(JubjubScalar::one());
388 let result = a * b;
389 assert_eq!(result, a);
390 }
391
392 #[test]
393 fn test_mul_with_zero() {
394 let mut rng = ChaCha20Rng::from_seed([2u8; 32]);
395 let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
396 let zero = ScalarFieldElement(JubjubScalar::zero());
397 let result = a * zero;
398 assert!(result.is_zero());
399 }
400
401 #[test]
402 fn test_mul_associativity() {
403 let mut rng = ChaCha20Rng::from_seed([3u8; 32]);
404 let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
405 let b = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
406 let c = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
407
408 let result1 = (a * b) * c;
409 let result2 = a * (b * c);
410 assert_eq!(result1, result2);
411 }
412
413 #[test]
414 fn test_sub() {
415 let mut rng = ChaCha20Rng::from_seed([4u8; 32]);
416 let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
417 let result = a - a;
418 assert!(result.is_zero());
419 }
420
421 #[test]
422 fn test_sub_with_zero() {
423 let mut rng = ChaCha20Rng::from_seed([5u8; 32]);
424 let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
425 let zero = ScalarFieldElement(JubjubScalar::zero());
426 let result = a - zero;
427 assert_eq!(result, a);
428 }
429
430 #[test]
431 fn test_sub_specific_values() {
432 let two = ScalarFieldElement(JubjubScalar::one() + JubjubScalar::one());
433 let one = ScalarFieldElement(JubjubScalar::one());
434 let result = two - one;
435 assert_eq!(result, ScalarFieldElement(JubjubScalar::one()));
436 }
437
438 #[test]
439 fn test_combined_operations() {
440 let mut rng = ChaCha20Rng::from_seed([6u8; 32]);
441 let a = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
442 let b = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
443 let c = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
444
445 let left = a * (b - c);
446 let right = (a * b) - (a * c);
447 assert_eq!(left, right);
448 }
449 }
450}