mithril_aggregator/database/query/certificate/
get_certificate.rs1use sqlite::Value;
2
3#[cfg(test)]
4use mithril_common::entities::Epoch;
5#[cfg(test)]
6use mithril_common::StdResult;
7use mithril_persistence::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition};
8
9use crate::database::record::CertificateRecord;
10
11pub struct GetCertificateRecordQuery {
13 condition: WhereCondition,
14}
15
16impl GetCertificateRecordQuery {
17 pub fn all() -> Self {
18 Self {
19 condition: WhereCondition::default(),
20 }
21 }
22
23 pub fn by_certificate_id(certificate_id: &str) -> Self {
24 Self {
25 condition: WhereCondition::new(
26 "certificate_id = ?*",
27 vec![Value::String(certificate_id.to_owned())],
28 ),
29 }
30 }
31
32 #[cfg(test)]
33 pub fn by_epoch(epoch: Epoch) -> StdResult<Self> {
34 Ok(Self {
35 condition: WhereCondition::new("epoch = ?*", vec![Value::Integer(epoch.try_into()?)]),
36 })
37 }
38}
39
40impl Query for GetCertificateRecordQuery {
41 type Entity = CertificateRecord;
42
43 fn filters(&self) -> WhereCondition {
44 self.condition.clone()
45 }
46
47 fn get_definition(&self, condition: &str) -> String {
48 let aliases = SourceAlias::new(&[("{:certificate:}", "c")]);
49 let projection = Self::Entity::get_projection().expand(aliases);
50 format!("select {projection} from certificate as c where {condition} order by ROWID desc")
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use mithril_common::crypto_helper::tests_setup::setup_certificate_chain;
57 use mithril_persistence::sqlite::ConnectionExtensions;
58
59 use crate::database::test_helper::{insert_certificate_records, main_db_connection};
60
61 use super::*;
62
63 #[test]
64 fn test_get_certificate_records_by_epoch() {
65 let (certificates, _) = setup_certificate_chain(20, 7);
66
67 let connection = main_db_connection().unwrap();
68 insert_certificate_records(&connection, certificates.clone());
69
70 let certificate_records: Vec<CertificateRecord> = connection
71 .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(1)).unwrap())
72 .unwrap();
73 let expected_certificate_records: Vec<CertificateRecord> = certificates
74 .iter()
75 .filter_map(|c| (c.epoch == Epoch(1)).then_some(c.to_owned().into()))
76 .rev()
77 .collect();
78 assert_eq!(expected_certificate_records, certificate_records);
79
80 let certificate_records: Vec<CertificateRecord> = connection
81 .fetch_collect(GetCertificateRecordQuery::by_epoch(Epoch(3)).unwrap())
82 .unwrap();
83 let expected_certificate_records: Vec<CertificateRecord> = certificates
84 .iter()
85 .filter_map(|c| (c.epoch == Epoch(3)).then_some(c.to_owned().into()))
86 .rev()
87 .collect();
88 assert_eq!(expected_certificate_records, certificate_records);
89
90 let cursor = connection
91 .fetch(GetCertificateRecordQuery::by_epoch(Epoch(5)).unwrap())
92 .unwrap();
93 assert_eq!(0, cursor.count());
94 }
95
96 #[test]
97 fn test_get_all_certificate_records() {
98 let (certificates, _) = setup_certificate_chain(5, 2);
99 let expected_certificate_records: Vec<CertificateRecord> = certificates
100 .iter()
101 .map(|c| c.to_owned().into())
102 .rev()
103 .collect();
104
105 let connection = main_db_connection().unwrap();
106 insert_certificate_records(&connection, certificates.clone());
107
108 let certificate_records: Vec<CertificateRecord> = connection
109 .fetch_collect(GetCertificateRecordQuery::all())
110 .unwrap();
111 assert_eq!(expected_certificate_records, certificate_records);
112 }
113}