mithril_client/certificate_client/verify_cache/
memory_cache.rs

1use async_trait::async_trait;
2use chrono::{DateTime, TimeDelta, Utc};
3use std::collections::HashMap;
4use std::ops::Add;
5use tokio::sync::RwLock;
6
7use crate::certificate_client::CertificateVerifierCache;
8use crate::MithrilResult;
9
10pub type CertificateHash = str;
11pub type PreviousCertificateHash = str;
12
13/// A in-memory cache for the certificate verifier.
14pub struct MemoryCertificateVerifierCache {
15    expiration_delay: TimeDelta,
16    cache: RwLock<HashMap<String, CachedCertificate>>,
17}
18
19#[derive(Debug, PartialEq, Eq, Clone)]
20struct CachedCertificate {
21    previous_hash: String,
22    expire_at: DateTime<Utc>,
23}
24
25impl CachedCertificate {
26    fn new<TPreviousHash: Into<String>>(
27        previous_hash: TPreviousHash,
28        expire_at: DateTime<Utc>,
29    ) -> Self {
30        CachedCertificate {
31            previous_hash: previous_hash.into(),
32            expire_at,
33        }
34    }
35}
36
37impl MemoryCertificateVerifierCache {
38    /// `MemoryCertificateVerifierCache` factory
39    pub fn new(expiration_delay: TimeDelta) -> Self {
40        MemoryCertificateVerifierCache {
41            expiration_delay,
42            cache: RwLock::new(HashMap::new()),
43        }
44    }
45
46    /// Get the number of elements in the cache
47    pub async fn len(&self) -> usize {
48        self.cache.read().await.len()
49    }
50
51    /// Return true if the cache is empty
52    pub async fn is_empty(&self) -> bool {
53        self.cache.read().await.is_empty()
54    }
55}
56
57#[cfg_attr(target_family = "wasm", async_trait(?Send))]
58#[cfg_attr(not(target_family = "wasm"), async_trait)]
59impl CertificateVerifierCache for MemoryCertificateVerifierCache {
60    async fn store_validated_certificate(
61        &self,
62        certificate_hash: &CertificateHash,
63        previous_certificate_hash: &PreviousCertificateHash,
64    ) -> MithrilResult<()> {
65        // todo: should we raise an error if an empty string is given for previous_certificate_hash ? (or any other kind of validation)
66        let mut cache = self.cache.write().await;
67        cache.insert(
68            certificate_hash.to_string(),
69            CachedCertificate::new(
70                previous_certificate_hash,
71                Utc::now().add(self.expiration_delay),
72            ),
73        );
74        Ok(())
75    }
76
77    async fn get_previous_hash(
78        &self,
79        certificate_hash: &CertificateHash,
80    ) -> MithrilResult<Option<String>> {
81        let cache = self.cache.read().await;
82        Ok(cache
83            .get(certificate_hash)
84            .filter(|cached| cached.expire_at >= Utc::now())
85            .map(|cached| cached.previous_hash.clone()))
86    }
87
88    async fn reset(&self) -> MithrilResult<()> {
89        let mut cache = self.cache.write().await;
90        cache.clear();
91        Ok(())
92    }
93}
94
95#[cfg(test)]
96pub(crate) mod test_tools {
97    use mithril_common::entities::Certificate;
98
99    use super::*;
100
101    impl MemoryCertificateVerifierCache {
102        /// `Test only` Populate the cache with the given hash and previous hash
103        pub(crate) fn with_items<'a, T>(mut self, key_values: T) -> Self
104        where
105            T: IntoIterator<Item = (&'a CertificateHash, &'a PreviousCertificateHash)>,
106        {
107            let expire_at = Utc::now() + self.expiration_delay;
108            self.cache = RwLock::new(
109                key_values
110                    .into_iter()
111                    .map(|(k, v)| (k.to_string(), CachedCertificate::new(v, expire_at)))
112                    .collect(),
113            );
114            self
115        }
116
117        /// `Test only` Populate the cache with the given hash and previous hash from given certificates
118        pub(crate) fn with_items_from_chain<'a, T>(self, chain: T) -> Self
119        where
120            T: IntoIterator<Item = &'a Certificate>,
121        {
122            self.with_items(
123                chain
124                    .into_iter()
125                    .map(|cert| (cert.hash.as_str(), cert.previous_hash.as_str())),
126            )
127        }
128
129        /// `Test only` Return the content of the cache (without the expiration date)
130        pub(crate) async fn content(&self) -> HashMap<String, String> {
131            self.cache
132                .read()
133                .await
134                .iter()
135                .map(|(hash, cached)| (hash.clone(), cached.previous_hash.clone()))
136                .collect()
137        }
138
139        /// `Test only` Overwrite the expiration date of an entry the given certificate hash.
140        ///
141        /// panic if the key is not found
142        pub(crate) async fn overwrite_expiration_date(
143            &self,
144            certificate_hash: &CertificateHash,
145            expire_at: DateTime<Utc>,
146        ) {
147            let mut cache = self.cache.write().await;
148            cache
149                .get_mut(certificate_hash)
150                .expect("Key not found")
151                .expire_at = expire_at;
152        }
153
154        /// `Test only` Get the cached value for the given certificate hash
155        pub(super) async fn get_cached_value(
156            &self,
157            certificate_hash: &CertificateHash,
158        ) -> Option<CachedCertificate> {
159            self.cache.read().await.get(certificate_hash).cloned()
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use mithril_common::entities::Certificate;
167    use mithril_common::test_utils::fake_data;
168
169    use super::*;
170
171    #[tokio::test]
172    async fn from_str_iterator() {
173        let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
174            .with_items([("first", "one"), ("second", "two")]);
175
176        assert_eq!(
177            HashMap::from_iter([
178                ("first".to_string(), "one".to_string()),
179                ("second".to_string(), "two".to_string())
180            ]),
181            cache.content().await
182        );
183    }
184
185    #[tokio::test]
186    async fn from_certificate_iterator() {
187        let chain = vec![
188            Certificate {
189                previous_hash: "first_parent".to_string(),
190                ..fake_data::certificate("first")
191            },
192            Certificate {
193                previous_hash: "second_parent".to_string(),
194                ..fake_data::certificate("second")
195            },
196        ];
197        let cache =
198            MemoryCertificateVerifierCache::new(TimeDelta::hours(1)).with_items_from_chain(&chain);
199
200        assert_eq!(
201            HashMap::from_iter([
202                ("first".to_string(), "first_parent".to_string()),
203                ("second".to_string(), "second_parent".to_string())
204            ]),
205            cache.content().await
206        );
207    }
208
209    mod store_validated_certificate {
210        use super::*;
211
212        #[tokio::test]
213        async fn store_in_empty_cache_add_new_item_that_expire_after_parametrized_delay() {
214            let expiration_delay = TimeDelta::hours(1);
215            let start_time = Utc::now();
216            let cache = MemoryCertificateVerifierCache::new(expiration_delay);
217            cache
218                .store_validated_certificate("hash", "parent")
219                .await
220                .unwrap();
221
222            let cached = cache
223                .get_cached_value("hash")
224                .await
225                .expect("Cache should have been populated");
226
227            assert_eq!(1, cache.len().await);
228            assert_eq!("parent", cached.previous_hash);
229            assert!(cached.expire_at - start_time >= expiration_delay);
230        }
231
232        #[tokio::test]
233        async fn store_new_hash_push_new_key_at_end_and_dont_alter_existing_values() {
234            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1)).with_items([
235                ("existing_hash", "existing_parent"),
236                ("another_hash", "another_parent"),
237            ]);
238            cache
239                .store_validated_certificate("new_hash", "new_parent")
240                .await
241                .unwrap();
242
243            assert_eq!(
244                HashMap::from_iter([
245                    ("existing_hash".to_string(), "existing_parent".to_string()),
246                    ("another_hash".to_string(), "another_parent".to_string()),
247                    ("new_hash".to_string(), "new_parent".to_string()),
248                ]),
249                cache.content().await
250            );
251        }
252
253        #[tokio::test]
254        async fn storing_same_hash_update_parent_hash_and_expiration_time() {
255            let expiration_delay = TimeDelta::days(2);
256            let start_time = Utc::now();
257            let cache = MemoryCertificateVerifierCache::new(expiration_delay)
258                .with_items([("hash", "first_parent"), ("another_hash", "another_parent")]);
259
260            let initial_value = cache.get_cached_value("hash").await.unwrap();
261
262            cache
263                .store_validated_certificate("hash", "updated_parent")
264                .await
265                .unwrap();
266
267            let updated_value = cache.get_cached_value("hash").await.unwrap();
268
269            assert_eq!(2, cache.len().await);
270            assert_eq!(
271                Some("another_parent".to_string()),
272                cache.get_previous_hash("another_hash").await.unwrap(),
273                "Existing but not updated value should not have been altered"
274            );
275            assert_ne!(initial_value, updated_value);
276            assert_eq!("updated_parent", updated_value.previous_hash);
277            assert!(updated_value.expire_at - start_time >= expiration_delay);
278        }
279    }
280
281    mod get_previous_hash {
282        use super::*;
283
284        #[tokio::test]
285        async fn get_previous_hash_when_key_exists() {
286            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
287                .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
288
289            assert_eq!(
290                Some("parent".to_string()),
291                cache.get_previous_hash("hash").await.unwrap()
292            );
293        }
294
295        #[tokio::test]
296        async fn get_previous_hash_return_none_if_not_found() {
297            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
298                .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
299
300            assert_eq!(None, cache.get_previous_hash("not_found").await.unwrap());
301        }
302
303        #[tokio::test]
304        async fn get_expired_previous_hash_return_none() {
305            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
306                .with_items([("hash", "parent")]);
307            cache
308                .overwrite_expiration_date("hash", Utc::now() - TimeDelta::days(5))
309                .await;
310
311            assert_eq!(None, cache.get_previous_hash("hash").await.unwrap());
312        }
313    }
314
315    mod reset {
316        use super::*;
317
318        #[tokio::test]
319        async fn reset_empty_cache_dont_raise_error() {
320            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1));
321
322            cache.reset().await.unwrap();
323
324            assert_eq!(HashMap::new(), cache.content().await);
325        }
326
327        #[tokio::test]
328        async fn reset_not_empty_cache() {
329            let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
330                .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
331
332            cache.reset().await.unwrap();
333
334            assert_eq!(HashMap::new(), cache.content().await);
335        }
336    }
337}