mithril_dmq/consumer/client/
deduplicator.rs1use std::{collections::HashMap, fmt::Debug, sync::Arc, time::Duration};
2
3use blake2::{Blake2b, Digest, digest::consts::U64};
4use tokio::sync::Mutex;
5
6use mithril_common::{
7 StdResult,
8 crypto_helper::{TryFromBytes, TryToBytes},
9 entities::PartyId,
10};
11
12use crate::{DmqConsumerClient, model::UnixTimestampProvider};
13
14pub const DMQ_MESSAGE_DEDUPLICATOR_TTL: Duration = Duration::from_secs(1800);
16
17type MessageKey = Vec<u8>;
19
20pub struct DmqConsumerClientDeduplicator<
26 M: TryFromBytes + TryToBytes + Debug + Send + Sync + Clone + Eq,
27> {
28 inner: Arc<dyn DmqConsumerClient<M>>,
29 timestamp_provider: Arc<dyn UnixTimestampProvider>,
30 seen_messages: Mutex<HashMap<MessageKey, u64>>,
31 ttl: Duration,
32}
33
34impl<M: TryFromBytes + TryToBytes + Debug + Send + Sync + Clone + Eq>
35 DmqConsumerClientDeduplicator<M>
36{
37 pub fn new(
41 inner: Arc<dyn DmqConsumerClient<M>>,
42 timestamp_provider: Arc<dyn UnixTimestampProvider>,
43 ttl: Duration,
44 ) -> Self {
45 Self {
46 inner,
47 timestamp_provider,
48 seen_messages: Mutex::new(HashMap::new()),
49 ttl,
50 }
51 }
52
53 pub fn new_with_default_ttl(
55 inner: Arc<dyn DmqConsumerClient<M>>,
56 timestamp_provider: Arc<dyn UnixTimestampProvider>,
57 ) -> Self {
58 Self::new(inner, timestamp_provider, DMQ_MESSAGE_DEDUPLICATOR_TTL)
59 }
60
61 fn try_compute_message_key(&self, message: &M) -> StdResult<MessageKey> {
63 let mut hasher = Blake2b::<U64>::new();
64 hasher.update(&message.to_bytes_vec()?);
65
66 Ok(hasher.finalize().to_vec())
67 }
68
69 fn current_timestamp(&self) -> StdResult<u64> {
71 self.timestamp_provider.current_timestamp()
72 }
73
74 fn is_expired_timestamp(&self, timestamp: u64, current_timestamp: u64) -> bool {
76 current_timestamp.saturating_sub(timestamp) > self.ttl.as_secs()
77 }
78
79 async fn remove_expired_messages(&self, current_timestamp: u64) {
81 let mut seen_messages = self.seen_messages.lock().await;
82 seen_messages
83 .retain(|_, timestamp| !self.is_expired_timestamp(*timestamp, current_timestamp));
84 }
85
86 async fn has_message_been_seen(&self, message: &M) -> StdResult<bool> {
88 let seen_messages = self.seen_messages.lock().await;
89
90 Ok(seen_messages.contains_key(&self.try_compute_message_key(message)?))
91 }
92
93 async fn mark_message_as_seen(&self, message: M, timestamp: u64) -> StdResult<()> {
95 let mut seen_messages = self.seen_messages.lock().await;
96 seen_messages.insert(self.try_compute_message_key(&message)?, timestamp);
97
98 Ok(())
99 }
100
101 #[cfg(test)]
102 async fn set_seen_messages_expiration_timestamp(&self, timestamp: u64) {
104 let mut seen_messages = self.seen_messages.lock().await;
105 for seen_timestamp in seen_messages.values_mut() {
106 *seen_timestamp = timestamp;
107 }
108 }
109
110 #[cfg(test)]
111 async fn seen_messages_count(&self) -> usize {
113 self.seen_messages.lock().await.len()
114 }
115
116 #[cfg(test)]
117 async fn has_seen_message(&self, message: &M) -> StdResult<bool> {
119 Ok(self
120 .seen_messages
121 .lock()
122 .await
123 .contains_key(&self.try_compute_message_key(message)?))
124 }
125}
126
127#[async_trait::async_trait]
128impl<M: TryFromBytes + TryToBytes + Debug + Send + Sync + Clone + Eq + Clone> DmqConsumerClient<M>
129 for DmqConsumerClientDeduplicator<M>
130{
131 async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
132 let messages = self.inner.consume_messages().await;
133 let current_timestamp = self.current_timestamp()?;
134 self.remove_expired_messages(current_timestamp).await;
135 let messages = messages?;
136
137 let mut deduplicated_messages = Vec::new();
138 for (message, party_id) in messages {
139 if !self.has_message_been_seen(&message).await? {
140 self.mark_message_as_seen(message.clone(), current_timestamp).await?;
141 deduplicated_messages.push((message, party_id));
142 }
143 }
144
145 Ok(deduplicated_messages)
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use std::time::Duration;
152
153 use crate::test::{
154 double::DmqConsumerFake, double::FakeUnixTimestampProvider, payload::DmqMessageTestPayload,
155 };
156
157 use super::*;
158
159 fn create_deduplicator(
160 inner_results: Vec<StdResult<Vec<(DmqMessageTestPayload, PartyId)>>>,
161 current_timestamp: u64,
162 ttl: Duration,
163 ) -> DmqConsumerClientDeduplicator<DmqMessageTestPayload> {
164 DmqConsumerClientDeduplicator::new(
165 Arc::new(DmqConsumerFake::new(inner_results)),
166 Arc::new(FakeUnixTimestampProvider::new(current_timestamp)),
167 ttl,
168 )
169 }
170
171 #[tokio::test]
172 async fn returns_not_already_seen_messages_from_inner_client() {
173 let expected_messages = vec![
174 (
175 DmqMessageTestPayload::new(b"message-1"),
176 "party-1".to_string(),
177 ),
178 (
179 DmqMessageTestPayload::new(b"message-2"),
180 "party-2".to_string(),
181 ),
182 ];
183 let deduplicator = create_deduplicator(
184 vec![Ok(expected_messages.clone())],
185 1000,
186 Duration::from_secs(600),
187 );
188
189 let messages = deduplicator.consume_messages().await.unwrap();
190
191 assert_eq!(expected_messages, messages);
192 }
193
194 #[tokio::test]
195 async fn returns_nothing_when_inner_returns_nothing() {
196 let deduplicator = create_deduplicator(vec![Ok(vec![])], 1000, Duration::from_secs(600));
197
198 let messages = deduplicator.consume_messages().await.unwrap();
199
200 assert!(messages.is_empty());
201 }
202
203 #[tokio::test]
204 async fn returns_error_from_failing_inner_client() {
205 let deduplicator = create_deduplicator(
206 vec![Err(anyhow::anyhow!("Inner client error"))],
207 1000,
208 Duration::from_secs(600),
209 );
210
211 let result = deduplicator.consume_messages().await;
212
213 result.expect_err("Should return an error");
214 }
215
216 #[tokio::test]
217 async fn filters_out_already_seen_messages_in_same_call() {
218 let duplicate_message = DmqMessageTestPayload::new(b"duplicate");
219 let unique_message = DmqMessageTestPayload::new(b"unique");
220 let inner_results = vec![Ok(vec![
221 (duplicate_message.clone(), "party-1".to_string()),
222 (duplicate_message.clone(), "party-2".to_string()),
223 (unique_message.clone(), "party-3".to_string()),
224 ])];
225 let deduplicator = create_deduplicator(inner_results, 1000, Duration::from_secs(600));
226
227 let messages = deduplicator.consume_messages().await.unwrap();
228
229 assert_eq!(
230 vec![
231 (duplicate_message, "party-1".to_string()),
232 (unique_message, "party-3".to_string()),
233 ],
234 messages
235 );
236 }
237
238 #[tokio::test]
239 async fn filters_out_already_seen_messages_across_calls() {
240 let message_1 = DmqMessageTestPayload::new(b"message-1");
241 let message_2 = DmqMessageTestPayload::new(b"message-2");
242 let inner_results = vec![
243 Ok(vec![(message_1.clone(), "party-1".to_string())]),
244 Ok(vec![
245 (message_1.clone(), "party-1-duplicate".to_string()),
246 (message_2.clone(), "party-2".to_string()),
247 ]),
248 ];
249 let deduplicator = create_deduplicator(inner_results, 1000, Duration::from_secs(600));
250
251 let batch_1 = deduplicator.consume_messages().await.unwrap();
252 let batch_2 = deduplicator.consume_messages().await.unwrap();
253
254 assert_eq!(vec![(message_1, "party-1".to_string())], batch_1);
255 assert_eq!(vec![(message_2, "party-2".to_string())], batch_2);
256 }
257
258 #[tokio::test]
259 async fn cleans_up_expired_entries_on_consume() {
260 let message_1 = DmqMessageTestPayload::new(b"message-1");
261 let message_2 = DmqMessageTestPayload::new(b"message-2");
262 let deduplicator = create_deduplicator(
263 vec![
264 Ok(vec![(message_1.clone(), "party-1".to_string())]),
265 Ok(vec![(message_2.clone(), "party-2".to_string())]),
266 ],
267 1000,
268 Duration::from_secs(100),
269 );
270
271 deduplicator.consume_messages().await.unwrap();
272 assert_eq!(1, deduplicator.seen_messages_count().await);
273
274 deduplicator.set_seen_messages_expiration_timestamp(800).await;
275 deduplicator.consume_messages().await.unwrap();
276
277 assert_eq!(1, deduplicator.seen_messages_count().await);
278 assert!(deduplicator.has_seen_message(&message_2).await.unwrap());
279 assert!(!deduplicator.has_seen_message(&message_1).await.unwrap());
280 }
281
282 #[tokio::test]
283 async fn cleans_up_expired_entries_when_inner_client_fails() {
284 let message = DmqMessageTestPayload::new(b"message");
285 let deduplicator = create_deduplicator(
286 vec![
287 Ok(vec![(message.clone(), "party-1".to_string())]),
288 Err(anyhow::anyhow!("Inner client error")),
289 ],
290 1000,
291 Duration::from_secs(100),
292 );
293
294 deduplicator.consume_messages().await.unwrap();
295 assert_eq!(1, deduplicator.seen_messages_count().await);
296
297 deduplicator.set_seen_messages_expiration_timestamp(800).await;
298 deduplicator
299 .consume_messages()
300 .await
301 .expect_err("Should return an error");
302
303 assert_eq!(0, deduplicator.seen_messages_count().await);
304 }
305}