mithril_aggregator/database/repository/
signer_store.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::Utc;
6
7use mithril_common::StdResult;
8use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
9
10use crate::SignerRecorder;
11use crate::database::query::{
12    GetSignerRecordQuery, ImportSignerRecordQuery, RegisterSignerRecordQuery,
13};
14use crate::database::record::SignerRecord;
15
16/// Service to get [SignerRecord].
17#[cfg_attr(test, mockall::automock)]
18#[async_trait]
19pub trait SignerGetter: Sync + Send {
20    /// Return all stored records.
21    async fn get_all(&self) -> StdResult<Vec<SignerRecord>>;
22}
23
24/// Service to deal with signer (read & write).
25pub struct SignerStore {
26    connection: Arc<SqliteConnection>,
27}
28
29impl SignerStore {
30    /// Create a new SignerStore service
31    pub fn new(connection: Arc<SqliteConnection>) -> Self {
32        Self { connection }
33    }
34
35    /// Import a signer in the database, its last_registered_at date will be left empty
36    pub async fn import_signer(
37        &self,
38        signer_id: String,
39        pool_ticker: Option<String>,
40    ) -> StdResult<()> {
41        let created_at = Utc::now();
42        let updated_at = created_at;
43        let signer_record = SignerRecord {
44            signer_id,
45            pool_ticker,
46            created_at,
47            updated_at,
48            last_registered_at: None,
49        };
50        self.connection
51            .fetch_first(ImportSignerRecordQuery::one(signer_record))?;
52
53        Ok(())
54    }
55
56    /// Create many signers at once in the database, their last_registered_at date will be left empty
57    pub async fn import_many_signers(
58        &self,
59        pool_ticker_by_id: HashMap<String, Option<String>>,
60    ) -> StdResult<()> {
61        let created_at = Utc::now();
62        let updated_at = created_at;
63        let signer_records: Vec<_> = pool_ticker_by_id
64            .into_iter()
65            .map(|(signer_id, pool_ticker)| SignerRecord {
66                signer_id,
67                pool_ticker,
68                created_at,
69                updated_at,
70                last_registered_at: None,
71            })
72            .collect();
73        self.connection
74            .fetch_first(ImportSignerRecordQuery::many(signer_records))?;
75
76        Ok(())
77    }
78}
79
80#[async_trait]
81impl SignerRecorder for SignerStore {
82    async fn record_signer_registration(&self, signer_id: String) -> StdResult<()> {
83        let created_at = Utc::now();
84        let updated_at = created_at;
85        let registered_at = Some(created_at);
86        let signer_record = SignerRecord {
87            signer_id,
88            pool_ticker: None,
89            created_at,
90            updated_at,
91            last_registered_at: registered_at,
92        };
93        self.connection
94            .fetch_first(RegisterSignerRecordQuery::one(signer_record))?;
95
96        Ok(())
97    }
98}
99
100#[async_trait]
101impl SignerGetter for SignerStore {
102    async fn get_all(&self) -> StdResult<Vec<SignerRecord>> {
103        self.connection.fetch_collect(GetSignerRecordQuery::all())
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use std::collections::BTreeMap;
110
111    use crate::database::test_helper::{insert_signers, main_db_connection};
112
113    use super::*;
114
115    #[tokio::test]
116    async fn test_get_all_signers() {
117        let signer_records = SignerRecord::fake_records(5);
118        let expected: Vec<_> = signer_records.iter().rev().cloned().collect();
119        let connection = main_db_connection().unwrap();
120        insert_signers(&connection, signer_records).unwrap();
121
122        let store = SignerStore::new(Arc::new(connection));
123
124        let stored_signers = store.get_all().await.expect("getting all signers should not fail");
125
126        assert_eq!(expected, stored_signers);
127    }
128
129    #[tokio::test]
130    async fn test_signer_recorder() {
131        let signer_records_fake = SignerRecord::fake_records(5);
132
133        let connection = Arc::new(main_db_connection().unwrap());
134        let store_recorder = SignerStore::new(connection.clone());
135
136        for signer_record in signer_records_fake.clone() {
137            store_recorder
138                .record_signer_registration(signer_record.signer_id.clone())
139                .await
140                .expect("record_signer_registration should not fail");
141            let signer_record_stored = connection
142                .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
143                .unwrap();
144            assert!(signer_record_stored.is_some());
145            assert!(
146                signer_record_stored.unwrap().last_registered_at.is_some(),
147                "registering a signer should set the registration date"
148            )
149        }
150    }
151
152    #[tokio::test]
153    async fn test_store_import_signer() {
154        let signer_records_fake = SignerRecord::fake_records(5);
155
156        let connection = Arc::new(main_db_connection().unwrap());
157        let store = SignerStore::new(connection.clone());
158
159        for signer_record in signer_records_fake {
160            store
161                .import_signer(
162                    signer_record.signer_id.clone(),
163                    signer_record.pool_ticker.clone(),
164                )
165                .await
166                .expect("import_signer should not fail");
167            let signer_record_stored = connection
168                .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
169                .unwrap();
170            assert!(signer_record_stored.is_some());
171            assert!(
172                signer_record_stored.unwrap().last_registered_at.is_none(),
173                "imported signer should not have a registration date"
174            )
175        }
176    }
177
178    #[tokio::test]
179    async fn test_store_import_many_signers() {
180        let signers_fake: BTreeMap<_, _> = SignerRecord::fake_records(5)
181            .into_iter()
182            .map(|r| (r.signer_id, r.pool_ticker))
183            .collect();
184
185        let connection = main_db_connection().unwrap();
186        let store = SignerStore::new(Arc::new(connection));
187
188        store
189            .import_many_signers(signers_fake.clone().into_iter().collect())
190            .await
191            .expect("import_many_signers should not fail");
192
193        let signer_records_stored = store.get_all().await.unwrap();
194        let signers_stored = signer_records_stored
195            .iter()
196            .cloned()
197            .map(|r| (r.signer_id, r.pool_ticker))
198            .collect();
199        assert_eq!(signers_fake, signers_stored);
200        assert!(
201            signer_records_stored.iter().all(|s| s.last_registered_at.is_none()),
202            "imported signer should not have a registration date"
203        );
204    }
205}