mithril_aggregator/database/query/signer/
import_signer.rs

1use std::iter::repeat_n;
2
3use sqlite::Value;
4
5use mithril_persistence::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition};
6
7use crate::database::record::SignerRecord;
8
9/// Query used by the [signer importer][crate::tools::SignersImporter] to register a [SignerRecord]
10/// in the sqlite database.
11///
12/// If it already exists it's `pool_ticker` and `updated_at` fields will be updated.
13pub struct ImportSignerRecordQuery {
14    condition: WhereCondition,
15}
16
17impl ImportSignerRecordQuery {
18    pub fn one(signer_record: SignerRecord) -> Self {
19        Self::many(vec![signer_record])
20    }
21
22    pub fn many(signer_records: Vec<SignerRecord>) -> Self {
23        let columns = "(signer_id, pool_ticker, created_at, updated_at, last_registered_at)";
24        let values_columns: Vec<&str> =
25            repeat_n("(?*, ?*, ?*, ?*, ?*)", signer_records.len()).collect();
26        let values = signer_records
27            .into_iter()
28            .flat_map(|signer_record| {
29                vec![
30                    Value::String(signer_record.signer_id),
31                    signer_record
32                        .pool_ticker
33                        .map(Value::String)
34                        .unwrap_or(Value::Null),
35                    Value::String(signer_record.created_at.to_rfc3339()),
36                    Value::String(signer_record.updated_at.to_rfc3339()),
37                    signer_record
38                        .last_registered_at
39                        .map(|d| Value::String(d.to_rfc3339()))
40                        .unwrap_or(Value::Null),
41                ]
42            })
43            .collect();
44
45        let condition = WhereCondition::new(
46            format!("{columns} values {}", values_columns.join(", ")).as_str(),
47            values,
48        );
49
50        Self { condition }
51    }
52}
53
54impl Query for ImportSignerRecordQuery {
55    type Entity = SignerRecord;
56
57    fn filters(&self) -> WhereCondition {
58        self.condition.clone()
59    }
60
61    fn get_definition(&self, condition: &str) -> String {
62        // it is important to alias the fields with the same name as the table
63        // since the table cannot be aliased in a RETURNING statement in SQLite.
64        let projection =
65            Self::Entity::get_projection().expand(SourceAlias::new(&[("{:signer:}", "signer")]));
66
67        format!(
68            "insert into signer {condition} on conflict(signer_id) do update \
69            set pool_ticker = excluded.pool_ticker, updated_at = excluded.updated_at returning {projection}"
70        )
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use chrono::Duration;
77    use mithril_persistence::sqlite::ConnectionExtensions;
78
79    use crate::database::test_helper::{insert_signers, main_db_connection};
80
81    use super::*;
82
83    #[test]
84    fn test_update_signer_record() {
85        let signer_records_fake = SignerRecord::fake_records(5);
86
87        let connection = main_db_connection().unwrap();
88        insert_signers(&connection, signer_records_fake.clone()).unwrap();
89
90        for signer_record in signer_records_fake.clone() {
91            let signer_record_saved = connection
92                .fetch_first(ImportSignerRecordQuery::one(signer_record.clone()))
93                .unwrap();
94            assert_eq!(Some(signer_record), signer_record_saved);
95        }
96
97        for mut signer_record in signer_records_fake {
98            signer_record.pool_ticker = Some(format!("new-pool-{}", signer_record.signer_id));
99            signer_record.updated_at += Duration::try_hours(1).unwrap();
100            let signer_record_saved = connection
101                .fetch_first(ImportSignerRecordQuery::one(signer_record.clone()))
102                .unwrap();
103            assert_eq!(Some(signer_record), signer_record_saved);
104        }
105    }
106
107    #[test]
108    fn test_update_many_signer_records() {
109        let mut signer_records_fake = SignerRecord::fake_records(5);
110        signer_records_fake.sort_by(|a, b| a.signer_id.cmp(&b.signer_id));
111
112        let connection = main_db_connection().unwrap();
113        insert_signers(&connection, signer_records_fake.clone()).unwrap();
114
115        let mut saved_records: Vec<SignerRecord> = connection
116            .fetch_collect(ImportSignerRecordQuery::many(signer_records_fake.clone()))
117            .unwrap();
118        saved_records.sort_by(|a, b| a.signer_id.cmp(&b.signer_id));
119        assert_eq!(signer_records_fake, saved_records);
120
121        for signer_record in signer_records_fake.iter_mut() {
122            signer_record.pool_ticker = Some(format!("new-pool-{}", signer_record.signer_id));
123            signer_record.updated_at += Duration::try_hours(1).unwrap();
124        }
125        let mut saved_records: Vec<SignerRecord> = connection
126            .fetch_collect(ImportSignerRecordQuery::many(signer_records_fake.clone()))
127            .unwrap();
128        saved_records.sort_by(|a, b| a.signer_id.cmp(&b.signer_id));
129        assert_eq!(signer_records_fake, saved_records);
130    }
131}