mithril_client/certificate_client/verify_cache/
memory_cache.rs1use async_trait::async_trait;
2use chrono::{DateTime, TimeDelta, Utc};
3use std::collections::HashMap;
4use std::ops::Add;
5use tokio::sync::RwLock;
6
7use crate::certificate_client::CertificateVerifierCache;
8use crate::MithrilResult;
9
10pub type CertificateHash = str;
11pub type PreviousCertificateHash = str;
12
13pub struct MemoryCertificateVerifierCache {
15 expiration_delay: TimeDelta,
16 cache: RwLock<HashMap<String, CachedCertificate>>,
17}
18
19#[derive(Debug, PartialEq, Eq, Clone)]
20struct CachedCertificate {
21 previous_hash: String,
22 expire_at: DateTime<Utc>,
23}
24
25impl CachedCertificate {
26 fn new<TPreviousHash: Into<String>>(
27 previous_hash: TPreviousHash,
28 expire_at: DateTime<Utc>,
29 ) -> Self {
30 CachedCertificate {
31 previous_hash: previous_hash.into(),
32 expire_at,
33 }
34 }
35}
36
37impl MemoryCertificateVerifierCache {
38 pub fn new(expiration_delay: TimeDelta) -> Self {
40 MemoryCertificateVerifierCache {
41 expiration_delay,
42 cache: RwLock::new(HashMap::new()),
43 }
44 }
45
46 pub async fn len(&self) -> usize {
48 self.cache.read().await.len()
49 }
50
51 pub async fn is_empty(&self) -> bool {
53 self.cache.read().await.is_empty()
54 }
55}
56
57#[cfg_attr(target_family = "wasm", async_trait(?Send))]
58#[cfg_attr(not(target_family = "wasm"), async_trait)]
59impl CertificateVerifierCache for MemoryCertificateVerifierCache {
60 async fn store_validated_certificate(
61 &self,
62 certificate_hash: &CertificateHash,
63 previous_certificate_hash: &PreviousCertificateHash,
64 ) -> MithrilResult<()> {
65 let mut cache = self.cache.write().await;
67 cache.insert(
68 certificate_hash.to_string(),
69 CachedCertificate::new(
70 previous_certificate_hash,
71 Utc::now().add(self.expiration_delay),
72 ),
73 );
74 Ok(())
75 }
76
77 async fn get_previous_hash(
78 &self,
79 certificate_hash: &CertificateHash,
80 ) -> MithrilResult<Option<String>> {
81 let cache = self.cache.read().await;
82 Ok(cache
83 .get(certificate_hash)
84 .filter(|cached| cached.expire_at >= Utc::now())
85 .map(|cached| cached.previous_hash.clone()))
86 }
87
88 async fn reset(&self) -> MithrilResult<()> {
89 let mut cache = self.cache.write().await;
90 cache.clear();
91 Ok(())
92 }
93}
94
95#[cfg(test)]
96pub(crate) mod test_tools {
97 use mithril_common::entities::Certificate;
98
99 use super::*;
100
101 impl MemoryCertificateVerifierCache {
102 pub(crate) fn with_items<'a, T>(mut self, key_values: T) -> Self
104 where
105 T: IntoIterator<Item = (&'a CertificateHash, &'a PreviousCertificateHash)>,
106 {
107 let expire_at = Utc::now() + self.expiration_delay;
108 self.cache = RwLock::new(
109 key_values
110 .into_iter()
111 .map(|(k, v)| (k.to_string(), CachedCertificate::new(v, expire_at)))
112 .collect(),
113 );
114 self
115 }
116
117 pub(crate) fn with_items_from_chain<'a, T>(self, chain: T) -> Self
119 where
120 T: IntoIterator<Item = &'a Certificate>,
121 {
122 self.with_items(
123 chain
124 .into_iter()
125 .map(|cert| (cert.hash.as_str(), cert.previous_hash.as_str())),
126 )
127 }
128
129 pub(crate) async fn content(&self) -> HashMap<String, String> {
131 self.cache
132 .read()
133 .await
134 .iter()
135 .map(|(hash, cached)| (hash.clone(), cached.previous_hash.clone()))
136 .collect()
137 }
138
139 pub(crate) async fn overwrite_expiration_date(
143 &self,
144 certificate_hash: &CertificateHash,
145 expire_at: DateTime<Utc>,
146 ) {
147 let mut cache = self.cache.write().await;
148 cache
149 .get_mut(certificate_hash)
150 .expect("Key not found")
151 .expire_at = expire_at;
152 }
153
154 pub(super) async fn get_cached_value(
156 &self,
157 certificate_hash: &CertificateHash,
158 ) -> Option<CachedCertificate> {
159 self.cache.read().await.get(certificate_hash).cloned()
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use mithril_common::entities::Certificate;
167 use mithril_common::test_utils::fake_data;
168
169 use super::*;
170
171 #[tokio::test]
172 async fn from_str_iterator() {
173 let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
174 .with_items([("first", "one"), ("second", "two")]);
175
176 assert_eq!(
177 HashMap::from_iter([
178 ("first".to_string(), "one".to_string()),
179 ("second".to_string(), "two".to_string())
180 ]),
181 cache.content().await
182 );
183 }
184
185 #[tokio::test]
186 async fn from_certificate_iterator() {
187 let chain = vec![
188 Certificate {
189 previous_hash: "first_parent".to_string(),
190 ..fake_data::certificate("first")
191 },
192 Certificate {
193 previous_hash: "second_parent".to_string(),
194 ..fake_data::certificate("second")
195 },
196 ];
197 let cache =
198 MemoryCertificateVerifierCache::new(TimeDelta::hours(1)).with_items_from_chain(&chain);
199
200 assert_eq!(
201 HashMap::from_iter([
202 ("first".to_string(), "first_parent".to_string()),
203 ("second".to_string(), "second_parent".to_string())
204 ]),
205 cache.content().await
206 );
207 }
208
209 mod store_validated_certificate {
210 use super::*;
211
212 #[tokio::test]
213 async fn store_in_empty_cache_add_new_item_that_expire_after_parametrized_delay() {
214 let expiration_delay = TimeDelta::hours(1);
215 let start_time = Utc::now();
216 let cache = MemoryCertificateVerifierCache::new(expiration_delay);
217 cache
218 .store_validated_certificate("hash", "parent")
219 .await
220 .unwrap();
221
222 let cached = cache
223 .get_cached_value("hash")
224 .await
225 .expect("Cache should have been populated");
226
227 assert_eq!(1, cache.len().await);
228 assert_eq!("parent", cached.previous_hash);
229 assert!(cached.expire_at - start_time >= expiration_delay);
230 }
231
232 #[tokio::test]
233 async fn store_new_hash_push_new_key_at_end_and_dont_alter_existing_values() {
234 let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1)).with_items([
235 ("existing_hash", "existing_parent"),
236 ("another_hash", "another_parent"),
237 ]);
238 cache
239 .store_validated_certificate("new_hash", "new_parent")
240 .await
241 .unwrap();
242
243 assert_eq!(
244 HashMap::from_iter([
245 ("existing_hash".to_string(), "existing_parent".to_string()),
246 ("another_hash".to_string(), "another_parent".to_string()),
247 ("new_hash".to_string(), "new_parent".to_string()),
248 ]),
249 cache.content().await
250 );
251 }
252
253 #[tokio::test]
254 async fn storing_same_hash_update_parent_hash_and_expiration_time() {
255 let expiration_delay = TimeDelta::days(2);
256 let start_time = Utc::now();
257 let cache = MemoryCertificateVerifierCache::new(expiration_delay)
258 .with_items([("hash", "first_parent"), ("another_hash", "another_parent")]);
259
260 let initial_value = cache.get_cached_value("hash").await.unwrap();
261
262 cache
263 .store_validated_certificate("hash", "updated_parent")
264 .await
265 .unwrap();
266
267 let updated_value = cache.get_cached_value("hash").await.unwrap();
268
269 assert_eq!(2, cache.len().await);
270 assert_eq!(
271 Some("another_parent".to_string()),
272 cache.get_previous_hash("another_hash").await.unwrap(),
273 "Existing but not updated value should not have been altered"
274 );
275 assert_ne!(initial_value, updated_value);
276 assert_eq!("updated_parent", updated_value.previous_hash);
277 assert!(updated_value.expire_at - start_time >= expiration_delay);
278 }
279 }
280
281 mod get_previous_hash {
282 use super::*;
283
284 #[tokio::test]
285 async fn get_previous_hash_when_key_exists() {
286 let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
287 .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
288
289 assert_eq!(
290 Some("parent".to_string()),
291 cache.get_previous_hash("hash").await.unwrap()
292 );
293 }
294
295 #[tokio::test]
296 async fn get_previous_hash_return_none_if_not_found() {
297 let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
298 .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
299
300 assert_eq!(None, cache.get_previous_hash("not_found").await.unwrap());
301 }
302
303 #[tokio::test]
304 async fn get_expired_previous_hash_return_none() {
305 let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
306 .with_items([("hash", "parent")]);
307 cache
308 .overwrite_expiration_date("hash", Utc::now() - TimeDelta::days(5))
309 .await;
310
311 assert_eq!(None, cache.get_previous_hash("hash").await.unwrap());
312 }
313 }
314
315 mod reset {
316 use super::*;
317
318 #[tokio::test]
319 async fn reset_empty_cache_dont_raise_error() {
320 let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1));
321
322 cache.reset().await.unwrap();
323
324 assert_eq!(HashMap::new(), cache.content().await);
325 }
326
327 #[tokio::test]
328 async fn reset_not_empty_cache() {
329 let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(1))
330 .with_items([("hash", "parent"), ("another_hash", "another_parent")]);
331
332 cache.reset().await.unwrap();
333
334 assert_eq!(HashMap::new(), cache.content().await);
335 }
336 }
337}