mithril_common/signable_builder/
cardano_stake_distribution.rs1use anyhow::anyhow;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::{
7 crypto_helper::{MKTree, MKTreeNode, MKTreeStoreInMemory},
8 entities::{Epoch, ProtocolMessage, ProtocolMessagePartKey, StakeDistribution},
9 signable_builder::SignableBuilder,
10 StdResult,
11};
12
13#[cfg(test)]
14use mockall::automock;
15
16#[cfg_attr(test, automock)]
18#[async_trait]
19pub trait StakeDistributionRetriever: Send + Sync {
20 async fn retrieve(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>>;
22}
23
24struct StakeDistributionEntry(String, u64);
25
26impl StakeDistributionEntry {
27 pub fn new(pool_id: &str, stake: u64) -> Self {
28 Self(pool_id.to_string(), stake)
29 }
30}
31
32impl From<StakeDistributionEntry> for MKTreeNode {
33 fn from(entry: StakeDistributionEntry) -> Self {
34 MKTreeNode::new(format!("{}{}", entry.0, entry.1).into())
35 }
36}
37
38pub struct CardanoStakeDistributionSignableBuilder {
40 cardano_stake_distribution_retriever: Arc<dyn StakeDistributionRetriever>,
41}
42
43impl CardanoStakeDistributionSignableBuilder {
44 pub fn new(cardano_stake_distribution_retriever: Arc<dyn StakeDistributionRetriever>) -> Self {
46 Self {
47 cardano_stake_distribution_retriever,
48 }
49 }
50
51 pub fn compute_merkle_tree_from_stake_distribution(
53 pools_with_stake: StakeDistribution,
54 ) -> StdResult<MKTree<MKTreeStoreInMemory>> {
55 let leaves: Vec<MKTreeNode> = pools_with_stake
56 .iter()
57 .map(|(k, v)| StakeDistributionEntry::new(k, *v).into())
58 .collect();
59
60 MKTree::new(&leaves)
61 }
62}
63
64#[async_trait]
65impl SignableBuilder<Epoch> for CardanoStakeDistributionSignableBuilder {
66 async fn compute_protocol_message(&self, epoch: Epoch) -> StdResult<ProtocolMessage> {
67 let pools_with_stake = self
68 .cardano_stake_distribution_retriever
69 .retrieve(epoch.offset_to_cardano_stake_distribution_snapshot_epoch())
70 .await?.ok_or(anyhow!(
71 "CardanoStakeDistributionSignableBuilder could not find the stake distribution for epoch: '{epoch}'"
72 ))?;
73
74 let mk_tree = Self::compute_merkle_tree_from_stake_distribution(pools_with_stake)?;
75
76 let mut protocol_message = ProtocolMessage::new();
77 protocol_message.set_message_part(
78 ProtocolMessagePartKey::CardanoStakeDistributionEpoch,
79 epoch.to_string(),
80 );
81 protocol_message.set_message_part(
82 ProtocolMessagePartKey::CardanoStakeDistributionMerkleRoot,
83 mk_tree.compute_root()?.to_hex(),
84 );
85
86 Ok(protocol_message)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use mockall::predicate::eq;
93
94 use crate::entities::ProtocolMessagePartKey;
95
96 use super::*;
97
98 fn is_merkle_tree_equals(
99 first_pools_with_stake: StakeDistribution,
100 second_pools_with_stake: StakeDistribution,
101 ) -> bool {
102 let first_merkle_tree =
103 CardanoStakeDistributionSignableBuilder::compute_merkle_tree_from_stake_distribution(
104 first_pools_with_stake,
105 )
106 .unwrap();
107 let second_merkle_tree =
108 CardanoStakeDistributionSignableBuilder::compute_merkle_tree_from_stake_distribution(
109 second_pools_with_stake,
110 )
111 .unwrap();
112
113 first_merkle_tree.compute_root().unwrap() == second_merkle_tree.compute_root().unwrap()
114 }
115
116 #[test]
117 fn compute_merkle_tree_equals() {
118 assert!(is_merkle_tree_equals(
119 StakeDistribution::from([("pool-123".to_string(), 100)]),
120 StakeDistribution::from([("pool-123".to_string(), 100)]),
121 ));
122
123 assert!(is_merkle_tree_equals(
124 StakeDistribution::from([("pool-123".to_string(), 100), ("pool-456".to_string(), 150)]),
125 StakeDistribution::from([("pool-456".to_string(), 150), ("pool-123".to_string(), 100)])
126 ));
127 }
128
129 #[test]
130 fn compute_merkle_tree_not_equals() {
131 assert!(!is_merkle_tree_equals(
132 StakeDistribution::from([("pool-123".to_string(), 100)]),
133 StakeDistribution::from([("pool-456".to_string(), 100)]),
134 ));
135
136 assert!(!is_merkle_tree_equals(
137 StakeDistribution::from([("pool-123".to_string(), 100)]),
138 StakeDistribution::from([("pool-123".to_string(), 999)]),
139 ));
140 }
141
142 #[tokio::test]
143 async fn compute_protocol_message_returns_error_when_no_cardano_stake_distribution_found() {
144 let epoch = Epoch(1);
145
146 let mut cardano_stake_distribution_retriever = MockStakeDistributionRetriever::new();
147 cardano_stake_distribution_retriever
148 .expect_retrieve()
149 .return_once(move |_| Ok(None));
150 let cardano_stake_distribution_signable_builder =
151 CardanoStakeDistributionSignableBuilder::new(Arc::new(
152 cardano_stake_distribution_retriever,
153 ));
154
155 cardano_stake_distribution_signable_builder
156 .compute_protocol_message(epoch)
157 .await
158 .expect_err("Should return an error when no cardano stake distribution found");
159 }
160
161 #[tokio::test]
162 async fn compute_protocol_message_returns_signable_and_retrieve_with_epoch_offset() {
163 let epoch = Epoch(1);
164 let epoch_to_retrieve = Epoch(3);
165 let stake_distribution = StakeDistribution::from([("pool-123".to_string(), 100)]);
166 let stake_distribution_clone = stake_distribution.clone();
167
168 let mut pools_with_stake_retriever = MockStakeDistributionRetriever::new();
169 pools_with_stake_retriever
170 .expect_retrieve()
171 .with(eq(epoch_to_retrieve))
172 .return_once(move |_| Ok(Some(stake_distribution)));
173 let cardano_stake_distribution_signable_builder =
174 CardanoStakeDistributionSignableBuilder::new(Arc::new(pools_with_stake_retriever));
175
176 let signable = cardano_stake_distribution_signable_builder
177 .compute_protocol_message(epoch)
178 .await
179 .unwrap();
180
181 let expected_mktree =
182 CardanoStakeDistributionSignableBuilder::compute_merkle_tree_from_stake_distribution(
183 stake_distribution_clone,
184 )
185 .unwrap();
186 let mut signable_expected = ProtocolMessage::new();
187 signable_expected.set_message_part(
188 ProtocolMessagePartKey::CardanoStakeDistributionEpoch,
189 epoch.to_string(),
190 );
191 signable_expected.set_message_part(
192 ProtocolMessagePartKey::CardanoStakeDistributionMerkleRoot,
193 expected_mktree.compute_root().unwrap().to_hex(),
194 );
195 assert_eq!(signable_expected, signable);
196 }
197}