mithril_aggregator/database/query/signer/
import_signer.rs

1use std::iter::repeat_n;
2
3use sqlite::Value;
4
5use mithril_persistence::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition};
6
7use crate::database::record::SignerRecord;
8
9/// Query used by the [signer importer][crate::tools::SignersImporter] to register a [SignerRecord]
10/// in the sqlite database.
11///
12/// If it already exists it's `pool_ticker` and `updated_at` fields will be updated.
13pub struct ImportSignerRecordQuery {
14    condition: WhereCondition,
15}
16
17impl ImportSignerRecordQuery {
18    pub fn one(signer_record: SignerRecord) -> Self {
19        Self::many(vec![signer_record])
20    }
21
22    pub fn many(signer_records: Vec<SignerRecord>) -> Self {
23        let columns = "(signer_id, pool_ticker, created_at, updated_at, last_registered_at)";
24        let values_columns: Vec<&str> =
25            repeat_n("(?*, ?*, ?*, ?*, ?*)", signer_records.len()).collect();
26        let values = signer_records
27            .into_iter()
28            .flat_map(|signer_record| {
29                vec![
30                    Value::String(signer_record.signer_id),
31                    signer_record.pool_ticker.map(Value::String).unwrap_or(Value::Null),
32                    Value::String(signer_record.created_at.to_rfc3339()),
33                    Value::String(signer_record.updated_at.to_rfc3339()),
34                    signer_record
35                        .last_registered_at
36                        .map(|d| Value::String(d.to_rfc3339()))
37                        .unwrap_or(Value::Null),
38                ]
39            })
40            .collect();
41
42        let condition = WhereCondition::new(
43            format!("{columns} values {}", values_columns.join(", ")).as_str(),
44            values,
45        );
46
47        Self { condition }
48    }
49}
50
51impl Query for ImportSignerRecordQuery {
52    type Entity = SignerRecord;
53
54    fn filters(&self) -> WhereCondition {
55        self.condition.clone()
56    }
57
58    fn get_definition(&self, condition: &str) -> String {
59        // it is important to alias the fields with the same name as the table
60        // since the table cannot be aliased in a RETURNING statement in SQLite.
61        let projection =
62            Self::Entity::get_projection().expand(SourceAlias::new(&[("{:signer:}", "signer")]));
63
64        format!(
65            "insert into signer {condition} on conflict(signer_id) do update \
66            set pool_ticker = excluded.pool_ticker, updated_at = excluded.updated_at returning {projection}"
67        )
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use chrono::Duration;
74    use mithril_persistence::sqlite::ConnectionExtensions;
75
76    use crate::database::test_helper::{insert_signers, main_db_connection};
77
78    use super::*;
79
80    #[test]
81    fn test_update_signer_record() {
82        let signer_records_fake = SignerRecord::fake_records(5);
83
84        let connection = main_db_connection().unwrap();
85        insert_signers(&connection, signer_records_fake.clone()).unwrap();
86
87        for signer_record in signer_records_fake.clone() {
88            let signer_record_saved = connection
89                .fetch_first(ImportSignerRecordQuery::one(signer_record.clone()))
90                .unwrap();
91            assert_eq!(Some(signer_record), signer_record_saved);
92        }
93
94        for mut signer_record in signer_records_fake {
95            signer_record.pool_ticker = Some(format!("new-pool-{}", signer_record.signer_id));
96            signer_record.updated_at += Duration::try_hours(1).unwrap();
97            let signer_record_saved = connection
98                .fetch_first(ImportSignerRecordQuery::one(signer_record.clone()))
99                .unwrap();
100            assert_eq!(Some(signer_record), signer_record_saved);
101        }
102    }
103
104    #[test]
105    fn test_update_many_signer_records() {
106        let mut signer_records_fake = SignerRecord::fake_records(5);
107        signer_records_fake.sort_by(|a, b| a.signer_id.cmp(&b.signer_id));
108
109        let connection = main_db_connection().unwrap();
110        insert_signers(&connection, signer_records_fake.clone()).unwrap();
111
112        let mut saved_records: Vec<SignerRecord> = connection
113            .fetch_collect(ImportSignerRecordQuery::many(signer_records_fake.clone()))
114            .unwrap();
115        saved_records.sort_by(|a, b| a.signer_id.cmp(&b.signer_id));
116        assert_eq!(signer_records_fake, saved_records);
117
118        for signer_record in signer_records_fake.iter_mut() {
119            signer_record.pool_ticker = Some(format!("new-pool-{}", signer_record.signer_id));
120            signer_record.updated_at += Duration::try_hours(1).unwrap();
121        }
122        let mut saved_records: Vec<SignerRecord> = connection
123            .fetch_collect(ImportSignerRecordQuery::many(signer_records_fake.clone()))
124            .unwrap();
125        saved_records.sort_by(|a, b| a.signer_id.cmp(&b.signer_id));
126        assert_eq!(signer_records_fake, saved_records);
127    }
128}