mithril_aggregator/database/query/certificate/
get_certificate.rs

1use sqlite::Value;
2
3#[cfg(test)]
4use mithril_common::StdResult;
5#[cfg(test)]
6use mithril_common::entities::Epoch;
7use mithril_persistence::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition};
8
9use crate::database::record::CertificateRecord;
10
11/// Simple queries to retrieve [CertificateRecord] from the sqlite database.
12pub struct GetCertificateRecordQuery {
13    condition: WhereCondition,
14}
15
16impl GetCertificateRecordQuery {
17    pub fn all() -> Self {
18        Self {
19            condition: WhereCondition::default(),
20        }
21    }
22
23    pub fn all_genesis() -> Self {
24        Self {
25            condition: WhereCondition::new("parent_certificate_id is null", vec![]),
26        }
27    }
28
29    pub fn by_certificate_id(certificate_id: &str) -> Self {
30        Self {
31            condition: WhereCondition::new(
32                "certificate_id = ?*",
33                vec![Value::String(certificate_id.to_owned())],
34            ),
35        }
36    }
37
38    #[cfg(test)]
39    pub fn by_epoch(epoch: Epoch) -> StdResult<Self> {
40        Ok(Self {
41            condition: WhereCondition::new("epoch = ?*", vec![Value::Integer(epoch.try_into()?)]),
42        })
43    }
44}
45
46impl Query for GetCertificateRecordQuery {
47    type Entity = CertificateRecord;
48
49    fn filters(&self) -> WhereCondition {
50        self.condition.clone()
51    }
52
53    fn get_definition(&self, condition: &str) -> String {
54        let aliases = SourceAlias::new(&[("{:certificate:}", "c")]);
55        let projection = Self::Entity::get_projection().expand(aliases);
56        format!("select {projection} from certificate as c where {condition} order by ROWID desc")
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use mithril_common::crypto_helper::ProtocolParameters;
63    use mithril_common::crypto_helper::tests_setup::setup_certificate_chain;
64    use mithril_common::test_utils::CertificateChainBuilder;
65
66    use mithril_persistence::sqlite::ConnectionExtensions;
67
68    use crate::database::test_helper::{insert_certificate_records, main_db_connection};
69
70    use super::*;
71
72    #[test]
73    fn test_get_certificate_records_by_epoch() {
74        let certificates = setup_certificate_chain(20, 7);
75
76        let connection = main_db_connection().unwrap();
77        insert_certificate_records(&connection, certificates.certificates_chained.clone());
78
79        let certificate_records: Vec<CertificateRecord> = connection
80            .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(1)).unwrap())
81            .unwrap();
82        let expected_certificate_records: Vec<CertificateRecord> = certificates
83            .reversed_chain()
84            .into_iter()
85            .filter_map(|c| (c.epoch == Epoch(1)).then_some(c.to_owned().into()))
86            .collect();
87        assert_eq!(expected_certificate_records, certificate_records);
88
89        let certificate_records: Vec<CertificateRecord> = connection
90            .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(3)).unwrap())
91            .unwrap();
92        let expected_certificate_records: Vec<CertificateRecord> = certificates
93            .reversed_chain()
94            .into_iter()
95            .filter_map(|c| (c.epoch == Epoch(3)).then_some(c.to_owned().into()))
96            .collect();
97        assert_eq!(expected_certificate_records, certificate_records);
98
99        let cursor = connection
100            .fetch(GetCertificateRecordQuery::by_epoch(Epoch(5)).unwrap())
101            .unwrap();
102        assert_eq!(0, cursor.count());
103    }
104
105    #[test]
106    fn test_get_all_certificate_records() {
107        let certificates = setup_certificate_chain(5, 2);
108        let expected_certificate_records: Vec<CertificateRecord> =
109            certificates.reversed_chain().into_iter().map(Into::into).collect();
110
111        let connection = main_db_connection().unwrap();
112        insert_certificate_records(&connection, certificates.certificates_chained.clone());
113
114        let certificate_records: Vec<CertificateRecord> =
115            connection.fetch_collect(GetCertificateRecordQuery::all()).unwrap();
116        assert_eq!(expected_certificate_records, certificate_records);
117    }
118
119    #[test]
120    fn test_get_all_genesis_certificate_records() {
121        // Two chains with different protocol parameters so generated certificates are different.
122        let first_certificates_chain = CertificateChainBuilder::new()
123            .with_total_certificates(2)
124            .with_protocol_parameters(ProtocolParameters {
125                m: 90,
126                k: 4,
127                phi_f: 0.65,
128            })
129            .build();
130        let first_chain_genesis: CertificateRecord =
131            first_certificates_chain.genesis_certificate().clone().into();
132        let second_certificates_chain = CertificateChainBuilder::new()
133            .with_total_certificates(2)
134            .with_protocol_parameters(ProtocolParameters {
135                m: 100,
136                k: 5,
137                phi_f: 0.65,
138            })
139            .build();
140        let second_chain_genesis: CertificateRecord =
141            second_certificates_chain.genesis_certificate().clone().into();
142        assert_ne!(first_chain_genesis, second_chain_genesis);
143
144        let connection = main_db_connection().unwrap();
145        let certificate_records: Vec<CertificateRecord> = connection
146            .fetch_collect(GetCertificateRecordQuery::all_genesis())
147            .unwrap();
148        assert_eq!(Vec::<CertificateRecord>::new(), certificate_records);
149
150        insert_certificate_records(&connection, first_certificates_chain.certificates_chained);
151
152        let certificate_records: Vec<CertificateRecord> = connection
153            .fetch_collect(GetCertificateRecordQuery::all_genesis())
154            .unwrap();
155        assert_eq!(vec![first_chain_genesis.to_owned()], certificate_records);
156
157        insert_certificate_records(&connection, second_certificates_chain.certificates_chained);
158
159        let certificate_records: Vec<CertificateRecord> = connection
160            .fetch_collect(GetCertificateRecordQuery::all_genesis())
161            .unwrap();
162        assert_eq!(
163            vec![second_chain_genesis, first_chain_genesis],
164            certificate_records
165        );
166    }
167}