mithril_client/certificate_client/verify_cache/
memory_cache.rs

1use async_trait::async_trait;
2use chrono::{DateTime, TimeDelta, Utc};
3use std::collections::HashMap;
4use std::ops::Add;
5use tokio::sync::RwLock;
6
7use crate::MithrilResult;
8use crate::certificate_client::CertificateVerifierCache;
9
10pub type CertificateHash = str;
11pub type PreviousCertificateHash = str;
12
13/// A in-memory cache for the certificate verifier.
14pub struct MemoryCertificateVerifierCache {
15    expiration_delay: TimeDelta,
16    cache: RwLock<HashMap<String, CachedCertificate>>,
17}
18
19#[derive(Debug, PartialEq, Eq, Clone)]
20struct CachedCertificate {
21    previous_hash: String,
22    expire_at: DateTime<Utc>,
23}
24
25impl CachedCertificate {
26    fn new<TPreviousHash: Into<String>>(
27        previous_hash: TPreviousHash,
28        expire_at: DateTime<Utc>,
29    ) -> Self {
30        CachedCertificate {
31            previous_hash: previous_hash.into(),
32            expire_at,
33        }
34    }
35}
36
37impl MemoryCertificateVerifierCache {
38    /// `MemoryCertificateVerifierCache` factory
39    pub fn new(expiration_delay: TimeDelta) -> Self {
40        MemoryCertificateVerifierCache {
41            expiration_delay,
42            cache: RwLock::new(HashMap::new()),
43        }
44    }
45
46    /// Get the number of elements in the cache
47    pub async fn len(&self) -> usize {
48        self.cache.read().await.len()
49    }
50
51    /// Return true if the cache is empty
52    pub async fn is_empty(&self) -> bool {
53        self.cache.read().await.is_empty()
54    }
55}
56
57#[cfg_attr(target_family = "wasm", async_trait(?Send))]
58#[cfg_attr(not(target_family = "wasm"), async_trait)]
59impl CertificateVerifierCache for MemoryCertificateVerifierCache {
60    async fn store_validated_certificate(
61        &self,
62        certificate_hash: &CertificateHash,
63        previous_certificate_hash: &PreviousCertificateHash,
64    ) -> MithrilResult<()> {
65        // todo: should we raise an error if an empty string is given for previous_certificate_hash ? (or any other kind of validation)
66        let mut cache = self.cache.write().await;
67        cache.insert(
68            certificate_hash.to_string(),
69            CachedCertificate::new(
70                previous_certificate_hash,
71                Utc::now().add(self.expiration_delay),
72            ),
73        );
74        Ok(())
75    }
76
77    async fn get_previous_hash(
78        &self,
79        certificate_hash: &CertificateHash,
80    ) -> MithrilResult<Option<String>> {
81        let cache = self.cache.read().await;
82        Ok(cache
83            .get(certificate_hash)
84            .filter(|cached| cached.expire_at >= Utc::now())
85            .map(|cached| cached.previous_hash.clone()))
86    }
87
88    async fn reset(&self) -> MithrilResult<()> {
89        let mut cache = self.cache.write().await;
90        cache.clear();
91        Ok(())
92    }
93}
94
95#[cfg(test)]
96pub(crate) mod test_tools {
97    use mithril_common::entities::Certificate;
98
99    use super::*;
100
101    impl MemoryCertificateVerifierCache {
102        /// `Test only` Populate the cache with the given hash and previous hash
103        pub(crate) fn with_items<'a, T>(mut self, key_values: T) -> Self
104        where
105            T: IntoIterator<Item = (&'a CertificateHash, &'a PreviousCertificateHash)>,
106        {
107            let expire_at = Utc::now() + self.expiration_delay;
108            self.cache = RwLock::new(
109                key_values
110                    .into_iter()
111                    .map(|(k, v)| (k.to_string(), CachedCertificate::new(v, expire_at)))
112                    .collect(),
113            );
114            self
115        }
116
117        /// `Test only` Populate the cache with the given hash and previous hash from given certificates
118        pub(crate) fn with_items_from_chain<'a, T>(self, chain: T) -> Self
119        where
120            T: IntoIterator<Item = &'a Certificate>,
121        {
122            self.with_items(
123                chain
124                    .into_iter()
125                    .map(|cert| (cert.hash.as_str(), cert.previous_hash.as_str())),
126            )
127        }
128
129        /// `Test only` Return the content of the cache (without the expiration date)
130        pub(crate) async fn content(&self) -> HashMap<String, String> {
131            self.cache
132                .read()
133                .await
134                .iter()
135                .map(|(hash, cached)| (hash.clone(), cached.previous_hash.clone()))
136                .collect()
137        }
138
139        /// `Test only` Overwrite the expiration date of an entry the given certificate hash.
140        ///
141        /// panic if the key is not found
142        pub(crate) async fn overwrite_expiration_date(
143            &self,
144            certificate_hash: &CertificateHash,
145            expire_at: DateTime<Utc>,
146        ) {
147            let mut cache = self.cache.write().await;
148            cache.get_mut(certificate_hash).expect("Key not found").expire_at = expire_at;
149        }
150
151        /// `Test only` Get the cached value for the given certificate hash
152        pub(super) async fn get_cached_value(
153            &self,
154            certificate_hash: &CertificateHash,
155        ) -> Option<CachedCertificate> {
156            self.cache.read().await.get(certificate_hash).cloned()
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use mithril_common::entities::Certificate;
164    use mithril_common::test::double::fake_data;
165
166    use super::*;
167
168    #[tokio::test]
169    async fn from_str_iterator() {
170        let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
171            .with_items([("first", "one"), ("second", "two")]);
172
173        assert_eq!(
174            HashMap::from_iter([
175                ("first".to_string(), "one".to_string()),
176                ("second".to_string(), "two".to_string())
177            ]),
178            cache.content().await
179        );
180    }
181
182    #[tokio::test]
183    async fn from_certificate_iterator() {
184        let chain = vec![
185            Certificate {
186                previous_hash: "first_parent".to_string(),
187                ..fake_data::certificate("first")
188            },
189            Certificate {
190                previous_hash: "second_parent".to_string(),
191                ..fake_data::certificate("second")
192            },
193        ];
194        let cache =
195            MemoryCertificateVerifierCache::new(TimeDelta::hours(1)).with_items_from_chain(&chain);
196
197        assert_eq!(
198            HashMap::from_iter([
199                ("first".to_string(), "first_parent".to_string()),
200                ("second".to_string(), "second_parent".to_string())
201            ]),
202            cache.content().await
203        );
204    }
205
206    mod store_validated_certificate {
207        use super::*;
208
209        #[tokio::test]
210        async fn store_in_empty_cache_add_new_item_that_expire_after_parametrized_delay() {
211            let expiration_delay = TimeDelta::hours(1);
212            let start_time = Utc::now();
213            let cache = MemoryCertificateVerifierCache::new(expiration_delay);
214            cache.store_validated_certificate("hash", "parent").await.unwrap();
215
216            let cached = cache
217                .get_cached_value("hash")
218                .await
219                .expect("Cache should have been populated");
220
221            assert_eq!(1, cache.len().await);
222            assert_eq!("parent", cached.previous_hash);
223            assert!(cached.expire_at - start_time >= expiration_delay);
224        }
225
226        #[tokio::test]
227        async fn store_new_hash_push_new_key_at_end_and_dont_alter_existing_values() {
228            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1)).with_items([
229                ("existing_hash", "existing_parent"),
230                ("another_hash", "another_parent"),
231            ]);
232            cache
233                .store_validated_certificate("new_hash", "new_parent")
234                .await
235                .unwrap();
236
237            assert_eq!(
238                HashMap::from_iter([
239                    ("existing_hash".to_string(), "existing_parent".to_string()),
240                    ("another_hash".to_string(), "another_parent".to_string()),
241                    ("new_hash".to_string(), "new_parent".to_string()),
242                ]),
243                cache.content().await
244            );
245        }
246
247        #[tokio::test]
248        async fn storing_same_hash_update_parent_hash_and_expiration_time() {
249            let expiration_delay = TimeDelta::days(2);
250            let start_time = Utc::now();
251            let cache = MemoryCertificateVerifierCache::new(expiration_delay)
252                .with_items([("hash", "first_parent"), ("another_hash", "another_parent")]);
253
254            let initial_value = cache.get_cached_value("hash").await.unwrap();
255
256            cache
257                .store_validated_certificate("hash", "updated_parent")
258                .await
259                .unwrap();
260
261            let updated_value = cache.get_cached_value("hash").await.unwrap();
262
263            assert_eq!(2, cache.len().await);
264            assert_eq!(
265                Some("another_parent".to_string()),
266                cache.get_previous_hash("another_hash").await.unwrap(),
267                "Existing but not updated value should not have been altered"
268            );
269            assert_ne!(initial_value, updated_value);
270            assert_eq!("updated_parent", updated_value.previous_hash);
271            assert!(updated_value.expire_at - start_time >= expiration_delay);
272        }
273    }
274
275    mod get_previous_hash {
276        use super::*;
277
278        #[tokio::test]
279        async fn get_previous_hash_when_key_exists() {
280            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
281                .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
282
283            assert_eq!(
284                Some("parent".to_string()),
285                cache.get_previous_hash("hash").await.unwrap()
286            );
287        }
288
289        #[tokio::test]
290        async fn get_previous_hash_return_none_if_not_found() {
291            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
292                .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
293
294            assert_eq!(None, cache.get_previous_hash("not_found").await.unwrap());
295        }
296
297        #[tokio::test]
298        async fn get_expired_previous_hash_return_none() {
299            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
300                .with_items([("hash", "parent")]);
301            cache
302                .overwrite_expiration_date("hash", Utc::now() - TimeDelta::days(5))
303                .await;
304
305            assert_eq!(None, cache.get_previous_hash("hash").await.unwrap());
306        }
307    }
308
309    mod reset {
310        use super::*;
311
312        #[tokio::test]
313        async fn reset_empty_cache_dont_raise_error() {
314            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1));
315
316            cache.reset().await.unwrap();
317
318            assert_eq!(HashMap::new(), cache.content().await);
319        }
320
321        #[tokio::test]
322        async fn reset_not_empty_cache() {
323            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
324                .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
325
326            cache.reset().await.unwrap();
327
328            assert_eq!(HashMap::new(), cache.content().await);
329        }
330    }
331}