use super::stm::Stake;
use crate::error::RegisterError;
use crate::merkle_tree::{MTLeaf, MerkleTree};
use crate::multi_sig::{VerificationKey, VerificationKeyPoP};
use blake2::digest::{Digest, FixedOutput};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::Arc;
pub type RegParty = MTLeaf;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct KeyReg {
keys: HashMap<VerificationKey, Stake>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ClosedKeyReg<D: Digest> {
pub reg_parties: Vec<RegParty>,
pub total_stake: Stake,
pub merkle_tree: Arc<MerkleTree<D>>,
}
impl KeyReg {
pub fn init() -> Self {
Self {
keys: HashMap::new(),
}
}
pub fn register(&mut self, stake: Stake, pk: VerificationKeyPoP) -> Result<(), RegisterError> {
if let Entry::Vacant(e) = self.keys.entry(pk.vk) {
if pk.check().is_ok() {
e.insert(stake);
return Ok(());
} else {
return Err(RegisterError::KeyInvalid(Box::new(pk)));
}
}
Err(RegisterError::KeyRegistered(Box::new(pk.vk)))
}
pub fn close<D>(self) -> ClosedKeyReg<D>
where
D: Digest + FixedOutput,
{
let mut total_stake: Stake = 0;
let mut reg_parties = self
.keys
.iter()
.map(|(&vk, &stake)| {
let (res, overflow) = total_stake.overflowing_add(stake);
if overflow {
panic!("Total stake overflow");
}
total_stake = res;
MTLeaf(vk, stake)
})
.collect::<Vec<RegParty>>();
reg_parties.sort();
ClosedKeyReg {
merkle_tree: Arc::new(MerkleTree::create(®_parties)),
reg_parties,
total_stake,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multi_sig::SigningKey;
use blake2::{digest::consts::U32, Blake2b};
use proptest::collection::vec;
use proptest::prelude::*;
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
proptest! {
#[test]
fn test_keyreg(stake in vec(1..1u64 << 60, 2..=10),
nkeys in 2..10_usize,
fake_it in 0..4usize,
seed in any::<[u8;32]>()) {
let mut rng = ChaCha20Rng::from_seed(seed);
let mut kr = KeyReg::init();
let gen_keys = (1..nkeys).map(|_| {
let sk = SigningKey::gen(&mut rng);
VerificationKeyPoP::from(&sk)
}).collect::<Vec<_>>();
let fake_key = {
let sk = SigningKey::gen(&mut rng);
VerificationKeyPoP::from(&sk)
};
let mut keys = HashMap::new();
for (i, &stake) in stake.iter().enumerate() {
let mut pk = gen_keys[i % gen_keys.len()];
if fake_it == 0 {
pk.pop = fake_key.pop;
}
let reg = kr.register(stake, pk);
match reg {
Ok(_) => {
assert!(keys.insert(pk.vk, stake).is_none());
},
Err(RegisterError::KeyRegistered(pk1)) => {
assert!(pk1.as_ref() == &pk.vk);
assert!(keys.contains_key(&pk.vk));
}
Err(RegisterError::KeyInvalid(a)) => {
assert_eq!(fake_it, 0);
assert!(a.check().is_err());
}
Err(RegisterError::SerializationError) => unreachable!(),
_ => unreachable!(),
}
}
if !kr.keys.is_empty() {
let closed = kr.close::<Blake2b<U32>>();
let retrieved_keys = closed.reg_parties.iter().map(|r| (r.0, r.1)).collect::<HashMap<_,_>>();
assert!(retrieved_keys == keys);
}
}
}
}