1use super::stm::Stake;
3use crate::bls_multi_signature::{VerificationKey, VerificationKeyPoP};
4use crate::error::RegisterError;
5use crate::merkle_tree::{MTLeaf, MerkleTree};
6use blake2::digest::{Digest, FixedOutput};
7use std::collections::hash_map::Entry;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub type RegParty = MTLeaf;
13
14#[derive(Clone, Debug, Default, PartialEq, Eq)]
18pub struct KeyReg {
19 keys: HashMap<VerificationKey, Stake>,
20}
21
22impl KeyReg {
23 pub fn init() -> Self {
25 Self::default()
26 }
27
28 pub fn register(&mut self, stake: Stake, pk: VerificationKeyPoP) -> Result<(), RegisterError> {
32 if let Entry::Vacant(e) = self.keys.entry(pk.vk) {
33 pk.check()?;
34 e.insert(stake);
35 return Ok(());
36 }
37 Err(RegisterError::KeyRegistered(Box::new(pk.vk)))
38 }
39
40 pub fn close<D>(self) -> ClosedKeyReg<D>
43 where
44 D: Digest + FixedOutput,
45 {
46 let mut total_stake: Stake = 0;
47 let mut reg_parties = self
48 .keys
49 .iter()
50 .map(|(&vk, &stake)| {
51 let (res, overflow) = total_stake.overflowing_add(stake);
52 if overflow {
53 panic!("Total stake overflow");
54 }
55 total_stake = res;
56 MTLeaf(vk, stake)
57 })
58 .collect::<Vec<RegParty>>();
59 reg_parties.sort();
60
61 ClosedKeyReg {
62 merkle_tree: Arc::new(MerkleTree::create(®_parties)),
63 reg_parties,
64 total_stake,
65 }
66 }
67}
68
69#[derive(Clone, Debug, PartialEq, Eq)]
72pub struct ClosedKeyReg<D: Digest> {
73 pub reg_parties: Vec<RegParty>,
75 pub total_stake: Stake,
77 pub merkle_tree: Arc<MerkleTree<D>>,
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::bls_multi_signature::SigningKey;
85 use blake2::{digest::consts::U32, Blake2b};
86 use proptest::collection::vec;
87 use proptest::prelude::*;
88 use rand_chacha::ChaCha20Rng;
89 use rand_core::SeedableRng;
90
91 proptest! {
92 #[test]
93 fn test_keyreg(stake in vec(1..1u64 << 60, 2..=10),
94 nkeys in 2..10_usize,
95 fake_it in 0..4usize,
96 seed in any::<[u8;32]>()) {
97 let mut rng = ChaCha20Rng::from_seed(seed);
98 let mut kr = KeyReg::init();
99
100 let gen_keys = (1..nkeys).map(|_| {
101 let sk = SigningKey::gen(&mut rng);
102 VerificationKeyPoP::from(&sk)
103 }).collect::<Vec<_>>();
104
105 let fake_key = {
106 let sk = SigningKey::gen(&mut rng);
107 VerificationKeyPoP::from(&sk)
108 };
109
110 let mut keys = HashMap::new();
112
113 for (i, &stake) in stake.iter().enumerate() {
114 let mut pk = gen_keys[i % gen_keys.len()];
115
116 if fake_it == 0 {
117 pk.pop = fake_key.pop;
118 }
119
120 let reg = kr.register(stake, pk);
121 match reg {
122 Ok(_) => {
123 assert!(keys.insert(pk.vk, stake).is_none());
124 },
125 Err(RegisterError::KeyRegistered(pk1)) => {
126 assert!(pk1.as_ref() == &pk.vk);
127 assert!(keys.contains_key(&pk.vk));
128 }
129 Err(RegisterError::KeyInvalid(a)) => {
130 assert_eq!(fake_it, 0);
131 assert!(a.check().is_err());
132 }
133 Err(RegisterError::SerializationError) => unreachable!(),
134 _ => unreachable!(),
135 }
136 }
137
138 if !kr.keys.is_empty() {
139 let closed = kr.close::<Blake2b<U32>>();
140 let retrieved_keys = closed.reg_parties.iter().map(|r| (r.0, r.1)).collect::<HashMap<_,_>>();
141 assert!(retrieved_keys == keys);
142 }
143 }
144 }
145}