mithril_common/test/double/
certificate_retriever.rs

1//! A module used for a fake implementation of a certificate chain retriever
2//!
3
4use anyhow::anyhow;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use tokio::sync::RwLock;
8
9use crate::certificate_chain::{CertificateRetriever, CertificateRetrieverError};
10use crate::entities::Certificate;
11
12/// A fake [CertificateRetriever] that returns a [Certificate] given its hash
13pub struct FakeCertificaterRetriever {
14    certificates_map: RwLock<HashMap<String, Certificate>>,
15}
16
17impl FakeCertificaterRetriever {
18    /// Create a new [FakeCertificaterRetriever]
19    pub fn from_certificates(certificates: &[Certificate]) -> Self {
20        let certificates_map = certificates
21            .iter()
22            .map(|certificate| (certificate.hash.clone(), certificate.clone()))
23            .collect::<HashMap<_, _>>();
24        let certificates_map = RwLock::new(certificates_map);
25
26        Self { certificates_map }
27    }
28}
29
30#[cfg_attr(target_family = "wasm", async_trait(?Send))]
31#[cfg_attr(not(target_family = "wasm"), async_trait)]
32impl CertificateRetriever for FakeCertificaterRetriever {
33    async fn get_certificate_details(
34        &self,
35        certificate_hash: &str,
36    ) -> Result<Certificate, CertificateRetrieverError> {
37        let certificates_map = self.certificates_map.read().await;
38        certificates_map
39            .get(certificate_hash)
40            .cloned()
41            .ok_or_else(|| CertificateRetrieverError(anyhow!("Certificate not found")))
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use crate::test::double::fake_data;
48
49    use super::*;
50
51    #[tokio::test]
52    async fn fake_certificate_retriever_retrieves_existing_certificate() {
53        let certificate = fake_data::certificate("certificate-hash-123".to_string());
54        let certificate_hash = certificate.hash.clone();
55        let certificate_retriever =
56            FakeCertificaterRetriever::from_certificates(&[certificate.clone()]);
57
58        let retrieved_certificate = certificate_retriever
59            .get_certificate_details(&certificate_hash)
60            .await
61            .expect("Should retrieve certificate");
62
63        assert_eq!(retrieved_certificate, certificate);
64    }
65
66    #[tokio::test]
67    async fn test_fake_certificate_fails_retrieving_unknown_certificate() {
68        let certificate = fake_data::certificate("certificate-hash-123".to_string());
69        let certificate_retriever = FakeCertificaterRetriever::from_certificates(&[certificate]);
70
71        let retrieved_certificate = certificate_retriever
72            .get_certificate_details("certificate-hash-not-found")
73            .await;
74
75        retrieved_certificate.expect_err("get_certificate_details should fail");
76    }
77}