mithril_aggregator/database/repository/
stake_pool_store.rs

1use std::ops::Not;
2use std::sync::Arc;
3
4use anyhow::Context;
5use async_trait::async_trait;
6
7use mithril_common::entities::{Epoch, StakeDistribution};
8use mithril_common::signable_builder::StakeDistributionRetriever;
9use mithril_common::StdResult;
10use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
11use mithril_persistence::store::StakeStorer;
12
13use crate::database::query::{
14    DeleteStakePoolQuery, GetStakePoolQuery, InsertOrReplaceStakePoolQuery,
15};
16use crate::database::record::StakePool;
17use crate::services::EpochPruningTask;
18
19/// Service to deal with stake pools (read & write).
20pub struct StakePoolStore {
21    connection: Arc<SqliteConnection>,
22
23    /// Number of epochs before previous records will be pruned at the next call to
24    /// [save_protocol_parameters][StakePoolStore::save_stakes].
25    retention_limit: Option<u64>,
26}
27
28impl StakePoolStore {
29    /// Create a new StakePool service
30    pub fn new(connection: Arc<SqliteConnection>, retention_limit: Option<u64>) -> Self {
31        Self {
32            connection,
33            retention_limit,
34        }
35    }
36}
37
38#[async_trait]
39impl StakeStorer for StakePoolStore {
40    async fn save_stakes(
41        &self,
42        epoch: Epoch,
43        stakes: StakeDistribution,
44    ) -> StdResult<Option<StakeDistribution>> {
45        let pools: Vec<StakePool> = self
46            .connection
47            .fetch_collect(InsertOrReplaceStakePoolQuery::many(
48                stakes
49                    .into_iter()
50                    .map(|(pool_id, stake)| (pool_id, epoch, stake))
51                    .collect(),
52            ))
53            .with_context(|| format!("persist stakes failure, epoch: {epoch}"))?;
54
55        Ok(Some(StakeDistribution::from_iter(
56            pools.into_iter().map(|p| (p.stake_pool_id, p.stake)),
57        )))
58    }
59
60    async fn get_stakes(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
61        let cursor = self
62            .connection
63            .fetch(GetStakePoolQuery::by_epoch(epoch)?)
64            .with_context(|| format!("get stakes failure, epoch: {epoch}"))?;
65        let mut stake_distribution = StakeDistribution::new();
66
67        for stake_pool in cursor {
68            stake_distribution.insert(stake_pool.stake_pool_id, stake_pool.stake);
69        }
70
71        Ok(stake_distribution
72            .is_empty()
73            .not()
74            .then_some(stake_distribution))
75    }
76}
77
78#[async_trait]
79impl StakeDistributionRetriever for StakePoolStore {
80    async fn retrieve(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
81        self.get_stakes(epoch).await
82    }
83}
84
85#[async_trait]
86impl EpochPruningTask for StakePoolStore {
87    fn pruned_data(&self) -> &'static str {
88        "Stake pool"
89    }
90
91    async fn prune(&self, epoch: Epoch) -> StdResult<()> {
92        if let Some(threshold) = self.retention_limit {
93            self.connection
94                .apply(DeleteStakePoolQuery::below_epoch_threshold(
95                    epoch - threshold,
96                ))?;
97        }
98        Ok(())
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use crate::database::test_helper::{insert_stake_pool, main_db_connection};
105
106    use super::*;
107
108    #[tokio::test]
109    async fn prune_epoch_settings_older_than_threshold() {
110        let connection = main_db_connection().unwrap();
111        const STAKE_POOL_PRUNE_EPOCH_THRESHOLD: u64 = 10;
112        insert_stake_pool(&connection, &[1, 2]).unwrap();
113        let store =
114            StakePoolStore::new(Arc::new(connection), Some(STAKE_POOL_PRUNE_EPOCH_THRESHOLD));
115
116        store
117            .prune(Epoch(2) + STAKE_POOL_PRUNE_EPOCH_THRESHOLD)
118            .await
119            .unwrap();
120
121        let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
122        let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
123
124        assert_eq!(
125            None, epoch1_stakes,
126            "Stakes at epoch 1 should have been pruned",
127        );
128        assert!(
129            epoch2_stakes.is_some(),
130            "Stakes at epoch 2 should still exist",
131        );
132    }
133
134    #[tokio::test]
135    async fn without_threshold_nothing_is_pruned() {
136        let connection = main_db_connection().unwrap();
137        insert_stake_pool(&connection, &[1, 2]).unwrap();
138        let store = StakePoolStore::new(Arc::new(connection), None);
139
140        store.prune(Epoch(100)).await.unwrap();
141
142        let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
143        let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
144
145        assert!(
146            epoch1_stakes.is_some(),
147            "Stakes at epoch 1 should have been pruned",
148        );
149        assert!(
150            epoch2_stakes.is_some(),
151            "Stakes at epoch 2 should still exist",
152        );
153    }
154
155    #[tokio::test]
156    async fn retrieve_with_no_stakes_returns_none() {
157        let connection = main_db_connection().unwrap();
158        let store = StakePoolStore::new(Arc::new(connection), None);
159
160        let result = store.retrieve(Epoch(1)).await.unwrap();
161
162        assert!(result.is_none());
163    }
164
165    #[tokio::test]
166    async fn retrieve_returns_stake_distribution() {
167        let stake_distribution_to_retrieve =
168            StakeDistribution::from([("pool-123".to_string(), 123)]);
169        let connection = main_db_connection().unwrap();
170        let store = StakePoolStore::new(Arc::new(connection), None);
171        store
172            .save_stakes(Epoch(1), stake_distribution_to_retrieve.clone())
173            .await
174            .unwrap();
175
176        let stake_distribution = store.retrieve(Epoch(1)).await.unwrap();
177
178        assert_eq!(stake_distribution, Some(stake_distribution_to_retrieve));
179    }
180}