mithril_aggregator/database/repository/
stake_pool_store.rs1use std::ops::Not;
2use std::sync::Arc;
3
4use anyhow::Context;
5use async_trait::async_trait;
6
7use mithril_common::StdResult;
8use mithril_common::entities::{Epoch, StakeDistribution};
9use mithril_common::signable_builder::StakeDistributionRetriever;
10use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
11use mithril_persistence::store::StakeStorer;
12
13use crate::database::query::{
14 DeleteStakePoolQuery, GetStakePoolQuery, InsertOrReplaceStakePoolQuery,
15};
16use crate::database::record::StakePool;
17use crate::services::EpochPruningTask;
18
19pub struct StakePoolStore {
21 connection: Arc<SqliteConnection>,
22
23 retention_limit: Option<u64>,
26}
27
28impl StakePoolStore {
29 pub fn new(connection: Arc<SqliteConnection>, retention_limit: Option<u64>) -> Self {
31 Self {
32 connection,
33 retention_limit,
34 }
35 }
36}
37
38#[async_trait]
39impl StakeStorer for StakePoolStore {
40 async fn save_stakes(
41 &self,
42 epoch: Epoch,
43 stakes: StakeDistribution,
44 ) -> StdResult<Option<StakeDistribution>> {
45 let pools: Vec<StakePool> = self
46 .connection
47 .fetch_collect(InsertOrReplaceStakePoolQuery::many(
48 stakes
49 .into_iter()
50 .map(|(pool_id, stake)| (pool_id, epoch, stake))
51 .collect(),
52 ))
53 .with_context(|| format!("persist stakes failure, epoch: {epoch}"))?;
54
55 Ok(Some(StakeDistribution::from_iter(
56 pools.into_iter().map(|p| (p.stake_pool_id, p.stake)),
57 )))
58 }
59
60 async fn get_stakes(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
61 let cursor = self
62 .connection
63 .fetch(GetStakePoolQuery::by_epoch(epoch)?)
64 .with_context(|| format!("get stakes failure, epoch: {epoch}"))?;
65 let mut stake_distribution = StakeDistribution::new();
66
67 for stake_pool in cursor {
68 stake_distribution.insert(stake_pool.stake_pool_id, stake_pool.stake);
69 }
70
71 Ok(stake_distribution.is_empty().not().then_some(stake_distribution))
72 }
73}
74
75#[async_trait]
76impl StakeDistributionRetriever for StakePoolStore {
77 async fn retrieve(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
78 self.get_stakes(epoch).await
79 }
80}
81
82#[async_trait]
83impl EpochPruningTask for StakePoolStore {
84 fn pruned_data(&self) -> &'static str {
85 "Stake pool"
86 }
87
88 async fn prune(&self, epoch: Epoch) -> StdResult<()> {
89 if let Some(threshold) = self.retention_limit {
90 self.connection.apply(DeleteStakePoolQuery::below_epoch_threshold(
91 epoch - threshold,
92 ))?;
93 }
94 Ok(())
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use crate::database::test_helper::{insert_stake_pool, main_db_connection};
101
102 use super::*;
103
104 #[tokio::test]
105 async fn prune_epoch_settings_older_than_threshold() {
106 let connection = main_db_connection().unwrap();
107 const STAKE_POOL_PRUNE_EPOCH_THRESHOLD: u64 = 10;
108 insert_stake_pool(&connection, &[1, 2]).unwrap();
109 let store =
110 StakePoolStore::new(Arc::new(connection), Some(STAKE_POOL_PRUNE_EPOCH_THRESHOLD));
111
112 store
113 .prune(Epoch(2) + STAKE_POOL_PRUNE_EPOCH_THRESHOLD)
114 .await
115 .unwrap();
116
117 let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
118 let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
119
120 assert_eq!(
121 None, epoch1_stakes,
122 "Stakes at epoch 1 should have been pruned",
123 );
124 assert!(
125 epoch2_stakes.is_some(),
126 "Stakes at epoch 2 should still exist",
127 );
128 }
129
130 #[tokio::test]
131 async fn without_threshold_nothing_is_pruned() {
132 let connection = main_db_connection().unwrap();
133 insert_stake_pool(&connection, &[1, 2]).unwrap();
134 let store = StakePoolStore::new(Arc::new(connection), None);
135
136 store.prune(Epoch(100)).await.unwrap();
137
138 let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
139 let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
140
141 assert!(
142 epoch1_stakes.is_some(),
143 "Stakes at epoch 1 should have been pruned",
144 );
145 assert!(
146 epoch2_stakes.is_some(),
147 "Stakes at epoch 2 should still exist",
148 );
149 }
150
151 #[tokio::test]
152 async fn retrieve_with_no_stakes_returns_none() {
153 let connection = main_db_connection().unwrap();
154 let store = StakePoolStore::new(Arc::new(connection), None);
155
156 let result = store.retrieve(Epoch(1)).await.unwrap();
157
158 assert!(result.is_none());
159 }
160
161 #[tokio::test]
162 async fn retrieve_returns_stake_distribution() {
163 let stake_distribution_to_retrieve =
164 StakeDistribution::from([("pool-123".to_string(), 123)]);
165 let connection = main_db_connection().unwrap();
166 let store = StakePoolStore::new(Arc::new(connection), None);
167 store
168 .save_stakes(Epoch(1), stake_distribution_to_retrieve.clone())
169 .await
170 .unwrap();
171
172 let stake_distribution = store.retrieve(Epoch(1)).await.unwrap();
173
174 assert_eq!(stake_distribution, Some(stake_distribution_to_retrieve));
175 }
176}