mithril_dmq/consumer/
pallas.rs

1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::facades::DmqClient;
5use slog::{Logger, debug, error};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9    CardanoNetwork, StdResult,
10    crypto_helper::{OpCert, TryFromBytes},
11    entities::PartyId,
12    logging::LoggerExtensions,
13};
14
15use crate::DmqConsumer;
16
17/// A DMQ consumer implementation.
18///
19/// This implementation is built upon the n2c mini-protocols DMQ implementation in Pallas.
20pub struct DmqConsumerPallas<M: TryFromBytes + Debug> {
21    socket: PathBuf,
22    network: CardanoNetwork,
23    client: Mutex<Option<DmqClient>>,
24    logger: Logger,
25    phantom: PhantomData<M>,
26}
27
28impl<M: TryFromBytes + Debug> DmqConsumerPallas<M> {
29    /// Creates a new `DmqConsumerPallas` instance.
30    pub fn new(socket: PathBuf, network: CardanoNetwork, logger: Logger) -> Self {
31        Self {
32            socket,
33            network,
34            client: Mutex::new(None),
35            logger: logger.new_with_component_name::<Self>(),
36            phantom: PhantomData,
37        }
38    }
39
40    /// Creates and returns a new `DmqClient` connected to the specified socket.
41    async fn new_client(&self) -> StdResult<DmqClient> {
42        debug!(
43            self.logger,
44            "Create new DMQ client";
45            "socket" => ?self.socket,
46            "network" => ?self.network
47        );
48        DmqClient::connect(&self.socket, self.network.code())
49            .await
50            .with_context(|| "DmqConsumerPallas failed to create a new client")
51    }
52
53    /// Gets the cached `DmqClient`, creating a new one if it does not exist.
54    async fn get_client(&self) -> StdResult<MutexGuard<Option<DmqClient>>> {
55        {
56            // Run this in a separate block to avoid dead lock on the Mutex
57            let client_lock = self.client.lock().await;
58            if client_lock.as_ref().is_some() {
59                return Ok(client_lock);
60            }
61        }
62
63        let mut client_lock = self.client.lock().await;
64        *client_lock = Some(self.new_client().await?);
65
66        Ok(client_lock)
67    }
68
69    /// Drops the current `DmqClient`, if it exists.
70    async fn drop_client(&self) -> StdResult<()> {
71        debug!(
72            self.logger,
73            "Drop existing DMQ client";
74            "socket" => ?self.socket,
75            "network" => ?self.network
76        );
77        let mut client_lock = self.client.lock().await;
78        if let Some(client) = client_lock.take() {
79            client.abort().await;
80        }
81
82        Ok(())
83    }
84
85    #[cfg(test)]
86    /// Check if the client already exists (test only).
87    async fn has_client(&self) -> bool {
88        let client_lock = self.client.lock().await;
89
90        client_lock.as_ref().is_some()
91    }
92
93    async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
94        debug!(self.logger, "Waiting for messages from DMQ...");
95        let mut client_guard = self.get_client().await?;
96        let client = client_guard.as_mut().ok_or(anyhow!("DMQ client does not exist"))?;
97        client
98            .msg_notification()
99            .send_request_messages_blocking()
100            .await
101            .with_context(|| "Failed to request notifications from DMQ server: {}")?;
102
103        let reply = client
104            .msg_notification()
105            .recv_next_reply()
106            .await
107            .with_context(|| "Failed to receive notifications from DMQ server")?;
108        debug!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
109        if let Err(e) = client.msg_notification().send_done().await {
110            error!(self.logger, "Failed to send Done"; "error" => ?e);
111        }
112
113        reply
114            .0
115            .into_iter()
116            .map(|dmq_message| {
117                let opcert = OpCert::try_from_bytes(&dmq_message.operational_certificate)
118                    .with_context(|| "Failed to parse operational certificate")?;
119                let party_id = opcert.compute_protocol_party_id()?;
120                let payload = M::try_from_bytes(&dmq_message.msg_body)
121                    .with_context(|| "Failed to parse DMQ message body")?;
122
123                Ok((payload, party_id))
124            })
125            .collect::<StdResult<Vec<_>>>()
126            .with_context(|| "Failed to parse DMQ messages")
127    }
128}
129
130#[async_trait::async_trait]
131impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumer<M> for DmqConsumerPallas<M> {
132    async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
133        let messages = self.consume_messages_internal().await;
134        if messages.is_err() {
135            self.drop_client().await?;
136        }
137
138        messages
139    }
140}
141
142#[cfg(all(test, unix))]
143mod tests {
144
145    use std::{fs, future, time::Duration, vec};
146
147    use mithril_common::{crypto_helper::TryToBytes, current_function, test_utils::TempDir};
148    use pallas_network::{
149        facades::DmqServer,
150        miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
151    };
152    use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
153
154    use crate::{test::payload::DmqMessageTestPayload, test_tools::TestLogger};
155
156    use super::*;
157
158    fn create_temp_dir(folder_name: &str) -> PathBuf {
159        TempDir::create_with_short_path("dmq_consumer", folder_name)
160    }
161
162    fn fake_msgs() -> Vec<DmqMsg> {
163        vec![
164            DmqMsg {
165                msg_id: vec![0, 1],
166                msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
167                block_number: 10,
168                ttl: 100,
169                kes_signature: vec![0, 1, 2, 3],
170                operational_certificate: vec![
171                    130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
172                    198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
173                    203, 41, 0, 0, 88, 64, 212, 171, 206, 39, 218, 5, 255, 3, 193, 52, 44, 198,
174                    171, 83, 19, 80, 114, 225, 186, 191, 156, 192, 84, 146, 245, 159, 31, 240, 9,
175                    247, 4, 87, 170, 168, 98, 199, 21, 139, 19, 190, 12, 251, 65, 215, 169, 26, 86,
176                    37, 137, 188, 17, 14, 178, 205, 175, 93, 39, 86, 4, 138, 187, 234, 95, 5, 88,
177                    32, 32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
178                    231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
179                ],
180                kes_period: 10,
181            },
182            DmqMsg {
183                msg_id: vec![1, 2],
184                msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
185                block_number: 11,
186                ttl: 100,
187                kes_signature: vec![1, 2, 3, 4],
188                operational_certificate: vec![
189                    130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
190                    198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
191                    203, 41, 0, 0, 88, 64, 132, 4, 199, 39, 190, 173, 88, 102, 121, 117, 55, 62,
192                    39, 189, 113, 96, 175, 24, 171, 240, 74, 42, 139, 202, 128, 185, 44, 130, 209,
193                    77, 191, 122, 196, 224, 33, 158, 187, 156, 203, 190, 173, 150, 247, 87, 172,
194                    58, 153, 185, 157, 87, 128, 14, 187, 107, 187, 215, 105, 195, 107, 135, 172,
195                    43, 173, 9, 88, 32, 77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105,
196                    240, 103, 245, 159, 147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65,
197                    245, 200,
198                ],
199                kes_period: 11,
200            },
201        ]
202    }
203
204    fn setup_dmq_server(
205        socket_path: PathBuf,
206        reply_messages: Vec<DmqMsg>,
207    ) -> JoinHandle<DmqServer> {
208        tokio::spawn({
209            async move {
210                // server setup
211                if socket_path.exists() {
212                    fs::remove_file(socket_path.clone()).unwrap();
213                }
214                let listener = UnixListener::bind(socket_path).unwrap();
215                let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
216                    .await
217                    .unwrap();
218
219                // init local msg notification server
220                let server_msg = server.msg_notification();
221
222                // server waits for blocking request from client
223                let request = server_msg.recv_next_request().await.unwrap();
224                assert_eq!(request, localmsgnotification::Request::Blocking);
225
226                if !reply_messages.is_empty() {
227                    // server replies with messages if any
228                    server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
229                    assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
230
231                    // server receives done from client
232                    server_msg.recv_done().await.unwrap();
233                    assert_eq!(*server_msg.state(), localmsgnotification::State::Done);
234                } else {
235                    // server waits if no message available
236                    future::pending().await
237                }
238
239                server
240            }
241        })
242    }
243
244    #[tokio::test]
245    async fn pallas_dmq_consumer_publisher_succeeds_when_messages_are_available() {
246        let socket_path = create_temp_dir(current_function!()).join("node.socket");
247        let reply_messages = fake_msgs();
248        let server = setup_dmq_server(socket_path.clone(), reply_messages);
249        let client = tokio::spawn(async move {
250            let consumer = DmqConsumerPallas::new(
251                socket_path,
252                CardanoNetwork::TestNet(0),
253                TestLogger::stdout(),
254            );
255
256            consumer.consume_messages().await.unwrap()
257        });
258
259        let (_, client_res) = tokio::join!(server, client);
260        let messages = client_res.unwrap();
261
262        assert_eq!(
263            vec![
264                (
265                    DmqMessageTestPayload::new(b"msg_1"),
266                    "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
267                ),
268                (
269                    DmqMessageTestPayload::new(b"msg_2"),
270                    "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
271                ),
272            ],
273            messages
274        );
275    }
276
277    #[tokio::test]
278    async fn pallas_dmq_consumer_publisher_blocks_when_no_message_available() {
279        let socket_path = create_temp_dir(current_function!()).join("node.socket");
280        let reply_messages = vec![];
281        let server = setup_dmq_server(socket_path.clone(), reply_messages);
282        let client = tokio::spawn(async move {
283            let consumer = DmqConsumerPallas::<DmqMessageTestPayload>::new(
284                socket_path,
285                CardanoNetwork::TestNet(0),
286                TestLogger::stdout(),
287            );
288
289            consumer.consume_messages().await.unwrap();
290        });
291
292        let result = tokio::select!(
293            _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
294            _res =  client  => {Ok(())},
295            _res =  server  => {Ok(())},
296        );
297
298        result.expect_err("Should have timed out");
299    }
300
301    #[tokio::test]
302    async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
303        let socket_path = create_temp_dir(current_function!()).join("node.socket");
304        let reply_messages = fake_msgs();
305        let server = setup_dmq_server(socket_path.clone(), reply_messages);
306        let client = tokio::spawn(async move {
307            let consumer = DmqConsumerPallas::<DmqMessageTestPayload>::new(
308                socket_path,
309                CardanoNetwork::TestNet(0),
310                TestLogger::stdout(),
311            );
312
313            consumer.consume_messages().await.unwrap();
314
315            consumer
316        });
317
318        let (server_res, client_res) = tokio::join!(server, client);
319        let consumer = client_res.unwrap();
320        let server = server_res.unwrap();
321        server.abort().await;
322
323        let client = tokio::spawn(async move {
324            assert!(consumer.has_client().await, "Client should exist");
325
326            consumer
327                .consume_messages()
328                .await
329                .expect_err("Consuming messages should fail");
330
331            assert!(
332                !consumer.has_client().await,
333                "Client should have been dropped after error"
334            );
335
336            consumer
337        });
338        client.await.unwrap();
339    }
340}