mithril_dmq/consumer/client/
pallas.rs

1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::Context;
4use pallas_network::{facades::DmqClient, miniprotocols::localmsgnotification::State};
5use slog::{Logger, debug, error, trace};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9    StdResult,
10    crypto_helper::{
11        KesPeriod, OpCert, OpCertWithoutColdVerificationKey, TryFromBytes,
12        ed25519::Ed25519VerificationKey,
13    },
14    entities::PartyId,
15    logging::LoggerExtensions,
16};
17
18use crate::{DmqConsumerClient, model::DmqNetwork};
19
20/// A DMQ client consumer implementation.
21///
22/// This implementation is built upon the n2c mini-protocols DMQ implementation in Pallas.
23pub struct DmqConsumerClientPallas<M: TryFromBytes + Debug> {
24    socket: PathBuf,
25    network: DmqNetwork,
26    client: Mutex<Option<DmqClient>>,
27    logger: Logger,
28    phantom: PhantomData<M>,
29}
30
31impl<M: TryFromBytes + Debug> DmqConsumerClientPallas<M> {
32    /// Creates a new `DmqConsumerClientPallas` instance.
33    pub fn new(socket: PathBuf, network: DmqNetwork, logger: Logger) -> Self {
34        Self {
35            socket,
36            network,
37            client: Mutex::new(None),
38            logger: logger.new_with_component_name::<Self>(),
39            phantom: PhantomData,
40        }
41    }
42
43    /// Creates and returns a new `DmqClient` connected to the specified socket.
44    async fn new_client(&self) -> StdResult<DmqClient> {
45        debug!(
46            self.logger,
47            "Create new DMQ client";
48            "socket" => ?self.socket,
49            "network" => ?self.network
50        );
51        DmqClient::connect(&self.socket, self.network.magic_id())
52            .await
53            .with_context(|| "DmqConsumerClientPallas failed to create a new client")
54    }
55
56    /// Gets the cached `DmqClient`, creating a new one if it does not exist.
57    async fn get_client(&self) -> StdResult<MutexGuard<'_, Option<DmqClient>>> {
58        {
59            // Run this in a separate block to avoid dead lock on the Mutex
60            let client_lock = self.client.lock().await;
61            if client_lock.as_ref().is_some() {
62                return Ok(client_lock);
63            }
64        }
65
66        let mut client_lock = self.client.lock().await;
67        *client_lock = Some(self.new_client().await?);
68
69        Ok(client_lock)
70    }
71
72    /// Drops the current `DmqClient`, if it exists.
73    async fn drop_client(&self) -> StdResult<()> {
74        debug!(
75            self.logger,
76            "Drop existing DMQ client";
77            "socket" => ?self.socket,
78            "network" => ?self.network
79        );
80        let mut client_lock = self.client.try_lock()?;
81        if let Some(client) = client_lock.take() {
82            client.abort().await;
83        }
84
85        Ok(())
86    }
87
88    #[cfg(test)]
89    /// Check if the client already exists (test only).
90    async fn has_client(&self) -> bool {
91        let client_lock = self.client.lock().await;
92
93        client_lock.as_ref().is_some()
94    }
95
96    async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
97        debug!(self.logger, "Waiting for messages from DMQ...");
98        let mut client_guard = self.get_client().await?;
99        let client = client_guard.as_mut().with_context(|| "DMQ client does not exist")?;
100        if *client.msg_notification().state() == State::Idle {
101            client
102                .msg_notification()
103                .send_request_messages_blocking()
104                .await
105                .with_context(|| "Failed to request notifications from DMQ server: {}")?;
106        }
107
108        let reply = client
109            .msg_notification()
110            .recv_next_reply()
111            .await
112            .with_context(|| "Failed to receive notifications from DMQ server")?;
113        debug!(self.logger, "Received single signatures from DMQ"; "total_messages" => reply.0.len());
114        trace!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
115
116        reply
117            .0
118            .into_iter()
119            .map(|dmq_message| {
120                debug!(
121                    self.logger,
122                    "Received DMQ message";
123                    "msg_id" => hex::encode(&dmq_message.msg_payload.msg_id),
124                    "kes_period" => dmq_message.msg_payload.kes_period,
125                    "expires_at" => dmq_message.msg_payload.expires_at,
126                );
127                let opcert_without_verification_key = OpCertWithoutColdVerificationKey::try_new(
128                    &dmq_message.operational_certificate.kes_vk,
129                    dmq_message.operational_certificate.issue_number,
130                    KesPeriod(dmq_message.operational_certificate.start_kes_period),
131                    &dmq_message.operational_certificate.cert_sig,
132                )
133                .with_context(|| "Failed to parse operational certificate")?;
134                let cold_verification_key =
135                    Ed25519VerificationKey::from_bytes(&dmq_message.cold_verification_key)
136                        .with_context(|| "Failed to parse cold verification key")?
137                        .into_inner();
138                let opcert: OpCert =
139                    (opcert_without_verification_key, cold_verification_key).into();
140                let party_id = opcert.compute_protocol_party_id()?;
141                let payload = M::try_from_bytes(&dmq_message.msg_payload.msg_body)
142                    .with_context(|| "Failed to parse DMQ message body")?;
143
144                Ok((payload, party_id))
145            })
146            .collect::<StdResult<Vec<_>>>()
147            .with_context(|| "Failed to parse DMQ messages")
148    }
149}
150
151#[async_trait::async_trait]
152impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumerClient<M> for DmqConsumerClientPallas<M> {
153    async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
154        let messages = self.consume_messages_internal().await;
155        if messages.is_err() {
156            self.drop_client().await?;
157        }
158
159        messages
160    }
161}
162
163impl<M: TryFromBytes + Debug> Drop for DmqConsumerClientPallas<M> {
164    fn drop(&mut self) {
165        tokio::task::block_in_place(|| {
166            tokio::runtime::Handle::current().block_on(async {
167                if let Err(e) = self.drop_client().await {
168                    error!(self.logger, "Failed to drop DMQ consumer client: {}", e);
169                }
170            });
171        });
172    }
173}
174
175#[cfg(all(test, unix))]
176mod tests {
177    use anyhow::anyhow;
178    use pallas_network::{
179        facades::DmqServer,
180        miniprotocols::{
181            localmsgnotification,
182            localmsgsubmission::{DmqMsg, DmqMsgOperationalCertificate, DmqMsgPayload},
183        },
184    };
185    use std::{fs, future, time::Duration, vec};
186    use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
187
188    use mithril_common::{crypto_helper::TryToBytes, current_function, test::TempDir};
189
190    use crate::test::{TestLogger, payload::DmqMessageTestPayload};
191
192    use super::*;
193
194    fn create_temp_dir(folder_name: &str) -> PathBuf {
195        TempDir::create_with_short_path("dmq_consumer", folder_name)
196    }
197
198    fn fake_msgs() -> Vec<DmqMsg> {
199        vec![
200            DmqMsg {
201                msg_payload: DmqMsgPayload {
202                    msg_id: vec![0, 1],
203                    msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
204                    kes_period: 10,
205                    expires_at: 100,
206                },
207                kes_signature: vec![0, 1, 2, 3],
208                operational_certificate: DmqMsgOperationalCertificate {
209                    kes_vk: vec![
210                        50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
211                        131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
212                    ],
213                    issue_number: 0,
214                    start_kes_period: 0,
215                    cert_sig: vec![
216                        207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
217                        5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
218                        190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
219                        209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
220                    ],
221                },
222                cold_verification_key: vec![
223                    32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
224                    231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
225                ],
226            },
227            DmqMsg {
228                msg_payload: DmqMsgPayload {
229                    msg_id: vec![1, 2],
230                    msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
231                    kes_period: 11,
232                    expires_at: 101,
233                },
234                kes_signature: vec![1, 2, 3, 4],
235                operational_certificate: DmqMsgOperationalCertificate {
236                    kes_vk: vec![
237                        50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
238                        131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
239                    ],
240                    issue_number: 0,
241                    start_kes_period: 0,
242                    cert_sig: vec![
243                        207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
244                        5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
245                        190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
246                        209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
247                    ],
248                },
249                cold_verification_key: vec![
250                    77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105, 240, 103, 245, 159,
251                    147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65, 245, 200,
252                ],
253            },
254        ]
255    }
256
257    fn setup_dmq_server(
258        socket_path: PathBuf,
259        reply_messages: Vec<DmqMsg>,
260    ) -> JoinHandle<DmqServer> {
261        tokio::spawn({
262            async move {
263                // server setup
264                if socket_path.exists() {
265                    fs::remove_file(socket_path.clone()).unwrap();
266                }
267                let listener = UnixListener::bind(socket_path).unwrap();
268                let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
269                    .await
270                    .unwrap();
271
272                // init local msg notification server
273                let server_msg = server.msg_notification();
274
275                // server waits for blocking request from client
276                let request = server_msg.recv_next_request().await.unwrap();
277                assert_eq!(request, localmsgnotification::Request::Blocking);
278
279                if !reply_messages.is_empty() {
280                    // server replies with messages if any
281                    server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
282                    assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
283                } else {
284                    // server waits if no message available
285                    future::pending().await
286                }
287
288                server
289            }
290        })
291    }
292
293    #[tokio::test(flavor = "multi_thread")]
294    async fn pallas_dmq_consumer_client_succeeds_when_messages_are_available() {
295        let socket_path = create_temp_dir(current_function!()).join("node.socket");
296        let reply_messages = fake_msgs();
297        let server = setup_dmq_server(socket_path.clone(), reply_messages);
298        let client = tokio::spawn(async move {
299            // sleep to avoid refused connection from the server
300            tokio::time::sleep(Duration::from_millis(10)).await;
301
302            let consumer = DmqConsumerClientPallas::new(
303                socket_path,
304                DmqNetwork::TestNet(0),
305                TestLogger::stdout(),
306            );
307
308            consumer.consume_messages().await.unwrap()
309        });
310
311        let (_, client_res) = tokio::join!(server, client);
312        let messages = client_res.unwrap();
313
314        assert_eq!(
315            vec![
316                (
317                    DmqMessageTestPayload::new(b"msg_1"),
318                    "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
319                ),
320                (
321                    DmqMessageTestPayload::new(b"msg_2"),
322                    "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
323                ),
324            ],
325            messages
326        );
327    }
328
329    #[tokio::test]
330    async fn pallas_dmq_consumer_client_blocks_when_no_message_available() {
331        let socket_path = create_temp_dir(current_function!()).join("node.socket");
332        let reply_messages = vec![];
333        let server = setup_dmq_server(socket_path.clone(), reply_messages);
334        let client = tokio::spawn(async move {
335            // sleep to avoid refused connection from the server
336            tokio::time::sleep(Duration::from_millis(10)).await;
337
338            let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
339                socket_path,
340                DmqNetwork::TestNet(0),
341                TestLogger::stdout(),
342            );
343
344            consumer.consume_messages().await.unwrap();
345        });
346
347        let result = tokio::select!(
348            _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
349            _res =  client  => {Ok(())},
350            _res =  server  => {Ok(())},
351        );
352
353        result.expect_err("Should have timed out");
354    }
355
356    #[tokio::test(flavor = "multi_thread")]
357    async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
358        let socket_path = create_temp_dir(current_function!()).join("node.socket");
359        let reply_messages = fake_msgs();
360        let server = setup_dmq_server(socket_path.clone(), reply_messages);
361        let client = tokio::spawn(async move {
362            // sleep to avoid refused connection from the server
363            tokio::time::sleep(Duration::from_millis(10)).await;
364
365            let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
366                socket_path,
367                DmqNetwork::TestNet(0),
368                TestLogger::stdout(),
369            );
370
371            consumer.consume_messages().await.unwrap();
372
373            consumer
374        });
375
376        let (server_res, client_res) = tokio::join!(server, client);
377        let consumer = client_res.unwrap();
378        let server = server_res.unwrap();
379        server.abort().await;
380
381        let client = tokio::spawn(async move {
382            assert!(consumer.has_client().await, "Client should exist");
383
384            consumer
385                .consume_messages()
386                .await
387                .expect_err("Consuming messages should fail");
388
389            assert!(
390                !consumer.has_client().await,
391                "Client should have been dropped after error"
392            );
393
394            consumer
395        });
396        client.await.unwrap();
397    }
398}