mithril_aggregator/database/query/certificate/
get_certificate.rs

1use sqlite::Value;
2
3#[cfg(test)]
4use mithril_common::StdResult;
5#[cfg(test)]
6use mithril_common::entities::Epoch;
7use mithril_persistence::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition};
8
9use crate::database::record::CertificateRecord;
10
11/// Simple queries to retrieve [CertificateRecord] from the sqlite database.
12pub struct GetCertificateRecordQuery {
13    condition: WhereCondition,
14}
15
16impl GetCertificateRecordQuery {
17    pub fn all() -> Self {
18        Self {
19            condition: WhereCondition::default(),
20        }
21    }
22
23    pub fn all_genesis() -> Self {
24        Self {
25            condition: WhereCondition::new("parent_certificate_id is null", vec![]),
26        }
27    }
28
29    pub fn by_certificate_id(certificate_id: &str) -> Self {
30        Self {
31            condition: WhereCondition::new(
32                "certificate_id = ?*",
33                vec![Value::String(certificate_id.to_owned())],
34            ),
35        }
36    }
37
38    #[cfg(test)]
39    pub fn by_epoch(epoch: Epoch) -> StdResult<Self> {
40        Ok(Self {
41            condition: WhereCondition::new("epoch = ?*", vec![Value::Integer(epoch.try_into()?)]),
42        })
43    }
44}
45
46impl Query for GetCertificateRecordQuery {
47    type Entity = CertificateRecord;
48
49    fn filters(&self) -> WhereCondition {
50        self.condition.clone()
51    }
52
53    fn get_definition(&self, condition: &str) -> String {
54        let aliases = SourceAlias::new(&[("{:certificate:}", "c")]);
55        let projection = Self::Entity::get_projection().expand(aliases);
56        format!("select {projection} from certificate as c where {condition} order by ROWID desc")
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use mithril_common::crypto_helper::ProtocolParameters;
63    use mithril_common::test::builder::CertificateChainBuilder;
64    use mithril_common::test::crypto_helper::setup_certificate_chain;
65
66    use mithril_persistence::sqlite::ConnectionExtensions;
67
68    use crate::database::test_helper::{insert_certificate_records, main_db_connection};
69
70    use super::*;
71
72    #[test]
73    fn test_get_certificate_records_by_epoch() {
74        let certificates = setup_certificate_chain(20, 7);
75
76        let connection = main_db_connection().unwrap();
77        insert_certificate_records(&connection, certificates.certificates_chained.clone());
78
79        let certificate_records: Vec<CertificateRecord> = connection
80            .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(1)).unwrap())
81            .unwrap();
82        let expected_certificate_records: Vec<CertificateRecord> = certificates
83            .reversed_chain()
84            .into_iter()
85            .filter_map(|c| (c.epoch == Epoch(1)).then_some(c.to_owned().try_into().unwrap()))
86            .collect();
87        assert_eq!(expected_certificate_records, certificate_records);
88
89        let certificate_records: Vec<CertificateRecord> = connection
90            .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(3)).unwrap())
91            .unwrap();
92        let expected_certificate_records: Vec<CertificateRecord> = certificates
93            .reversed_chain()
94            .into_iter()
95            .filter_map(|c| (c.epoch == Epoch(3)).then_some(c.to_owned().try_into().unwrap()))
96            .collect();
97        assert_eq!(expected_certificate_records, certificate_records);
98
99        let cursor = connection
100            .fetch(GetCertificateRecordQuery::by_epoch(Epoch(5)).unwrap())
101            .unwrap();
102        assert_eq!(0, cursor.count());
103    }
104
105    #[test]
106    fn test_get_all_certificate_records() {
107        let certificates = setup_certificate_chain(5, 2);
108        let expected_certificate_records: Vec<CertificateRecord> = certificates
109            .reversed_chain()
110            .into_iter()
111            .map(|c| c.try_into().unwrap())
112            .collect();
113
114        let connection = main_db_connection().unwrap();
115        insert_certificate_records(&connection, certificates.certificates_chained.clone());
116
117        let certificate_records: Vec<CertificateRecord> =
118            connection.fetch_collect(GetCertificateRecordQuery::all()).unwrap();
119        assert_eq!(expected_certificate_records, certificate_records);
120    }
121
122    #[test]
123    fn test_get_all_genesis_certificate_records() {
124        // Two chains with different protocol parameters so generated certificates are different.
125        let first_certificates_chain = CertificateChainBuilder::new()
126            .with_total_certificates(2)
127            .with_protocol_parameters(ProtocolParameters {
128                m: 90,
129                k: 4,
130                phi_f: 0.65,
131            })
132            .build();
133        let first_chain_genesis: CertificateRecord = first_certificates_chain
134            .genesis_certificate()
135            .clone()
136            .try_into()
137            .unwrap();
138        let second_certificates_chain = CertificateChainBuilder::new()
139            .with_total_certificates(2)
140            .with_protocol_parameters(ProtocolParameters {
141                m: 100,
142                k: 5,
143                phi_f: 0.65,
144            })
145            .build();
146        let second_chain_genesis: CertificateRecord = second_certificates_chain
147            .genesis_certificate()
148            .clone()
149            .try_into()
150            .unwrap();
151        assert_ne!(first_chain_genesis, second_chain_genesis);
152
153        let connection = main_db_connection().unwrap();
154        let certificate_records: Vec<CertificateRecord> = connection
155            .fetch_collect(GetCertificateRecordQuery::all_genesis())
156            .unwrap();
157        assert_eq!(Vec::<CertificateRecord>::new(), certificate_records);
158
159        insert_certificate_records(&connection, first_certificates_chain.certificates_chained);
160
161        let certificate_records: Vec<CertificateRecord> = connection
162            .fetch_collect(GetCertificateRecordQuery::all_genesis())
163            .unwrap();
164        assert_eq!(vec![first_chain_genesis.to_owned()], certificate_records);
165
166        insert_certificate_records(&connection, second_certificates_chain.certificates_chained);
167
168        let certificate_records: Vec<CertificateRecord> = connection
169            .fetch_collect(GetCertificateRecordQuery::all_genesis())
170            .unwrap();
171        assert_eq!(
172            vec![second_chain_genesis, first_chain_genesis],
173            certificate_records
174        );
175    }
176}