mithril_stm/
key_registration.rs1use std::{
3 collections::{hash_map::Entry, HashMap},
4 sync::Arc,
5};
6
7use blake2::digest::{Digest, FixedOutput};
8
9use crate::bls_multi_signature::{BlsVerificationKey, BlsVerificationKeyProofOfPossession};
10use crate::error::RegisterError;
11use crate::merkle_tree::{MerkleTree, MerkleTreeLeaf};
12use crate::Stake;
13
14pub type RegisteredParty = MerkleTreeLeaf;
16
17#[derive(Clone, Debug, Default, PartialEq, Eq)]
21pub struct KeyRegistration {
22 keys: HashMap<BlsVerificationKey, Stake>,
23}
24
25impl KeyRegistration {
26 pub fn init() -> Self {
28 Self::default()
29 }
30
31 pub fn register(
35 &mut self,
36 stake: Stake,
37 pk: BlsVerificationKeyProofOfPossession,
38 ) -> Result<(), RegisterError> {
39 if let Entry::Vacant(e) = self.keys.entry(pk.vk) {
40 pk.check()?;
41 e.insert(stake);
42 return Ok(());
43 }
44 Err(RegisterError::KeyRegistered(Box::new(pk.vk)))
45 }
46
47 pub fn close<D>(self) -> ClosedKeyRegistration<D>
50 where
51 D: Digest + FixedOutput,
52 {
53 let mut total_stake: Stake = 0;
54 let mut reg_parties = self
55 .keys
56 .iter()
57 .map(|(&vk, &stake)| {
58 let (res, overflow) = total_stake.overflowing_add(stake);
59 if overflow {
60 panic!("Total stake overflow");
61 }
62 total_stake = res;
63 MerkleTreeLeaf(vk, stake)
64 })
65 .collect::<Vec<RegisteredParty>>();
66 reg_parties.sort();
67
68 ClosedKeyRegistration {
69 merkle_tree: Arc::new(MerkleTree::create(®_parties)),
70 reg_parties,
71 total_stake,
72 }
73 }
74}
75
76#[derive(Clone, Debug, PartialEq, Eq)]
79pub struct ClosedKeyRegistration<D: Digest> {
80 pub reg_parties: Vec<RegisteredParty>,
82 pub total_stake: Stake,
84 pub merkle_tree: Arc<MerkleTree<D>>,
86}
87
88#[cfg(test)]
89mod tests {
90 use blake2::{digest::consts::U32, Blake2b};
91 use proptest::{collection::vec, prelude::*};
92 use rand_chacha::ChaCha20Rng;
93 use rand_core::SeedableRng;
94
95 use crate::bls_multi_signature::BlsSigningKey;
96
97 use super::*;
98
99 proptest! {
100 #[test]
101 fn test_keyreg(stake in vec(1..1u64 << 60, 2..=10),
102 nkeys in 2..10_usize,
103 fake_it in 0..4usize,
104 seed in any::<[u8;32]>()) {
105 let mut rng = ChaCha20Rng::from_seed(seed);
106 let mut kr = KeyRegistration::init();
107
108 let gen_keys = (1..nkeys).map(|_| {
109 let sk = BlsSigningKey::generate(&mut rng);
110 BlsVerificationKeyProofOfPossession::from(&sk)
111 }).collect::<Vec<_>>();
112
113 let fake_key = {
114 let sk = BlsSigningKey::generate(&mut rng);
115 BlsVerificationKeyProofOfPossession::from(&sk)
116 };
117
118 let mut keys = HashMap::new();
120
121 for (i, &stake) in stake.iter().enumerate() {
122 let mut pk = gen_keys[i % gen_keys.len()];
123
124 if fake_it == 0 {
125 pk.pop = fake_key.pop;
126 }
127
128 let reg = kr.register(stake, pk);
129 match reg {
130 Ok(_) => {
131 assert!(keys.insert(pk.vk, stake).is_none());
132 },
133 Err(RegisterError::KeyRegistered(pk1)) => {
134 assert!(pk1.as_ref() == &pk.vk);
135 assert!(keys.contains_key(&pk.vk));
136 }
137 Err(RegisterError::KeyInvalid(a)) => {
138 assert_eq!(fake_it, 0);
139 assert!(a.check().is_err());
140 }
141 Err(RegisterError::SerializationError) => unreachable!(),
142 _ => unreachable!(),
143 }
144 }
145
146 if !kr.keys.is_empty() {
147 let closed = kr.close::<Blake2b<U32>>();
148 let retrieved_keys = closed.reg_parties.iter().map(|r| (r.0, r.1)).collect::<HashMap<_,_>>();
149 assert!(retrieved_keys == keys);
150 }
151 }
152 }
153}