mithril_aggregator/database/repository/
stake_pool_store.rs

1use std::ops::Not;
2use std::sync::Arc;
3
4use anyhow::Context;
5use async_trait::async_trait;
6
7use mithril_common::StdResult;
8use mithril_common::entities::{Epoch, StakeDistribution};
9use mithril_common::signable_builder::StakeDistributionRetriever;
10use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
11use mithril_persistence::store::StakeStorer;
12
13use crate::database::query::{
14    DeleteStakePoolQuery, GetStakePoolQuery, InsertOrReplaceStakePoolQuery,
15};
16use crate::database::record::StakePool;
17use crate::services::EpochPruningTask;
18
19/// Service to deal with stake pools (read & write).
20pub struct StakePoolStore {
21    connection: Arc<SqliteConnection>,
22
23    /// Number of epochs before previous records will be pruned at the next call to
24    /// [save_protocol_parameters][StakePoolStore::save_stakes].
25    retention_limit: Option<u64>,
26}
27
28impl StakePoolStore {
29    /// Create a new StakePool service
30    pub fn new(connection: Arc<SqliteConnection>, retention_limit: Option<u64>) -> Self {
31        Self {
32            connection,
33            retention_limit,
34        }
35    }
36}
37
38#[async_trait]
39impl StakeStorer for StakePoolStore {
40    async fn save_stakes(
41        &self,
42        epoch: Epoch,
43        stakes: StakeDistribution,
44    ) -> StdResult<Option<StakeDistribution>> {
45        let pools: Vec<StakePool> = self
46            .connection
47            .fetch_collect(InsertOrReplaceStakePoolQuery::many(
48                stakes
49                    .into_iter()
50                    .map(|(pool_id, stake)| (pool_id, epoch, stake))
51                    .collect(),
52            ))
53            .with_context(|| format!("persist stakes failure, epoch: {epoch}"))?;
54
55        Ok(Some(StakeDistribution::from_iter(
56            pools.into_iter().map(|p| (p.stake_pool_id, p.stake)),
57        )))
58    }
59
60    async fn get_stakes(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
61        let cursor = self
62            .connection
63            .fetch(GetStakePoolQuery::by_epoch(epoch)?)
64            .with_context(|| format!("get stakes failure, epoch: {epoch}"))?;
65        let mut stake_distribution = StakeDistribution::new();
66
67        for stake_pool in cursor {
68            stake_distribution.insert(stake_pool.stake_pool_id, stake_pool.stake);
69        }
70
71        Ok(stake_distribution.is_empty().not().then_some(stake_distribution))
72    }
73}
74
75#[async_trait]
76impl StakeDistributionRetriever for StakePoolStore {
77    async fn retrieve(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
78        self.get_stakes(epoch).await
79    }
80}
81
82#[async_trait]
83impl EpochPruningTask for StakePoolStore {
84    fn pruned_data(&self) -> &'static str {
85        "Stake pool"
86    }
87
88    async fn prune(&self, epoch: Epoch) -> StdResult<()> {
89        if let Some(threshold) = self.retention_limit {
90            self.connection.apply(DeleteStakePoolQuery::below_epoch_threshold(
91                epoch - threshold,
92            ))?;
93        }
94        Ok(())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use crate::database::test_helper::{insert_stake_pool, main_db_connection};
101
102    use super::*;
103
104    #[tokio::test]
105    async fn prune_epoch_settings_older_than_threshold() {
106        let connection = main_db_connection().unwrap();
107        const STAKE_POOL_PRUNE_EPOCH_THRESHOLD: u64 = 10;
108        insert_stake_pool(&connection, &[1, 2]).unwrap();
109        let store =
110            StakePoolStore::new(Arc::new(connection), Some(STAKE_POOL_PRUNE_EPOCH_THRESHOLD));
111
112        store
113            .prune(Epoch(2) + STAKE_POOL_PRUNE_EPOCH_THRESHOLD)
114            .await
115            .unwrap();
116
117        let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
118        let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
119
120        assert_eq!(
121            None, epoch1_stakes,
122            "Stakes at epoch 1 should have been pruned",
123        );
124        assert!(
125            epoch2_stakes.is_some(),
126            "Stakes at epoch 2 should still exist",
127        );
128    }
129
130    #[tokio::test]
131    async fn without_threshold_nothing_is_pruned() {
132        let connection = main_db_connection().unwrap();
133        insert_stake_pool(&connection, &[1, 2]).unwrap();
134        let store = StakePoolStore::new(Arc::new(connection), None);
135
136        store.prune(Epoch(100)).await.unwrap();
137
138        let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
139        let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
140
141        assert!(
142            epoch1_stakes.is_some(),
143            "Stakes at epoch 1 should have been pruned",
144        );
145        assert!(
146            epoch2_stakes.is_some(),
147            "Stakes at epoch 2 should still exist",
148        );
149    }
150
151    #[tokio::test]
152    async fn retrieve_with_no_stakes_returns_none() {
153        let connection = main_db_connection().unwrap();
154        let store = StakePoolStore::new(Arc::new(connection), None);
155
156        let result = store.retrieve(Epoch(1)).await.unwrap();
157
158        assert!(result.is_none());
159    }
160
161    #[tokio::test]
162    async fn retrieve_returns_stake_distribution() {
163        let stake_distribution_to_retrieve =
164            StakeDistribution::from([("pool-123".to_string(), 123)]);
165        let connection = main_db_connection().unwrap();
166        let store = StakePoolStore::new(Arc::new(connection), None);
167        store
168            .save_stakes(Epoch(1), stake_distribution_to_retrieve.clone())
169            .await
170            .unwrap();
171
172        let stake_distribution = store.retrieve(Epoch(1)).await.unwrap();
173
174        assert_eq!(stake_distribution, Some(stake_distribution_to_retrieve));
175    }
176}