mithril_aggregator/database/repository/
signer_store.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::Utc;
6
7use mithril_common::StdResult;
8use mithril_persistence::sqlite::{ConnectionExtensions, SqliteConnection};
9
10use crate::database::query::{
11 GetSignerRecordQuery, ImportSignerRecordQuery, RegisterSignerRecordQuery,
12};
13use crate::database::record::SignerRecord;
14use crate::SignerRecorder;
15
16#[cfg_attr(test, mockall::automock)]
18#[async_trait]
19pub trait SignerGetter: Sync + Send {
20 async fn get_all(&self) -> StdResult<Vec<SignerRecord>>;
22}
23
24pub struct SignerStore {
26 connection: Arc<SqliteConnection>,
27}
28
29impl SignerStore {
30 pub fn new(connection: Arc<SqliteConnection>) -> Self {
32 Self { connection }
33 }
34
35 pub async fn import_signer(
37 &self,
38 signer_id: String,
39 pool_ticker: Option<String>,
40 ) -> StdResult<()> {
41 let created_at = Utc::now();
42 let updated_at = created_at;
43 let signer_record = SignerRecord {
44 signer_id,
45 pool_ticker,
46 created_at,
47 updated_at,
48 last_registered_at: None,
49 };
50 self.connection
51 .fetch_first(ImportSignerRecordQuery::one(signer_record))?;
52
53 Ok(())
54 }
55
56 pub async fn import_many_signers(
58 &self,
59 pool_ticker_by_id: HashMap<String, Option<String>>,
60 ) -> StdResult<()> {
61 let created_at = Utc::now();
62 let updated_at = created_at;
63 let signer_records: Vec<_> = pool_ticker_by_id
64 .into_iter()
65 .map(|(signer_id, pool_ticker)| SignerRecord {
66 signer_id,
67 pool_ticker,
68 created_at,
69 updated_at,
70 last_registered_at: None,
71 })
72 .collect();
73 self.connection
74 .fetch_first(ImportSignerRecordQuery::many(signer_records))?;
75
76 Ok(())
77 }
78}
79
80#[async_trait]
81impl SignerRecorder for SignerStore {
82 async fn record_signer_registration(&self, signer_id: String) -> StdResult<()> {
83 let created_at = Utc::now();
84 let updated_at = created_at;
85 let registered_at = Some(created_at);
86 let signer_record = SignerRecord {
87 signer_id,
88 pool_ticker: None,
89 created_at,
90 updated_at,
91 last_registered_at: registered_at,
92 };
93 self.connection
94 .fetch_first(RegisterSignerRecordQuery::one(signer_record))?;
95
96 Ok(())
97 }
98}
99
100#[async_trait]
101impl SignerGetter for SignerStore {
102 async fn get_all(&self) -> StdResult<Vec<SignerRecord>> {
103 self.connection.fetch_collect(GetSignerRecordQuery::all())
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use std::collections::BTreeMap;
110
111 use crate::database::test_helper::{insert_signers, main_db_connection};
112
113 use super::*;
114
115 #[tokio::test]
116 async fn test_get_all_signers() {
117 let signer_records = SignerRecord::fake_records(5);
118 let expected: Vec<_> = signer_records.iter().rev().cloned().collect();
119 let connection = main_db_connection().unwrap();
120 insert_signers(&connection, signer_records).unwrap();
121
122 let store = SignerStore::new(Arc::new(connection));
123
124 let stored_signers = store
125 .get_all()
126 .await
127 .expect("getting all signers should not fail");
128
129 assert_eq!(expected, stored_signers);
130 }
131
132 #[tokio::test]
133 async fn test_signer_recorder() {
134 let signer_records_fake = SignerRecord::fake_records(5);
135
136 let connection = Arc::new(main_db_connection().unwrap());
137 let store_recorder = SignerStore::new(connection.clone());
138
139 for signer_record in signer_records_fake.clone() {
140 store_recorder
141 .record_signer_registration(signer_record.signer_id.clone())
142 .await
143 .expect("record_signer_registration should not fail");
144 let signer_record_stored = connection
145 .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
146 .unwrap();
147 assert!(signer_record_stored.is_some());
148 assert!(
149 signer_record_stored.unwrap().last_registered_at.is_some(),
150 "registering a signer should set the registration date"
151 )
152 }
153 }
154
155 #[tokio::test]
156 async fn test_store_import_signer() {
157 let signer_records_fake = SignerRecord::fake_records(5);
158
159 let connection = Arc::new(main_db_connection().unwrap());
160 let store = SignerStore::new(connection.clone());
161
162 for signer_record in signer_records_fake {
163 store
164 .import_signer(
165 signer_record.signer_id.clone(),
166 signer_record.pool_ticker.clone(),
167 )
168 .await
169 .expect("import_signer should not fail");
170 let signer_record_stored = connection
171 .fetch_first(GetSignerRecordQuery::by_signer_id(signer_record.signer_id))
172 .unwrap();
173 assert!(signer_record_stored.is_some());
174 assert!(
175 signer_record_stored.unwrap().last_registered_at.is_none(),
176 "imported signer should not have a registration date"
177 )
178 }
179 }
180
181 #[tokio::test]
182 async fn test_store_import_many_signers() {
183 let signers_fake: BTreeMap<_, _> = SignerRecord::fake_records(5)
184 .into_iter()
185 .map(|r| (r.signer_id, r.pool_ticker))
186 .collect();
187
188 let connection = main_db_connection().unwrap();
189 let store = SignerStore::new(Arc::new(connection));
190
191 store
192 .import_many_signers(signers_fake.clone().into_iter().collect())
193 .await
194 .expect("import_many_signers should not fail");
195
196 let signer_records_stored = store.get_all().await.unwrap();
197 let signers_stored = signer_records_stored
198 .iter()
199 .cloned()
200 .map(|r| (r.signer_id, r.pool_ticker))
201 .collect();
202 assert_eq!(signers_fake, signers_stored);
203 assert!(
204 signer_records_stored
205 .iter()
206 .all(|s| s.last_registered_at.is_none()),
207 "imported signer should not have a registration date"
208 );
209 }
210}