mithril_dmq/consumer/client/
pallas.rs

1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqClient, miniprotocols::localmsgnotification::State};
5use slog::{Logger, debug, error};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9    StdResult,
10    crypto_helper::{
11        OpCert, OpCertWithoutColdVerificationKey, TryFromBytes, ed25519::Ed25519VerificationKey,
12    },
13    entities::PartyId,
14    logging::LoggerExtensions,
15};
16
17use crate::{DmqConsumerClient, model::DmqNetwork};
18
19/// A DMQ client consumer implementation.
20///
21/// This implementation is built upon the n2c mini-protocols DMQ implementation in Pallas.
22pub struct DmqConsumerClientPallas<M: TryFromBytes + Debug> {
23    socket: PathBuf,
24    network: DmqNetwork,
25    client: Mutex<Option<DmqClient>>,
26    logger: Logger,
27    phantom: PhantomData<M>,
28}
29
30impl<M: TryFromBytes + Debug> DmqConsumerClientPallas<M> {
31    /// Creates a new `DmqConsumerClientPallas` instance.
32    pub fn new(socket: PathBuf, network: DmqNetwork, logger: Logger) -> Self {
33        Self {
34            socket,
35            network,
36            client: Mutex::new(None),
37            logger: logger.new_with_component_name::<Self>(),
38            phantom: PhantomData,
39        }
40    }
41
42    /// Creates and returns a new `DmqClient` connected to the specified socket.
43    async fn new_client(&self) -> StdResult<DmqClient> {
44        debug!(
45            self.logger,
46            "Create new DMQ client";
47            "socket" => ?self.socket,
48            "network" => ?self.network
49        );
50        DmqClient::connect(&self.socket, self.network.magic_id())
51            .await
52            .with_context(|| "DmqConsumerClientPallas failed to create a new client")
53    }
54
55    /// Gets the cached `DmqClient`, creating a new one if it does not exist.
56    async fn get_client(&self) -> StdResult<MutexGuard<'_, Option<DmqClient>>> {
57        {
58            // Run this in a separate block to avoid dead lock on the Mutex
59            let client_lock = self.client.lock().await;
60            if client_lock.as_ref().is_some() {
61                return Ok(client_lock);
62            }
63        }
64
65        let mut client_lock = self.client.lock().await;
66        *client_lock = Some(self.new_client().await?);
67
68        Ok(client_lock)
69    }
70
71    /// Drops the current `DmqClient`, if it exists.
72    async fn drop_client(&self) -> StdResult<()> {
73        debug!(
74            self.logger,
75            "Drop existing DMQ client";
76            "socket" => ?self.socket,
77            "network" => ?self.network
78        );
79        let mut client_lock = self.client.try_lock()?;
80        if let Some(client) = client_lock.take() {
81            client.abort().await;
82        }
83
84        Ok(())
85    }
86
87    #[cfg(test)]
88    /// Check if the client already exists (test only).
89    async fn has_client(&self) -> bool {
90        let client_lock = self.client.lock().await;
91
92        client_lock.as_ref().is_some()
93    }
94
95    async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
96        debug!(self.logger, "Waiting for messages from DMQ...");
97        let mut client_guard = self.get_client().await?;
98        let client = client_guard.as_mut().ok_or(anyhow!("DMQ client does not exist"))?;
99        if *client.msg_notification().state() == State::Idle {
100            client
101                .msg_notification()
102                .send_request_messages_blocking()
103                .await
104                .with_context(|| "Failed to request notifications from DMQ server: {}")?;
105        }
106
107        let reply = client
108            .msg_notification()
109            .recv_next_reply()
110            .await
111            .with_context(|| "Failed to receive notifications from DMQ server")?;
112        debug!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
113
114        reply
115            .0
116            .into_iter()
117            .map(|dmq_message| {
118                let opcert_without_verification_key = OpCertWithoutColdVerificationKey::try_new(
119                    &dmq_message.operational_certificate.kes_vk,
120                    dmq_message.operational_certificate.issue_number,
121                    dmq_message.operational_certificate.start_kes_period,
122                    &dmq_message.operational_certificate.cert_sig,
123                )
124                .with_context(|| "Failed to parse operational certificate")?;
125                let cold_verification_key =
126                    Ed25519VerificationKey::from_bytes(&dmq_message.cold_verification_key)
127                        .with_context(|| "Failed to parse cold verification key")?
128                        .into_inner();
129                let opcert: OpCert =
130                    (opcert_without_verification_key, cold_verification_key).into();
131                let party_id = opcert.compute_protocol_party_id()?;
132                let payload = M::try_from_bytes(&dmq_message.msg_payload.msg_body)
133                    .with_context(|| "Failed to parse DMQ message body")?;
134
135                Ok((payload, party_id))
136            })
137            .collect::<StdResult<Vec<_>>>()
138            .with_context(|| "Failed to parse DMQ messages")
139    }
140}
141
142#[async_trait::async_trait]
143impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumerClient<M> for DmqConsumerClientPallas<M> {
144    async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
145        let messages = self.consume_messages_internal().await;
146        if messages.is_err() {
147            self.drop_client().await?;
148        }
149
150        messages
151    }
152}
153
154impl<M: TryFromBytes + Debug> Drop for DmqConsumerClientPallas<M> {
155    fn drop(&mut self) {
156        tokio::task::block_in_place(|| {
157            tokio::runtime::Handle::current().block_on(async {
158                if let Err(e) = self.drop_client().await {
159                    error!(self.logger, "Failed to drop DMQ consumer client: {}", e);
160                }
161            });
162        });
163    }
164}
165
166#[cfg(all(test, unix))]
167mod tests {
168
169    use std::{fs, future, time::Duration, vec};
170
171    use mithril_common::{crypto_helper::TryToBytes, current_function, test::TempDir};
172    use pallas_network::{
173        facades::DmqServer,
174        miniprotocols::{
175            localmsgnotification,
176            localmsgsubmission::{DmqMsg, DmqMsgOperationalCertificate, DmqMsgPayload},
177        },
178    };
179    use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
180
181    use crate::test::{TestLogger, payload::DmqMessageTestPayload};
182
183    use super::*;
184
185    fn create_temp_dir(folder_name: &str) -> PathBuf {
186        TempDir::create_with_short_path("dmq_consumer", folder_name)
187    }
188
189    fn fake_msgs() -> Vec<DmqMsg> {
190        vec![
191            DmqMsg {
192                msg_payload: DmqMsgPayload {
193                    msg_id: vec![0, 1],
194                    msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
195                    kes_period: 10,
196                    expires_at: 100,
197                },
198                kes_signature: vec![0, 1, 2, 3],
199                operational_certificate: DmqMsgOperationalCertificate {
200                    kes_vk: vec![
201                        50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
202                        131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
203                    ],
204                    issue_number: 0,
205                    start_kes_period: 0,
206                    cert_sig: vec![
207                        207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
208                        5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
209                        190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
210                        209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
211                    ],
212                },
213                cold_verification_key: vec![
214                    32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
215                    231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
216                ],
217            },
218            DmqMsg {
219                msg_payload: DmqMsgPayload {
220                    msg_id: vec![1, 2],
221                    msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
222                    kes_period: 11,
223                    expires_at: 101,
224                },
225                kes_signature: vec![1, 2, 3, 4],
226                operational_certificate: DmqMsgOperationalCertificate {
227                    kes_vk: vec![
228                        50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
229                        131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
230                    ],
231                    issue_number: 0,
232                    start_kes_period: 0,
233                    cert_sig: vec![
234                        207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
235                        5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
236                        190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
237                        209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
238                    ],
239                },
240                cold_verification_key: vec![
241                    77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105, 240, 103, 245, 159,
242                    147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65, 245, 200,
243                ],
244            },
245        ]
246    }
247
248    fn setup_dmq_server(
249        socket_path: PathBuf,
250        reply_messages: Vec<DmqMsg>,
251    ) -> JoinHandle<DmqServer> {
252        tokio::spawn({
253            async move {
254                // server setup
255                if socket_path.exists() {
256                    fs::remove_file(socket_path.clone()).unwrap();
257                }
258                let listener = UnixListener::bind(socket_path).unwrap();
259                let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
260                    .await
261                    .unwrap();
262
263                // init local msg notification server
264                let server_msg = server.msg_notification();
265
266                // server waits for blocking request from client
267                let request = server_msg.recv_next_request().await.unwrap();
268                assert_eq!(request, localmsgnotification::Request::Blocking);
269
270                if !reply_messages.is_empty() {
271                    // server replies with messages if any
272                    server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
273                    assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
274                } else {
275                    // server waits if no message available
276                    future::pending().await
277                }
278
279                server
280            }
281        })
282    }
283
284    #[tokio::test(flavor = "multi_thread")]
285    async fn pallas_dmq_consumer_client_succeeds_when_messages_are_available() {
286        let socket_path = create_temp_dir(current_function!()).join("node.socket");
287        let reply_messages = fake_msgs();
288        let server = setup_dmq_server(socket_path.clone(), reply_messages);
289        let client = tokio::spawn(async move {
290            // sleep to avoid refused connection from the server
291            tokio::time::sleep(Duration::from_millis(10)).await;
292
293            let consumer = DmqConsumerClientPallas::new(
294                socket_path,
295                DmqNetwork::TestNet(0),
296                TestLogger::stdout(),
297            );
298
299            consumer.consume_messages().await.unwrap()
300        });
301
302        let (_, client_res) = tokio::join!(server, client);
303        let messages = client_res.unwrap();
304
305        assert_eq!(
306            vec![
307                (
308                    DmqMessageTestPayload::new(b"msg_1"),
309                    "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
310                ),
311                (
312                    DmqMessageTestPayload::new(b"msg_2"),
313                    "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
314                ),
315            ],
316            messages
317        );
318    }
319
320    #[tokio::test]
321    async fn pallas_dmq_consumer_client_blocks_when_no_message_available() {
322        let socket_path = create_temp_dir(current_function!()).join("node.socket");
323        let reply_messages = vec![];
324        let server = setup_dmq_server(socket_path.clone(), reply_messages);
325        let client = tokio::spawn(async move {
326            // sleep to avoid refused connection from the server
327            tokio::time::sleep(Duration::from_millis(10)).await;
328
329            let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
330                socket_path,
331                DmqNetwork::TestNet(0),
332                TestLogger::stdout(),
333            );
334
335            consumer.consume_messages().await.unwrap();
336        });
337
338        let result = tokio::select!(
339            _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
340            _res =  client  => {Ok(())},
341            _res =  server  => {Ok(())},
342        );
343
344        result.expect_err("Should have timed out");
345    }
346
347    #[tokio::test(flavor = "multi_thread")]
348    async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
349        let socket_path = create_temp_dir(current_function!()).join("node.socket");
350        let reply_messages = fake_msgs();
351        let server = setup_dmq_server(socket_path.clone(), reply_messages);
352        let client = tokio::spawn(async move {
353            // sleep to avoid refused connection from the server
354            tokio::time::sleep(Duration::from_millis(10)).await;
355
356            let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
357                socket_path,
358                DmqNetwork::TestNet(0),
359                TestLogger::stdout(),
360            );
361
362            consumer.consume_messages().await.unwrap();
363
364            consumer
365        });
366
367        let (server_res, client_res) = tokio::join!(server, client);
368        let consumer = client_res.unwrap();
369        let server = server_res.unwrap();
370        server.abort().await;
371
372        let client = tokio::spawn(async move {
373            assert!(consumer.has_client().await, "Client should exist");
374
375            consumer
376                .consume_messages()
377                .await
378                .expect_err("Consuming messages should fail");
379
380            assert!(
381                !consumer.has_client().await,
382                "Client should have been dropped after error"
383            );
384
385            consumer
386        });
387        client.await.unwrap();
388    }
389}