mithril_aggregator/database/query/certificate/
insert_certificate.rs

1use std::iter::repeat_n;
2
3use sqlite::Value;
4
5use mithril_persistence::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition};
6
7use crate::database::record::CertificateRecord;
8
9/// Query to insert [CertificateRecord] in the sqlite database
10pub struct InsertCertificateRecordQuery {
11    condition: WhereCondition,
12}
13
14impl InsertCertificateRecordQuery {
15    pub fn one(certificate_record: CertificateRecord) -> Self {
16        Self::many(vec![certificate_record])
17    }
18
19    pub fn many(certificates_records: Vec<CertificateRecord>) -> Self {
20        let columns = "(\
21        certificate_id, \
22        parent_certificate_id, \
23        message, \
24        signature, \
25        aggregate_verification_key, \
26        epoch, \
27        network, \
28        signed_entity_type_id, \
29        signed_entity_beacon, \
30        protocol_version, \
31        protocol_parameters, \
32        protocol_message, \
33        signers, \
34        initiated_at, \
35        sealed_at)";
36        let values_columns: Vec<&str> = repeat_n(
37            "(?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*)",
38            certificates_records.len(),
39        )
40        .collect();
41
42        let values: Vec<Value> = certificates_records
43            .into_iter()
44            .flat_map(|certificate_record| {
45                vec![
46                    Value::String(certificate_record.certificate_id),
47                    match certificate_record.parent_certificate_id {
48                        Some(parent_certificate_id) => Value::String(parent_certificate_id),
49                        None => Value::Null,
50                    },
51                    Value::String(certificate_record.message),
52                    Value::String(certificate_record.signature),
53                    Value::String(certificate_record.aggregate_verification_key),
54                    Value::Integer(certificate_record.epoch.try_into().unwrap()),
55                    Value::String(certificate_record.network),
56                    Value::Integer(certificate_record.signed_entity_type.index() as i64),
57                    Value::String(certificate_record.signed_entity_type.get_json_beacon().unwrap()),
58                    Value::String(certificate_record.protocol_version),
59                    Value::String(
60                        serde_json::to_string(&certificate_record.protocol_parameters).unwrap(),
61                    ),
62                    Value::String(
63                        serde_json::to_string(&certificate_record.protocol_message).unwrap(),
64                    ),
65                    Value::String(serde_json::to_string(&certificate_record.signers).unwrap()),
66                    Value::String(certificate_record.initiated_at.to_rfc3339()),
67                    Value::String(certificate_record.sealed_at.to_rfc3339()),
68                ]
69            })
70            .collect();
71
72        let condition = WhereCondition::new(
73            format!("{columns} values {}", values_columns.join(", ")).as_str(),
74            values,
75        );
76
77        Self { condition }
78    }
79}
80
81impl Query for InsertCertificateRecordQuery {
82    type Entity = CertificateRecord;
83
84    fn filters(&self) -> WhereCondition {
85        self.condition.clone()
86    }
87
88    fn get_definition(&self, condition: &str) -> String {
89        // it is important to alias the fields with the same name as the table
90        // since the table cannot be aliased in a RETURNING statement in SQLite.
91        let projection = Self::Entity::get_projection()
92            .expand(SourceAlias::new(&[("{:certificate:}", "certificate")]));
93
94        format!("insert into certificate {condition} returning {projection}")
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use mithril_common::crypto_helper::tests_setup::setup_certificate_chain;
101    use mithril_persistence::sqlite::ConnectionExtensions;
102
103    use crate::database::test_helper::main_db_connection;
104
105    use super::*;
106
107    #[test]
108    fn test_insert_certificate_record() {
109        let certificates = setup_certificate_chain(5, 2);
110
111        let connection = main_db_connection().unwrap();
112
113        for certificate in certificates.certificates_chained {
114            let certificate_record: CertificateRecord = certificate.into();
115            let certificate_record_saved = connection
116                .fetch_first(InsertCertificateRecordQuery::one(
117                    certificate_record.clone(),
118                ))
119                .unwrap();
120            assert_eq!(Some(certificate_record), certificate_record_saved);
121        }
122    }
123
124    #[test]
125    fn test_insert_many_certificates_records() {
126        let certificates = setup_certificate_chain(5, 2);
127        let certificates_records: Vec<CertificateRecord> = certificates.into();
128
129        let connection = main_db_connection().unwrap();
130
131        let certificates_records_saved: Vec<CertificateRecord> = connection
132            .fetch_collect(InsertCertificateRecordQuery::many(
133                certificates_records.clone(),
134            ))
135            .expect("saving many records should not fail");
136
137        assert_eq!(certificates_records, certificates_records_saved);
138    }
139}