mithril_dmq/consumer/client/
deduplicator.rs

1use std::{collections::HashMap, fmt::Debug, sync::Arc, time::Duration};
2
3use blake2::{Blake2b, Digest, digest::consts::U64};
4use tokio::sync::Mutex;
5
6use mithril_common::{
7    StdResult,
8    crypto_helper::{TryFromBytes, TryToBytes},
9    entities::PartyId,
10};
11
12use crate::{DmqConsumerClient, model::UnixTimestampProvider};
13
14/// Default maximum time to keep a seen message in the deduplicator cache.
15pub const DMQ_MESSAGE_DEDUPLICATOR_TTL: Duration = Duration::from_secs(1800);
16
17/// Type alias for the message key used in the deduplicator cache.
18type MessageKey = Vec<u8>;
19
20/// A DMQ consumer client that filters out duplicate messages.
21///
22/// This implementation wraps an inner [`DmqConsumerClient`] and maintains a cache of recently seen messages.
23/// When a message is consumed, if it has already been seen, the message is skipped. Otherwise, the message is passed through.
24/// Expired entries are lazily cleaned up during call to `consume_messages`.
25pub struct DmqConsumerClientDeduplicator<
26    M: TryFromBytes + TryToBytes + Debug + Send + Sync + Clone + Eq,
27> {
28    inner: Arc<dyn DmqConsumerClient<M>>,
29    timestamp_provider: Arc<dyn UnixTimestampProvider>,
30    seen_messages: Mutex<HashMap<MessageKey, u64>>,
31    ttl: Duration,
32}
33
34impl<M: TryFromBytes + TryToBytes + Debug + Send + Sync + Clone + Eq>
35    DmqConsumerClientDeduplicator<M>
36{
37    /// Creates a new `DmqConsumerClientDeduplicator` wrapping the given inner client.
38    ///
39    /// The `ttl` parameter specifies how long a message is kept in the cache.
40    pub fn new(
41        inner: Arc<dyn DmqConsumerClient<M>>,
42        timestamp_provider: Arc<dyn UnixTimestampProvider>,
43        ttl: Duration,
44    ) -> Self {
45        Self {
46            inner,
47            timestamp_provider,
48            seen_messages: Mutex::new(HashMap::new()),
49            ttl,
50        }
51    }
52
53    /// Creates a new `DmqConsumerClientDeduplicator` with the default TTL.
54    pub fn new_with_default_ttl(
55        inner: Arc<dyn DmqConsumerClient<M>>,
56        timestamp_provider: Arc<dyn UnixTimestampProvider>,
57    ) -> Self {
58        Self::new(inner, timestamp_provider, DMQ_MESSAGE_DEDUPLICATOR_TTL)
59    }
60
61    /// Computes a key for the message to be used in the seen messages cache.
62    fn try_compute_message_key(&self, message: &M) -> StdResult<MessageKey> {
63        let mut hasher = Blake2b::<U64>::new();
64        hasher.update(&message.to_bytes_vec()?);
65
66        Ok(hasher.finalize().to_vec())
67    }
68
69    /// Gets the current timestamp from the timestamp provider.
70    fn current_timestamp(&self) -> StdResult<u64> {
71        self.timestamp_provider.current_timestamp()
72    }
73
74    /// Checks if a message timestamp is expired based on the current timestamp and TTL.
75    fn is_expired_timestamp(&self, timestamp: u64, current_timestamp: u64) -> bool {
76        current_timestamp.saturating_sub(timestamp) > self.ttl.as_secs()
77    }
78
79    /// Removes expired messages from the seen messages cache.
80    async fn remove_expired_messages(&self, current_timestamp: u64) {
81        let mut seen_messages = self.seen_messages.lock().await;
82        seen_messages
83            .retain(|_, timestamp| !self.is_expired_timestamp(*timestamp, current_timestamp));
84    }
85
86    /// Checks if a message has already been seen.
87    async fn has_message_been_seen(&self, message: &M) -> StdResult<bool> {
88        let seen_messages = self.seen_messages.lock().await;
89
90        Ok(seen_messages.contains_key(&self.try_compute_message_key(message)?))
91    }
92
93    /// Marks a message as seen with the given timestamp.
94    async fn mark_message_as_seen(&self, message: M, timestamp: u64) -> StdResult<()> {
95        let mut seen_messages = self.seen_messages.lock().await;
96        seen_messages.insert(self.try_compute_message_key(&message)?, timestamp);
97
98        Ok(())
99    }
100
101    #[cfg(test)]
102    /// Sets the expiration timestamp for all seen messages (for testing purposes).
103    async fn set_seen_messages_expiration_timestamp(&self, timestamp: u64) {
104        let mut seen_messages = self.seen_messages.lock().await;
105        for seen_timestamp in seen_messages.values_mut() {
106            *seen_timestamp = timestamp;
107        }
108    }
109
110    #[cfg(test)]
111    /// Returns the count of seen messages (for testing purposes).
112    async fn seen_messages_count(&self) -> usize {
113        self.seen_messages.lock().await.len()
114    }
115
116    #[cfg(test)]
117    /// Checks if a message has been seen (for testing purposes).
118    async fn has_seen_message(&self, message: &M) -> StdResult<bool> {
119        Ok(self
120            .seen_messages
121            .lock()
122            .await
123            .contains_key(&self.try_compute_message_key(message)?))
124    }
125}
126
127#[async_trait::async_trait]
128impl<M: TryFromBytes + TryToBytes + Debug + Send + Sync + Clone + Eq + Clone> DmqConsumerClient<M>
129    for DmqConsumerClientDeduplicator<M>
130{
131    async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
132        let messages = self.inner.consume_messages().await;
133        let current_timestamp = self.current_timestamp()?;
134        self.remove_expired_messages(current_timestamp).await;
135        let messages = messages?;
136
137        let mut deduplicated_messages = Vec::new();
138        for (message, party_id) in messages {
139            if !self.has_message_been_seen(&message).await? {
140                self.mark_message_as_seen(message.clone(), current_timestamp).await?;
141                deduplicated_messages.push((message, party_id));
142            }
143        }
144
145        Ok(deduplicated_messages)
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use std::time::Duration;
152
153    use crate::test::{
154        double::DmqConsumerFake, double::FakeUnixTimestampProvider, payload::DmqMessageTestPayload,
155    };
156
157    use super::*;
158
159    fn create_deduplicator(
160        inner_results: Vec<StdResult<Vec<(DmqMessageTestPayload, PartyId)>>>,
161        current_timestamp: u64,
162        ttl: Duration,
163    ) -> DmqConsumerClientDeduplicator<DmqMessageTestPayload> {
164        DmqConsumerClientDeduplicator::new(
165            Arc::new(DmqConsumerFake::new(inner_results)),
166            Arc::new(FakeUnixTimestampProvider::new(current_timestamp)),
167            ttl,
168        )
169    }
170
171    #[tokio::test]
172    async fn returns_not_already_seen_messages_from_inner_client() {
173        let expected_messages = vec![
174            (
175                DmqMessageTestPayload::new(b"message-1"),
176                "party-1".to_string(),
177            ),
178            (
179                DmqMessageTestPayload::new(b"message-2"),
180                "party-2".to_string(),
181            ),
182        ];
183        let deduplicator = create_deduplicator(
184            vec![Ok(expected_messages.clone())],
185            1000,
186            Duration::from_secs(600),
187        );
188
189        let messages = deduplicator.consume_messages().await.unwrap();
190
191        assert_eq!(expected_messages, messages);
192    }
193
194    #[tokio::test]
195    async fn returns_nothing_when_inner_returns_nothing() {
196        let deduplicator = create_deduplicator(vec![Ok(vec![])], 1000, Duration::from_secs(600));
197
198        let messages = deduplicator.consume_messages().await.unwrap();
199
200        assert!(messages.is_empty());
201    }
202
203    #[tokio::test]
204    async fn returns_error_from_failing_inner_client() {
205        let deduplicator = create_deduplicator(
206            vec![Err(anyhow::anyhow!("Inner client error"))],
207            1000,
208            Duration::from_secs(600),
209        );
210
211        let result = deduplicator.consume_messages().await;
212
213        result.expect_err("Should return an error");
214    }
215
216    #[tokio::test]
217    async fn filters_out_already_seen_messages_in_same_call() {
218        let duplicate_message = DmqMessageTestPayload::new(b"duplicate");
219        let unique_message = DmqMessageTestPayload::new(b"unique");
220        let inner_results = vec![Ok(vec![
221            (duplicate_message.clone(), "party-1".to_string()),
222            (duplicate_message.clone(), "party-2".to_string()),
223            (unique_message.clone(), "party-3".to_string()),
224        ])];
225        let deduplicator = create_deduplicator(inner_results, 1000, Duration::from_secs(600));
226
227        let messages = deduplicator.consume_messages().await.unwrap();
228
229        assert_eq!(
230            vec![
231                (duplicate_message, "party-1".to_string()),
232                (unique_message, "party-3".to_string()),
233            ],
234            messages
235        );
236    }
237
238    #[tokio::test]
239    async fn filters_out_already_seen_messages_across_calls() {
240        let message_1 = DmqMessageTestPayload::new(b"message-1");
241        let message_2 = DmqMessageTestPayload::new(b"message-2");
242        let inner_results = vec![
243            Ok(vec![(message_1.clone(), "party-1".to_string())]),
244            Ok(vec![
245                (message_1.clone(), "party-1-duplicate".to_string()),
246                (message_2.clone(), "party-2".to_string()),
247            ]),
248        ];
249        let deduplicator = create_deduplicator(inner_results, 1000, Duration::from_secs(600));
250
251        let batch_1 = deduplicator.consume_messages().await.unwrap();
252        let batch_2 = deduplicator.consume_messages().await.unwrap();
253
254        assert_eq!(vec![(message_1, "party-1".to_string())], batch_1);
255        assert_eq!(vec![(message_2, "party-2".to_string())], batch_2);
256    }
257
258    #[tokio::test]
259    async fn cleans_up_expired_entries_on_consume() {
260        let message_1 = DmqMessageTestPayload::new(b"message-1");
261        let message_2 = DmqMessageTestPayload::new(b"message-2");
262        let deduplicator = create_deduplicator(
263            vec![
264                Ok(vec![(message_1.clone(), "party-1".to_string())]),
265                Ok(vec![(message_2.clone(), "party-2".to_string())]),
266            ],
267            1000,
268            Duration::from_secs(100),
269        );
270
271        deduplicator.consume_messages().await.unwrap();
272        assert_eq!(1, deduplicator.seen_messages_count().await);
273
274        deduplicator.set_seen_messages_expiration_timestamp(800).await;
275        deduplicator.consume_messages().await.unwrap();
276
277        assert_eq!(1, deduplicator.seen_messages_count().await);
278        assert!(deduplicator.has_seen_message(&message_2).await.unwrap());
279        assert!(!deduplicator.has_seen_message(&message_1).await.unwrap());
280    }
281
282    #[tokio::test]
283    async fn cleans_up_expired_entries_when_inner_client_fails() {
284        let message = DmqMessageTestPayload::new(b"message");
285        let deduplicator = create_deduplicator(
286            vec![
287                Ok(vec![(message.clone(), "party-1".to_string())]),
288                Err(anyhow::anyhow!("Inner client error")),
289            ],
290            1000,
291            Duration::from_secs(100),
292        );
293
294        deduplicator.consume_messages().await.unwrap();
295        assert_eq!(1, deduplicator.seen_messages_count().await);
296
297        deduplicator.set_seen_messages_expiration_timestamp(800).await;
298        deduplicator
299            .consume_messages()
300            .await
301            .expect_err("Should return an error");
302
303        assert_eq!(0, deduplicator.seen_messages_count().await);
304    }
305}