mithril_common/signable_builder/
cardano_stake_distribution.rs

1use anyhow::anyhow;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::{
7    crypto_helper::{MKTree, MKTreeNode, MKTreeStoreInMemory},
8    entities::{Epoch, ProtocolMessage, ProtocolMessagePartKey, StakeDistribution},
9    signable_builder::SignableBuilder,
10    StdResult,
11};
12
13#[cfg(test)]
14use mockall::automock;
15
16/// Stake Distribution Retriever
17#[cfg_attr(test, automock)]
18#[async_trait]
19pub trait StakeDistributionRetriever: Send + Sync {
20    /// Retrieve the [StakeDistribution] for a given epoch
21    async fn retrieve(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>>;
22}
23
24struct StakeDistributionEntry(String, u64);
25
26impl StakeDistributionEntry {
27    pub fn new(pool_id: &str, stake: u64) -> Self {
28        Self(pool_id.to_string(), stake)
29    }
30}
31
32impl From<StakeDistributionEntry> for MKTreeNode {
33    fn from(entry: StakeDistributionEntry) -> Self {
34        MKTreeNode::new(format!("{}{}", entry.0, entry.1).into())
35    }
36}
37
38/// A [CardanoStakeDistributionSignableBuilder] builder
39pub struct CardanoStakeDistributionSignableBuilder {
40    cardano_stake_distribution_retriever: Arc<dyn StakeDistributionRetriever>,
41}
42
43impl CardanoStakeDistributionSignableBuilder {
44    /// Constructor
45    pub fn new(cardano_stake_distribution_retriever: Arc<dyn StakeDistributionRetriever>) -> Self {
46        Self {
47            cardano_stake_distribution_retriever,
48        }
49    }
50
51    /// Compute the Merkle tree of a given [StakeDistribution]
52    pub fn compute_merkle_tree_from_stake_distribution(
53        pools_with_stake: StakeDistribution,
54    ) -> StdResult<MKTree<MKTreeStoreInMemory>> {
55        let leaves: Vec<MKTreeNode> = pools_with_stake
56            .iter()
57            .map(|(k, v)| StakeDistributionEntry::new(k, *v).into())
58            .collect();
59
60        MKTree::new(&leaves)
61    }
62}
63
64#[async_trait]
65impl SignableBuilder<Epoch> for CardanoStakeDistributionSignableBuilder {
66    async fn compute_protocol_message(&self, epoch: Epoch) -> StdResult<ProtocolMessage> {
67        let pools_with_stake = self
68            .cardano_stake_distribution_retriever
69            .retrieve(epoch.offset_to_cardano_stake_distribution_snapshot_epoch())
70            .await?.ok_or(anyhow!(
71                "CardanoStakeDistributionSignableBuilder could not find the stake distribution for epoch: '{epoch}'"
72            ))?;
73
74        let mk_tree = Self::compute_merkle_tree_from_stake_distribution(pools_with_stake)?;
75
76        let mut protocol_message = ProtocolMessage::new();
77        protocol_message.set_message_part(
78            ProtocolMessagePartKey::CardanoStakeDistributionEpoch,
79            epoch.to_string(),
80        );
81        protocol_message.set_message_part(
82            ProtocolMessagePartKey::CardanoStakeDistributionMerkleRoot,
83            mk_tree.compute_root()?.to_hex(),
84        );
85
86        Ok(protocol_message)
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use mockall::predicate::eq;
93
94    use crate::entities::ProtocolMessagePartKey;
95
96    use super::*;
97
98    fn is_merkle_tree_equals(
99        first_pools_with_stake: StakeDistribution,
100        second_pools_with_stake: StakeDistribution,
101    ) -> bool {
102        let first_merkle_tree =
103            CardanoStakeDistributionSignableBuilder::compute_merkle_tree_from_stake_distribution(
104                first_pools_with_stake,
105            )
106            .unwrap();
107        let second_merkle_tree =
108            CardanoStakeDistributionSignableBuilder::compute_merkle_tree_from_stake_distribution(
109                second_pools_with_stake,
110            )
111            .unwrap();
112
113        first_merkle_tree.compute_root().unwrap() == second_merkle_tree.compute_root().unwrap()
114    }
115
116    #[test]
117    fn compute_merkle_tree_equals() {
118        assert!(is_merkle_tree_equals(
119            StakeDistribution::from([("pool-123".to_string(), 100)]),
120            StakeDistribution::from([("pool-123".to_string(), 100)]),
121        ));
122
123        assert!(is_merkle_tree_equals(
124            StakeDistribution::from([("pool-123".to_string(), 100), ("pool-456".to_string(), 150)]),
125            StakeDistribution::from([("pool-456".to_string(), 150), ("pool-123".to_string(), 100)])
126        ));
127    }
128
129    #[test]
130    fn compute_merkle_tree_not_equals() {
131        assert!(!is_merkle_tree_equals(
132            StakeDistribution::from([("pool-123".to_string(), 100)]),
133            StakeDistribution::from([("pool-456".to_string(), 100)]),
134        ));
135
136        assert!(!is_merkle_tree_equals(
137            StakeDistribution::from([("pool-123".to_string(), 100)]),
138            StakeDistribution::from([("pool-123".to_string(), 999)]),
139        ));
140    }
141
142    #[tokio::test]
143    async fn compute_protocol_message_returns_error_when_no_cardano_stake_distribution_found() {
144        let epoch = Epoch(1);
145
146        let mut cardano_stake_distribution_retriever = MockStakeDistributionRetriever::new();
147        cardano_stake_distribution_retriever
148            .expect_retrieve()
149            .return_once(move |_| Ok(None));
150        let cardano_stake_distribution_signable_builder =
151            CardanoStakeDistributionSignableBuilder::new(Arc::new(
152                cardano_stake_distribution_retriever,
153            ));
154
155        cardano_stake_distribution_signable_builder
156            .compute_protocol_message(epoch)
157            .await
158            .expect_err("Should return an error when no cardano stake distribution found");
159    }
160
161    #[tokio::test]
162    async fn compute_protocol_message_returns_signable_and_retrieve_with_epoch_offset() {
163        let epoch = Epoch(1);
164        let epoch_to_retrieve = Epoch(3);
165        let stake_distribution = StakeDistribution::from([("pool-123".to_string(), 100)]);
166        let stake_distribution_clone = stake_distribution.clone();
167
168        let mut pools_with_stake_retriever = MockStakeDistributionRetriever::new();
169        pools_with_stake_retriever
170            .expect_retrieve()
171            .with(eq(epoch_to_retrieve))
172            .return_once(move |_| Ok(Some(stake_distribution)));
173        let cardano_stake_distribution_signable_builder =
174            CardanoStakeDistributionSignableBuilder::new(Arc::new(pools_with_stake_retriever));
175
176        let signable = cardano_stake_distribution_signable_builder
177            .compute_protocol_message(epoch)
178            .await
179            .unwrap();
180
181        let expected_mktree =
182            CardanoStakeDistributionSignableBuilder::compute_merkle_tree_from_stake_distribution(
183                stake_distribution_clone,
184            )
185            .unwrap();
186        let mut signable_expected = ProtocolMessage::new();
187        signable_expected.set_message_part(
188            ProtocolMessagePartKey::CardanoStakeDistributionEpoch,
189            epoch.to_string(),
190        );
191        signable_expected.set_message_part(
192            ProtocolMessagePartKey::CardanoStakeDistributionMerkleRoot,
193            expected_mktree.compute_root().unwrap().to_hex(),
194        );
195        assert_eq!(signable_expected, signable);
196    }
197}