mithril_aggregator/database/repository/
signer_store.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::Utc;
6
7use mithril_common::StdResult;
8use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
9
10use crate::SignerRecorder;
11use crate::database::query::{
12 GetSignerRecordQuery, ImportSignerRecordQuery, RegisterSignerRecordQuery,
13};
14use crate::database::record::SignerRecord;
15
16#[cfg_attr(test, mockall::automock)]
18#[async_trait]
19pub trait SignerGetter: Sync + Send {
20 async fn get_all(&self) -> StdResult<Vec<SignerRecord>>;
22}
23
24pub struct SignerStore {
26 connection: Arc<SqliteConnection>,
27}
28
29impl SignerStore {
30 pub fn new(connection: Arc<SqliteConnection>) -> Self {
32 Self { connection }
33 }
34
35 pub async fn import_signer(
37 &self,
38 signer_id: String,
39 pool_ticker: Option<String>,
40 ) -> StdResult<()> {
41 let created_at = Utc::now();
42 let updated_at = created_at;
43 let signer_record = SignerRecord {
44 signer_id,
45 pool_ticker,
46 created_at,
47 updated_at,
48 last_registered_at: None,
49 };
50 self.connection
51 .fetch_first(ImportSignerRecordQuery::one(signer_record))?;
52
53 Ok(())
54 }
55
56 pub async fn import_many_signers(
58 &self,
59 pool_ticker_by_id: HashMap<String, Option<String>>,
60 ) -> StdResult<()> {
61 let created_at = Utc::now();
62 let updated_at = created_at;
63 let signer_records: Vec<_> = pool_ticker_by_id
64 .into_iter()
65 .map(|(signer_id, pool_ticker)| SignerRecord {
66 signer_id,
67 pool_ticker,
68 created_at,
69 updated_at,
70 last_registered_at: None,
71 })
72 .collect();
73 self.connection
74 .fetch_first(ImportSignerRecordQuery::many(signer_records))?;
75
76 Ok(())
77 }
78}
79
80#[async_trait]
81impl SignerRecorder for SignerStore {
82 async fn record_signer_registration(&self, signer_id: String) -> StdResult<()> {
83 let created_at = Utc::now();
84 let updated_at = created_at;
85 let registered_at = Some(created_at);
86 let signer_record = SignerRecord {
87 signer_id,
88 pool_ticker: None,
89 created_at,
90 updated_at,
91 last_registered_at: registered_at,
92 };
93 self.connection
94 .fetch_first(RegisterSignerRecordQuery::one(signer_record))?;
95
96 Ok(())
97 }
98}
99
100#[async_trait]
101impl SignerGetter for SignerStore {
102 async fn get_all(&self) -> StdResult<Vec<SignerRecord>> {
103 self.connection.fetch_collect(GetSignerRecordQuery::all())
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use std::collections::BTreeMap;
110
111 use crate::database::test_helper::{insert_signers, main_db_connection};
112
113 use super::*;
114
115 #[tokio::test]
116 async fn test_get_all_signers() {
117 let signer_records = SignerRecord::fake_records(5);
118 let expected: Vec<_> = signer_records.iter().rev().cloned().collect();
119 let connection = main_db_connection().unwrap();
120 insert_signers(&connection, signer_records).unwrap();
121
122 let store = SignerStore::new(Arc::new(connection));
123
124 let stored_signers = store.get_all().await.expect("getting all signers should not fail");
125
126 assert_eq!(expected, stored_signers);
127 }
128
129 #[tokio::test]
130 async fn test_signer_recorder() {
131 let signer_records_fake = SignerRecord::fake_records(5);
132
133 let connection = Arc::new(main_db_connection().unwrap());
134 let store_recorder = SignerStore::new(connection.clone());
135
136 for signer_record in signer_records_fake.clone() {
137 store_recorder
138 .record_signer_registration(signer_record.signer_id.clone())
139 .await
140 .expect("record_signer_registration should not fail");
141 let signer_record_stored = connection
142 .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
143 .unwrap();
144 assert!(signer_record_stored.is_some());
145 assert!(
146 signer_record_stored.unwrap().last_registered_at.is_some(),
147 "registering a signer should set the registration date"
148 )
149 }
150 }
151
152 #[tokio::test]
153 async fn test_store_import_signer() {
154 let signer_records_fake = SignerRecord::fake_records(5);
155
156 let connection = Arc::new(main_db_connection().unwrap());
157 let store = SignerStore::new(connection.clone());
158
159 for signer_record in signer_records_fake {
160 store
161 .import_signer(
162 signer_record.signer_id.clone(),
163 signer_record.pool_ticker.clone(),
164 )
165 .await
166 .expect("import_signer should not fail");
167 let signer_record_stored = connection
168 .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
169 .unwrap();
170 assert!(signer_record_stored.is_some());
171 assert!(
172 signer_record_stored.unwrap().last_registered_at.is_none(),
173 "imported signer should not have a registration date"
174 )
175 }
176 }
177
178 #[tokio::test]
179 async fn test_store_import_many_signers() {
180 let signers_fake: BTreeMap<_, _> = SignerRecord::fake_records(5)
181 .into_iter()
182 .map(|r| (r.signer_id, r.pool_ticker))
183 .collect();
184
185 let connection = main_db_connection().unwrap();
186 let store = SignerStore::new(Arc::new(connection));
187
188 store
189 .import_many_signers(signers_fake.clone().into_iter().collect())
190 .await
191 .expect("import_many_signers should not fail");
192
193 let signer_records_stored = store.get_all().await.unwrap();
194 let signers_stored = signer_records_stored
195 .iter()
196 .cloned()
197 .map(|r| (r.signer_id, r.pool_ticker))
198 .collect();
199 assert_eq!(signers_fake, signers_stored);
200 assert!(
201 signer_records_stored.iter().all(|s| s.last_registered_at.is_none()),
202 "imported signer should not have a registration date"
203 );
204 }
205}