1use std::{
3 collections::{hash_map::Entry, HashMap},
4 sync::Arc,
5};
6
7use blake2::digest::{Digest, FixedOutput};
8
9use crate::bls_multi_signature::{VerificationKey, VerificationKeyPoP};
10use crate::error::RegisterError;
11use crate::merkle_tree::{MTLeaf, MerkleTree};
12use crate::Stake;
13
14pub type RegParty = MTLeaf;
16
17#[derive(Clone, Debug, Default, PartialEq, Eq)]
21pub struct KeyReg {
22 keys: HashMap<VerificationKey, Stake>,
23}
24
25impl KeyReg {
26 pub fn init() -> Self {
28 Self::default()
29 }
30
31 pub fn register(&mut self, stake: Stake, pk: VerificationKeyPoP) -> Result<(), RegisterError> {
35 if let Entry::Vacant(e) = self.keys.entry(pk.vk) {
36 pk.check()?;
37 e.insert(stake);
38 return Ok(());
39 }
40 Err(RegisterError::KeyRegistered(Box::new(pk.vk)))
41 }
42
43 pub fn close<D>(self) -> ClosedKeyReg<D>
46 where
47 D: Digest + FixedOutput,
48 {
49 let mut total_stake: Stake = 0;
50 let mut reg_parties = self
51 .keys
52 .iter()
53 .map(|(&vk, &stake)| {
54 let (res, overflow) = total_stake.overflowing_add(stake);
55 if overflow {
56 panic!("Total stake overflow");
57 }
58 total_stake = res;
59 MTLeaf(vk, stake)
60 })
61 .collect::<Vec<RegParty>>();
62 reg_parties.sort();
63
64 ClosedKeyReg {
65 merkle_tree: Arc::new(MerkleTree::create(®_parties)),
66 reg_parties,
67 total_stake,
68 }
69 }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq)]
75pub struct ClosedKeyReg<D: Digest> {
76 pub reg_parties: Vec<RegParty>,
78 pub total_stake: Stake,
80 pub merkle_tree: Arc<MerkleTree<D>>,
82}
83
84#[cfg(test)]
85mod tests {
86 use blake2::{digest::consts::U32, Blake2b};
87 use proptest::{collection::vec, prelude::*};
88 use rand_chacha::ChaCha20Rng;
89 use rand_core::SeedableRng;
90
91 use crate::bls_multi_signature::SigningKey;
92
93 use super::*;
94
95 proptest! {
96 #[test]
97 fn test_keyreg(stake in vec(1..1u64 << 60, 2..=10),
98 nkeys in 2..10_usize,
99 fake_it in 0..4usize,
100 seed in any::<[u8;32]>()) {
101 let mut rng = ChaCha20Rng::from_seed(seed);
102 let mut kr = KeyReg::init();
103
104 let gen_keys = (1..nkeys).map(|_| {
105 let sk = SigningKey::gen(&mut rng);
106 VerificationKeyPoP::from(&sk)
107 }).collect::<Vec<_>>();
108
109 let fake_key = {
110 let sk = SigningKey::gen(&mut rng);
111 VerificationKeyPoP::from(&sk)
112 };
113
114 let mut keys = HashMap::new();
116
117 for (i, &stake) in stake.iter().enumerate() {
118 let mut pk = gen_keys[i % gen_keys.len()];
119
120 if fake_it == 0 {
121 pk.pop = fake_key.pop;
122 }
123
124 let reg = kr.register(stake, pk);
125 match reg {
126 Ok(_) => {
127 assert!(keys.insert(pk.vk, stake).is_none());
128 },
129 Err(RegisterError::KeyRegistered(pk1)) => {
130 assert!(pk1.as_ref() == &pk.vk);
131 assert!(keys.contains_key(&pk.vk));
132 }
133 Err(RegisterError::KeyInvalid(a)) => {
134 assert_eq!(fake_it, 0);
135 assert!(a.check().is_err());
136 }
137 Err(RegisterError::SerializationError) => unreachable!(),
138 _ => unreachable!(),
139 }
140 }
141
142 if !kr.keys.is_empty() {
143 let closed = kr.close::<Blake2b<U32>>();
144 let retrieved_keys = closed.reg_parties.iter().map(|r| (r.0, r.1)).collect::<HashMap<_,_>>();
145 assert!(retrieved_keys == keys);
146 }
147 }
148 }
149}