mithril_stm/protocol/
key_registration.rs1use std::{
3 collections::{HashMap, hash_map::Entry},
4 sync::Arc,
5};
6
7use anyhow::anyhow;
8use blake2::digest::{Digest, FixedOutput};
9
10use crate::{
11 RegisterError, Stake, StmResult,
12 membership_commitment::{MerkleTree, MerkleTreeLeaf},
13 signature_scheme::{BlsVerificationKey, BlsVerificationKeyProofOfPossession},
14};
15
16pub type RegisteredParty = MerkleTreeLeaf;
18
19#[derive(Clone, Debug, Default, PartialEq, Eq)]
23pub struct KeyRegistration {
24 keys: HashMap<BlsVerificationKey, Stake>,
25}
26
27impl KeyRegistration {
28 pub fn init() -> Self {
30 Self::default()
31 }
32
33 pub fn register(
37 &mut self,
38 stake: Stake,
39 pk: BlsVerificationKeyProofOfPossession,
40 ) -> StmResult<()> {
41 if let Entry::Vacant(e) = self.keys.entry(pk.vk) {
42 pk.verify_proof_of_possession()
43 .map_err(|_| RegisterError::KeyInvalid(Box::new(pk)))?;
44 e.insert(stake);
45 return Ok(());
46 }
47 Err(anyhow!(RegisterError::KeyRegistered(Box::new(pk.vk))))
48 }
49
50 pub fn close<D>(self) -> ClosedKeyRegistration<D>
53 where
54 D: Digest + FixedOutput,
55 {
56 let mut total_stake: Stake = 0;
57 let mut reg_parties = self
58 .keys
59 .iter()
60 .map(|(&vk, &stake)| {
61 let (res, overflow) = total_stake.overflowing_add(stake);
62 if overflow {
63 panic!("Total stake overflow");
64 }
65 total_stake = res;
66 MerkleTreeLeaf(vk, stake)
67 })
68 .collect::<Vec<RegisteredParty>>();
69 reg_parties.sort();
70
71 ClosedKeyRegistration {
72 merkle_tree: Arc::new(MerkleTree::new(®_parties)),
73 reg_parties,
74 total_stake,
75 }
76 }
77}
78
79#[derive(Clone, Debug, PartialEq, Eq)]
82pub struct ClosedKeyRegistration<D: Digest> {
83 pub reg_parties: Vec<RegisteredParty>,
85 pub total_stake: Stake,
87 pub merkle_tree: Arc<MerkleTree<D>>,
89}
90
91#[cfg(test)]
92mod tests {
93 use blake2::{Blake2b, digest::consts::U32};
94 use proptest::{collection::vec, prelude::*};
95 use rand_chacha::ChaCha20Rng;
96 use rand_core::SeedableRng;
97
98 use crate::signature_scheme::BlsSigningKey;
99
100 use super::*;
101
102 proptest! {
103 #[test]
104 fn test_keyreg(stake in vec(1..1u64 << 60, 2..=10),
105 nkeys in 2..10_usize,
106 fake_it in 0..4usize,
107 seed in any::<[u8;32]>()) {
108 let mut rng = ChaCha20Rng::from_seed(seed);
109 let mut kr = KeyRegistration::init();
110
111 let gen_keys = (1..nkeys).map(|_| {
112 let sk = BlsSigningKey::generate(&mut rng);
113 BlsVerificationKeyProofOfPossession::from(&sk)
114 }).collect::<Vec<_>>();
115
116 let fake_key = {
117 let sk = BlsSigningKey::generate(&mut rng);
118 BlsVerificationKeyProofOfPossession::from(&sk)
119 };
120
121 let mut keys = HashMap::new();
123
124 for (i, &stake) in stake.iter().enumerate() {
125 let mut pk = gen_keys[i % gen_keys.len()];
126
127 if fake_it == 0 {
128 pk.pop = fake_key.pop;
129 }
130
131 let reg = kr.register(stake, pk);
132
133 match reg {
134 Ok(_) => {
135 assert!(keys.insert(pk.vk, stake).is_none());
136 },
137 Err(error) => match error.downcast_ref::<RegisterError>(){
138 Some(RegisterError::KeyRegistered(pk1)) => {
139 assert!(pk1.as_ref() == &pk.vk);
140 assert!(keys.contains_key(&pk.vk));
141 },
142 Some(RegisterError::KeyInvalid(a)) => {
143 assert_eq!(fake_it, 0);
144 assert!(a.verify_proof_of_possession().is_err());
145 },
146 _ => {panic!("Unexpected error: {error}")}
147 }
148 }
149 }
150 if !kr.keys.is_empty() {
151 let closed = kr.close::<Blake2b<U32>>();
152 let retrieved_keys = closed.reg_parties.iter().map(|r| (r.0, r.1)).collect::<HashMap<_,_>>();
153 assert!(retrieved_keys == keys);
154 }
155 }
156 }
157}