mithril_aggregator/database/query/certificate/
get_certificate.rs

1use sqlite::Value;
2
3#[cfg(test)]
4use mithril_common::entities::Epoch;
5#[cfg(test)]
6use mithril_common::StdResult;
7use mithril_persistence::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition};
8
9use crate::database::record::CertificateRecord;
10
11/// Simple queries to retrieve [CertificateRecord] from the sqlite database.
12pub struct GetCertificateRecordQuery {
13    condition: WhereCondition,
14}
15
16impl GetCertificateRecordQuery {
17    pub fn all() -> Self {
18        Self {
19            condition: WhereCondition::default(),
20        }
21    }
22
23    pub fn by_certificate_id(certificate_id: &str) -> Self {
24        Self {
25            condition: WhereCondition::new(
26                "certificate_id = ?*",
27                vec![Value::String(certificate_id.to_owned())],
28            ),
29        }
30    }
31
32    #[cfg(test)]
33    pub fn by_epoch(epoch: Epoch) -> StdResult<Self> {
34        Ok(Self {
35            condition: WhereCondition::new("epoch = ?*", vec![Value::Integer(epoch.try_into()?)]),
36        })
37    }
38}
39
40impl Query for GetCertificateRecordQuery {
41    type Entity = CertificateRecord;
42
43    fn filters(&self) -> WhereCondition {
44        self.condition.clone()
45    }
46
47    fn get_definition(&self, condition: &str) -> String {
48        let aliases = SourceAlias::new(&[("{:certificate:}", "c")]);
49        let projection = Self::Entity::get_projection().expand(aliases);
50        format!("select {projection} from certificate as c where {condition} order by ROWID desc")
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use mithril_common::crypto_helper::tests_setup::setup_certificate_chain;
57    use mithril_persistence::sqlite::ConnectionExtensions;
58
59    use crate::database::test_helper::{insert_certificate_records, main_db_connection};
60
61    use super::*;
62
63    #[test]
64    fn test_get_certificate_records_by_epoch() {
65        let (certificates, _) = setup_certificate_chain(20, 7);
66
67        let connection = main_db_connection().unwrap();
68        insert_certificate_records(&connection, certificates.clone());
69
70        let certificate_records: Vec<CertificateRecord> = connection
71            .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(1)).unwrap())
72            .unwrap();
73        let expected_certificate_records: Vec<CertificateRecord> = certificates
74            .iter()
75            .filter_map(|c| (c.epoch == Epoch(1)).then_some(c.to_owned().into()))
76            .rev()
77            .collect();
78        assert_eq!(expected_certificate_records, certificate_records);
79
80        let certificate_records: Vec<CertificateRecord> = connection
81            .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(3)).unwrap())
82            .unwrap();
83        let expected_certificate_records: Vec<CertificateRecord> = certificates
84            .iter()
85            .filter_map(|c| (c.epoch == Epoch(3)).then_some(c.to_owned().into()))
86            .rev()
87            .collect();
88        assert_eq!(expected_certificate_records, certificate_records);
89
90        let cursor = connection
91            .fetch(GetCertificateRecordQuery::by_epoch(Epoch(5)).unwrap())
92            .unwrap();
93        assert_eq!(0, cursor.count());
94    }
95
96    #[test]
97    fn test_get_all_certificate_records() {
98        let (certificates, _) = setup_certificate_chain(5, 2);
99        let expected_certificate_records: Vec<CertificateRecord> = certificates
100            .iter()
101            .map(|c| c.to_owned().into())
102            .rev()
103            .collect();
104
105        let connection = main_db_connection().unwrap();
106        insert_certificate_records(&connection, certificates.clone());
107
108        let certificate_records: Vec<CertificateRecord> = connection
109            .fetch_collect(GetCertificateRecordQuery::all())
110            .unwrap();
111        assert_eq!(expected_certificate_records, certificate_records);
112    }
113}