mithril_aggregator/database/repository/
signer_store.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::Utc;
6
7use mithril_common::StdResult;
8use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
9
10use crate::database::query::{
11    GetSignerRecordQuery, ImportSignerRecordQuery, RegisterSignerRecordQuery,
12};
13use crate::database::record::SignerRecord;
14use crate::SignerRecorder;
15
16/// Service to get [SignerRecord].
17#[cfg_attr(test, mockall::automock)]
18#[async_trait]
19pub trait SignerGetter: Sync + Send {
20    /// Return all stored records.
21    async fn get_all(&self) -> StdResult<Vec<SignerRecord>>;
22}
23
24/// Service to deal with signer (read & write).
25pub struct SignerStore {
26    connection: Arc<SqliteConnection>,
27}
28
29impl SignerStore {
30    /// Create a new SignerStore service
31    pub fn new(connection: Arc<SqliteConnection>) -> Self {
32        Self { connection }
33    }
34
35    /// Import a signer in the database, its last_registered_at date will be left empty
36    pub async fn import_signer(
37        &self,
38        signer_id: String,
39        pool_ticker: Option<String>,
40    ) -> StdResult<()> {
41        let created_at = Utc::now();
42        let updated_at = created_at;
43        let signer_record = SignerRecord {
44            signer_id,
45            pool_ticker,
46            created_at,
47            updated_at,
48            last_registered_at: None,
49        };
50        self.connection
51            .fetch_first(ImportSignerRecordQuery::one(signer_record))?;
52
53        Ok(())
54    }
55
56    /// Create many signers at once in the database, their last_registered_at date will be left empty
57    pub async fn import_many_signers(
58        &self,
59        pool_ticker_by_id: HashMap<String, Option<String>>,
60    ) -> StdResult<()> {
61        let created_at = Utc::now();
62        let updated_at = created_at;
63        let signer_records: Vec<_> = pool_ticker_by_id
64            .into_iter()
65            .map(|(signer_id, pool_ticker)| SignerRecord {
66                signer_id,
67                pool_ticker,
68                created_at,
69                updated_at,
70                last_registered_at: None,
71            })
72            .collect();
73        self.connection
74            .fetch_first(ImportSignerRecordQuery::many(signer_records))?;
75
76        Ok(())
77    }
78}
79
80#[async_trait]
81impl SignerRecorder for SignerStore {
82    async fn record_signer_registration(&self, signer_id: String) -> StdResult<()> {
83        let created_at = Utc::now();
84        let updated_at = created_at;
85        let registered_at = Some(created_at);
86        let signer_record = SignerRecord {
87            signer_id,
88            pool_ticker: None,
89            created_at,
90            updated_at,
91            last_registered_at: registered_at,
92        };
93        self.connection
94            .fetch_first(RegisterSignerRecordQuery::one(signer_record))?;
95
96        Ok(())
97    }
98}
99
100#[async_trait]
101impl SignerGetter for SignerStore {
102    async fn get_all(&self) -> StdResult<Vec<SignerRecord>> {
103        self.connection.fetch_collect(GetSignerRecordQuery::all())
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use std::collections::BTreeMap;
110
111    use crate::database::test_helper::{insert_signers, main_db_connection};
112
113    use super::*;
114
115    #[tokio::test]
116    async fn test_get_all_signers() {
117        let signer_records = SignerRecord::fake_records(5);
118        let expected: Vec<_> = signer_records.iter().rev().cloned().collect();
119        let connection = main_db_connection().unwrap();
120        insert_signers(&connection, signer_records).unwrap();
121
122        let store = SignerStore::new(Arc::new(connection));
123
124        let stored_signers = store
125            .get_all()
126            .await
127            .expect("getting all signers should not fail");
128
129        assert_eq!(expected, stored_signers);
130    }
131
132    #[tokio::test]
133    async fn test_signer_recorder() {
134        let signer_records_fake = SignerRecord::fake_records(5);
135
136        let connection = Arc::new(main_db_connection().unwrap());
137        let store_recorder = SignerStore::new(connection.clone());
138
139        for signer_record in signer_records_fake.clone() {
140            store_recorder
141                .record_signer_registration(signer_record.signer_id.clone())
142                .await
143                .expect("record_signer_registration should not fail");
144            let signer_record_stored = connection
145                .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
146                .unwrap();
147            assert!(signer_record_stored.is_some());
148            assert!(
149                signer_record_stored.unwrap().last_registered_at.is_some(),
150                "registering a signer should set the registration date"
151            )
152        }
153    }
154
155    #[tokio::test]
156    async fn test_store_import_signer() {
157        let signer_records_fake = SignerRecord::fake_records(5);
158
159        let connection = Arc::new(main_db_connection().unwrap());
160        let store = SignerStore::new(connection.clone());
161
162        for signer_record in signer_records_fake {
163            store
164                .import_signer(
165                    signer_record.signer_id.clone(),
166                    signer_record.pool_ticker.clone(),
167                )
168                .await
169                .expect("import_signer should not fail");
170            let signer_record_stored = connection
171                .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
172                .unwrap();
173            assert!(signer_record_stored.is_some());
174            assert!(
175                signer_record_stored.unwrap().last_registered_at.is_none(),
176                "imported signer should not have a registration date"
177            )
178        }
179    }
180
181    #[tokio::test]
182    async fn test_store_import_many_signers() {
183        let signers_fake: BTreeMap<_, _> = SignerRecord::fake_records(5)
184            .into_iter()
185            .map(|r| (r.signer_id, r.pool_ticker))
186            .collect();
187
188        let connection = main_db_connection().unwrap();
189        let store = SignerStore::new(Arc::new(connection));
190
191        store
192            .import_many_signers(signers_fake.clone().into_iter().collect())
193            .await
194            .expect("import_many_signers should not fail");
195
196        let signer_records_stored = store.get_all().await.unwrap();
197        let signers_stored = signer_records_stored
198            .iter()
199            .cloned()
200            .map(|r| (r.signer_id, r.pool_ticker))
201            .collect();
202        assert_eq!(signers_fake, signers_stored);
203        assert!(
204            signer_records_stored
205                .iter()
206                .all(|s| s.last_registered_at.is_none()),
207            "imported signer should not have a registration date"
208        );
209    }
210}