mithril_dmq/consumer/client/
pallas.rs

1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqClient, miniprotocols::localmsgnotification::State};
5use slog::{Logger, debug, error};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9    CardanoNetwork, StdResult,
10    crypto_helper::{OpCert, TryFromBytes},
11    entities::PartyId,
12    logging::LoggerExtensions,
13};
14
15use crate::DmqConsumerClient;
16
17/// A DMQ client consumer implementation.
18///
19/// This implementation is built upon the n2c mini-protocols DMQ implementation in Pallas.
20pub struct DmqConsumerClientPallas<M: TryFromBytes + Debug> {
21    socket: PathBuf,
22    network: CardanoNetwork,
23    client: Mutex<Option<DmqClient>>,
24    logger: Logger,
25    phantom: PhantomData<M>,
26}
27
28impl<M: TryFromBytes + Debug> DmqConsumerClientPallas<M> {
29    /// Creates a new `DmqConsumerClientPallas` instance.
30    pub fn new(socket: PathBuf, network: CardanoNetwork, logger: Logger) -> Self {
31        Self {
32            socket,
33            network,
34            client: Mutex::new(None),
35            logger: logger.new_with_component_name::<Self>(),
36            phantom: PhantomData,
37        }
38    }
39
40    /// Creates and returns a new `DmqClient` connected to the specified socket.
41    async fn new_client(&self) -> StdResult<DmqClient> {
42        debug!(
43            self.logger,
44            "Create new DMQ client";
45            "socket" => ?self.socket,
46            "network" => ?self.network
47        );
48        DmqClient::connect(&self.socket, self.network.magic_id())
49            .await
50            .with_context(|| "DmqConsumerClientPallas failed to create a new client")
51    }
52
53    /// Gets the cached `DmqClient`, creating a new one if it does not exist.
54    async fn get_client(&self) -> StdResult<MutexGuard<'_, Option<DmqClient>>> {
55        {
56            // Run this in a separate block to avoid dead lock on the Mutex
57            let client_lock = self.client.lock().await;
58            if client_lock.as_ref().is_some() {
59                return Ok(client_lock);
60            }
61        }
62
63        let mut client_lock = self.client.lock().await;
64        *client_lock = Some(self.new_client().await?);
65
66        Ok(client_lock)
67    }
68
69    /// Drops the current `DmqClient`, if it exists.
70    async fn drop_client(&self) -> StdResult<()> {
71        debug!(
72            self.logger,
73            "Drop existing DMQ client";
74            "socket" => ?self.socket,
75            "network" => ?self.network
76        );
77        let mut client_lock = self.client.try_lock()?;
78        if let Some(client) = client_lock.take() {
79            client.abort().await;
80        }
81
82        Ok(())
83    }
84
85    #[cfg(test)]
86    /// Check if the client already exists (test only).
87    async fn has_client(&self) -> bool {
88        let client_lock = self.client.lock().await;
89
90        client_lock.as_ref().is_some()
91    }
92
93    async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
94        debug!(self.logger, "Waiting for messages from DMQ...");
95        let mut client_guard = self.get_client().await?;
96        let client = client_guard.as_mut().ok_or(anyhow!("DMQ client does not exist"))?;
97        if *client.msg_notification().state() == State::Idle {
98            client
99                .msg_notification()
100                .send_request_messages_blocking()
101                .await
102                .with_context(|| "Failed to request notifications from DMQ server: {}")?;
103        }
104
105        let reply = client
106            .msg_notification()
107            .recv_next_reply()
108            .await
109            .with_context(|| "Failed to receive notifications from DMQ server")?;
110        debug!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
111
112        reply
113            .0
114            .into_iter()
115            .map(|dmq_message| {
116                let opcert = OpCert::try_from_bytes(&dmq_message.operational_certificate)
117                    .with_context(|| "Failed to parse operational certificate")?;
118                let party_id = opcert.compute_protocol_party_id()?;
119                let payload = M::try_from_bytes(&dmq_message.msg_body)
120                    .with_context(|| "Failed to parse DMQ message body")?;
121
122                Ok((payload, party_id))
123            })
124            .collect::<StdResult<Vec<_>>>()
125            .with_context(|| "Failed to parse DMQ messages")
126    }
127}
128
129#[async_trait::async_trait]
130impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumerClient<M> for DmqConsumerClientPallas<M> {
131    async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
132        let messages = self.consume_messages_internal().await;
133        if messages.is_err() {
134            self.drop_client().await?;
135        }
136
137        messages
138    }
139}
140
141impl<M: TryFromBytes + Debug> Drop for DmqConsumerClientPallas<M> {
142    fn drop(&mut self) {
143        tokio::task::block_in_place(|| {
144            tokio::runtime::Handle::current().block_on(async {
145                if let Err(e) = self.drop_client().await {
146                    error!(self.logger, "Failed to drop DMQ consumer client: {}", e);
147                }
148            });
149        });
150    }
151}
152
153#[cfg(all(test, unix))]
154mod tests {
155
156    use std::{fs, future, time::Duration, vec};
157
158    use mithril_common::{crypto_helper::TryToBytes, current_function, test::TempDir};
159    use pallas_network::{
160        facades::DmqServer,
161        miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
162    };
163    use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
164
165    use crate::test::{TestLogger, payload::DmqMessageTestPayload};
166
167    use super::*;
168
169    fn create_temp_dir(folder_name: &str) -> PathBuf {
170        TempDir::create_with_short_path("dmq_consumer", folder_name)
171    }
172
173    fn fake_msgs() -> Vec<DmqMsg> {
174        vec![
175            DmqMsg {
176                msg_id: vec![0, 1],
177                msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
178                block_number: 10,
179                ttl: 100,
180                kes_signature: vec![0, 1, 2, 3],
181                operational_certificate: vec![
182                    130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
183                    198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
184                    203, 41, 0, 0, 88, 64, 212, 171, 206, 39, 218, 5, 255, 3, 193, 52, 44, 198,
185                    171, 83, 19, 80, 114, 225, 186, 191, 156, 192, 84, 146, 245, 159, 31, 240, 9,
186                    247, 4, 87, 170, 168, 98, 199, 21, 139, 19, 190, 12, 251, 65, 215, 169, 26, 86,
187                    37, 137, 188, 17, 14, 178, 205, 175, 93, 39, 86, 4, 138, 187, 234, 95, 5, 88,
188                    32, 32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
189                    231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
190                ],
191                kes_period: 10,
192            },
193            DmqMsg {
194                msg_id: vec![1, 2],
195                msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
196                block_number: 11,
197                ttl: 100,
198                kes_signature: vec![1, 2, 3, 4],
199                operational_certificate: vec![
200                    130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
201                    198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
202                    203, 41, 0, 0, 88, 64, 132, 4, 199, 39, 190, 173, 88, 102, 121, 117, 55, 62,
203                    39, 189, 113, 96, 175, 24, 171, 240, 74, 42, 139, 202, 128, 185, 44, 130, 209,
204                    77, 191, 122, 196, 224, 33, 158, 187, 156, 203, 190, 173, 150, 247, 87, 172,
205                    58, 153, 185, 157, 87, 128, 14, 187, 107, 187, 215, 105, 195, 107, 135, 172,
206                    43, 173, 9, 88, 32, 77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105,
207                    240, 103, 245, 159, 147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65,
208                    245, 200,
209                ],
210                kes_period: 11,
211            },
212        ]
213    }
214
215    fn setup_dmq_server(
216        socket_path: PathBuf,
217        reply_messages: Vec<DmqMsg>,
218    ) -> JoinHandle<DmqServer> {
219        tokio::spawn({
220            async move {
221                // server setup
222                if socket_path.exists() {
223                    fs::remove_file(socket_path.clone()).unwrap();
224                }
225                let listener = UnixListener::bind(socket_path).unwrap();
226                let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
227                    .await
228                    .unwrap();
229
230                // init local msg notification server
231                let server_msg = server.msg_notification();
232
233                // server waits for blocking request from client
234                let request = server_msg.recv_next_request().await.unwrap();
235                assert_eq!(request, localmsgnotification::Request::Blocking);
236
237                if !reply_messages.is_empty() {
238                    // server replies with messages if any
239                    server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
240                    assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
241                } else {
242                    // server waits if no message available
243                    future::pending().await
244                }
245
246                server
247            }
248        })
249    }
250
251    #[tokio::test(flavor = "multi_thread")]
252    async fn pallas_dmq_consumer_client_succeeds_when_messages_are_available() {
253        let socket_path = create_temp_dir(current_function!()).join("node.socket");
254        let reply_messages = fake_msgs();
255        let server = setup_dmq_server(socket_path.clone(), reply_messages);
256        let client = tokio::spawn(async move {
257            // sleep to avoid refused connection from the server
258            tokio::time::sleep(Duration::from_millis(10)).await;
259
260            let consumer = DmqConsumerClientPallas::new(
261                socket_path,
262                CardanoNetwork::TestNet(0),
263                TestLogger::stdout(),
264            );
265
266            consumer.consume_messages().await.unwrap()
267        });
268
269        let (_, client_res) = tokio::join!(server, client);
270        let messages = client_res.unwrap();
271
272        assert_eq!(
273            vec![
274                (
275                    DmqMessageTestPayload::new(b"msg_1"),
276                    "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
277                ),
278                (
279                    DmqMessageTestPayload::new(b"msg_2"),
280                    "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
281                ),
282            ],
283            messages
284        );
285    }
286
287    #[tokio::test]
288    async fn pallas_dmq_consumer_client_blocks_when_no_message_available() {
289        let socket_path = create_temp_dir(current_function!()).join("node.socket");
290        let reply_messages = vec![];
291        let server = setup_dmq_server(socket_path.clone(), reply_messages);
292        let client = tokio::spawn(async move {
293            // sleep to avoid refused connection from the server
294            tokio::time::sleep(Duration::from_millis(10)).await;
295
296            let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
297                socket_path,
298                CardanoNetwork::TestNet(0),
299                TestLogger::stdout(),
300            );
301
302            consumer.consume_messages().await.unwrap();
303        });
304
305        let result = tokio::select!(
306            _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
307            _res =  client  => {Ok(())},
308            _res =  server  => {Ok(())},
309        );
310
311        result.expect_err("Should have timed out");
312    }
313
314    #[tokio::test(flavor = "multi_thread")]
315    async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
316        let socket_path = create_temp_dir(current_function!()).join("node.socket");
317        let reply_messages = fake_msgs();
318        let server = setup_dmq_server(socket_path.clone(), reply_messages);
319        let client = tokio::spawn(async move {
320            // sleep to avoid refused connection from the server
321            tokio::time::sleep(Duration::from_millis(10)).await;
322
323            let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
324                socket_path,
325                CardanoNetwork::TestNet(0),
326                TestLogger::stdout(),
327            );
328
329            consumer.consume_messages().await.unwrap();
330
331            consumer
332        });
333
334        let (server_res, client_res) = tokio::join!(server, client);
335        let consumer = client_res.unwrap();
336        let server = server_res.unwrap();
337        server.abort().await;
338
339        let client = tokio::spawn(async move {
340            assert!(consumer.has_client().await, "Client should exist");
341
342            consumer
343                .consume_messages()
344                .await
345                .expect_err("Consuming messages should fail");
346
347            assert!(
348                !consumer.has_client().await,
349                "Client should have been dropped after error"
350            );
351
352            consumer
353        });
354        client.await.unwrap();
355    }
356}