mithril_aggregator/database/repository/
stake_pool_store.rs1use std::ops::Not;
2use std::sync::Arc;
3
4use anyhow::Context;
5use async_trait::async_trait;
6
7use mithril_common::entities::{Epoch, StakeDistribution};
8use mithril_common::signable_builder::StakeDistributionRetriever;
9use mithril_common::StdResult;
10use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
11use mithril_persistence::store::StakeStorer;
12
13use crate::database::query::{
14 DeleteStakePoolQuery, GetStakePoolQuery, InsertOrReplaceStakePoolQuery,
15};
16use crate::database::record::StakePool;
17use crate::services::EpochPruningTask;
18
19pub struct StakePoolStore {
21 connection: Arc<SqliteConnection>,
22
23 retention_limit: Option<u64>,
26}
27
28impl StakePoolStore {
29 pub fn new(connection: Arc<SqliteConnection>, retention_limit: Option<u64>) -> Self {
31 Self {
32 connection,
33 retention_limit,
34 }
35 }
36}
37
38#[async_trait]
39impl StakeStorer for StakePoolStore {
40 async fn save_stakes(
41 &self,
42 epoch: Epoch,
43 stakes: StakeDistribution,
44 ) -> StdResult<Option<StakeDistribution>> {
45 let pools: Vec<StakePool> = self
46 .connection
47 .fetch_collect(InsertOrReplaceStakePoolQuery::many(
48 stakes
49 .into_iter()
50 .map(|(pool_id, stake)| (pool_id, epoch, stake))
51 .collect(),
52 ))
53 .with_context(|| format!("persist stakes failure, epoch: {epoch}"))?;
54
55 Ok(Some(StakeDistribution::from_iter(
56 pools.into_iter().map(|p| (p.stake_pool_id, p.stake)),
57 )))
58 }
59
60 async fn get_stakes(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
61 let cursor = self
62 .connection
63 .fetch(GetStakePoolQuery::by_epoch(epoch)?)
64 .with_context(|| format!("get stakes failure, epoch: {epoch}"))?;
65 let mut stake_distribution = StakeDistribution::new();
66
67 for stake_pool in cursor {
68 stake_distribution.insert(stake_pool.stake_pool_id, stake_pool.stake);
69 }
70
71 Ok(stake_distribution
72 .is_empty()
73 .not()
74 .then_some(stake_distribution))
75 }
76}
77
78#[async_trait]
79impl StakeDistributionRetriever for StakePoolStore {
80 async fn retrieve(&self, epoch: Epoch) -> StdResult<Option<StakeDistribution>> {
81 self.get_stakes(epoch).await
82 }
83}
84
85#[async_trait]
86impl EpochPruningTask for StakePoolStore {
87 fn pruned_data(&self) -> &'static str {
88 "Stake pool"
89 }
90
91 async fn prune(&self, epoch: Epoch) -> StdResult<()> {
92 if let Some(threshold) = self.retention_limit {
93 self.connection
94 .apply(DeleteStakePoolQuery::below_epoch_threshold(
95 epoch - threshold,
96 ))?;
97 }
98 Ok(())
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use crate::database::test_helper::{insert_stake_pool, main_db_connection};
105
106 use super::*;
107
108 #[tokio::test]
109 async fn prune_epoch_settings_older_than_threshold() {
110 let connection = main_db_connection().unwrap();
111 const STAKE_POOL_PRUNE_EPOCH_THRESHOLD: u64 = 10;
112 insert_stake_pool(&connection, &[1, 2]).unwrap();
113 let store =
114 StakePoolStore::new(Arc::new(connection), Some(STAKE_POOL_PRUNE_EPOCH_THRESHOLD));
115
116 store
117 .prune(Epoch(2) + STAKE_POOL_PRUNE_EPOCH_THRESHOLD)
118 .await
119 .unwrap();
120
121 let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
122 let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
123
124 assert_eq!(
125 None, epoch1_stakes,
126 "Stakes at epoch 1 should have been pruned",
127 );
128 assert!(
129 epoch2_stakes.is_some(),
130 "Stakes at epoch 2 should still exist",
131 );
132 }
133
134 #[tokio::test]
135 async fn without_threshold_nothing_is_pruned() {
136 let connection = main_db_connection().unwrap();
137 insert_stake_pool(&connection, &[1, 2]).unwrap();
138 let store = StakePoolStore::new(Arc::new(connection), None);
139
140 store.prune(Epoch(100)).await.unwrap();
141
142 let epoch1_stakes = store.get_stakes(Epoch(1)).await.unwrap();
143 let epoch2_stakes = store.get_stakes(Epoch(2)).await.unwrap();
144
145 assert!(
146 epoch1_stakes.is_some(),
147 "Stakes at epoch 1 should have been pruned",
148 );
149 assert!(
150 epoch2_stakes.is_some(),
151 "Stakes at epoch 2 should still exist",
152 );
153 }
154
155 #[tokio::test]
156 async fn retrieve_with_no_stakes_returns_none() {
157 let connection = main_db_connection().unwrap();
158 let store = StakePoolStore::new(Arc::new(connection), None);
159
160 let result = store.retrieve(Epoch(1)).await.unwrap();
161
162 assert!(result.is_none());
163 }
164
165 #[tokio::test]
166 async fn retrieve_returns_stake_distribution() {
167 let stake_distribution_to_retrieve =
168 StakeDistribution::from([("pool-123".to_string(), 123)]);
169 let connection = main_db_connection().unwrap();
170 let store = StakePoolStore::new(Arc::new(connection), None);
171 store
172 .save_stakes(Epoch(1), stake_distribution_to_retrieve.clone())
173 .await
174 .unwrap();
175
176 let stake_distribution = store.retrieve(Epoch(1)).await.unwrap();
177
178 assert_eq!(stake_distribution, Some(stake_distribution_to_retrieve));
179 }
180}