mithril_dmq/consumer/server/
pallas.rs

1use std::{fs, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqServer, miniprotocols::localmsgnotification::Request};
5use tokio::{
6    join,
7    net::UnixListener,
8    select,
9    sync::{Mutex, MutexGuard, mpsc::UnboundedReceiver, watch::Receiver},
10};
11
12use slog::{Logger, debug, error, info, warn};
13
14use mithril_common::{CardanoNetwork, StdResult, logging::LoggerExtensions};
15
16use crate::{DmqConsumerServer, DmqMessage};
17
18use super::queue::MessageQueue;
19
20/// A DMQ server implementation for messages notification from a DMQ node.
21pub struct DmqConsumerServerPallas {
22    socket: PathBuf,
23    network: CardanoNetwork,
24    server: Mutex<Option<DmqServer>>,
25    messages_receiver: Mutex<Option<UnboundedReceiver<DmqMessage>>>,
26    messages_buffer: MessageQueue,
27    stop_rx: Receiver<()>,
28    logger: Logger,
29}
30
31impl DmqConsumerServerPallas {
32    /// Creates a new instance of [DmqConsumerServerPallas].
33    pub fn new(
34        socket: PathBuf,
35        network: CardanoNetwork,
36        stop_rx: Receiver<()>,
37        logger: Logger,
38    ) -> Self {
39        Self {
40            socket,
41            network,
42            server: Mutex::new(None),
43            messages_receiver: Mutex::new(None),
44            messages_buffer: MessageQueue::default(),
45            stop_rx,
46            logger: logger.new_with_component_name::<Self>(),
47        }
48    }
49
50    /// Creates and returns a new `DmqServer` connected to the specified socket.
51    async fn new_server(&self) -> StdResult<DmqServer> {
52        info!(
53            self.logger,
54            "Creating a new DMQ consumer server";
55            "socket" => ?self.socket,
56            "network" => ?self.network
57        );
58        let magic = self.network.magic_id();
59        if self.socket.exists() {
60            fs::remove_file(self.socket.clone())?;
61        }
62        let listener = UnixListener::bind(&self.socket)
63            .map_err(|err| anyhow!(err))
64            .with_context(|| {
65                format!(
66                    "DmqConsumerServerPallas failed to bind Unix socket at {}",
67                    self.socket.display()
68                )
69            })?;
70
71        DmqServer::accept(&listener, magic)
72            .await
73            .map_err(|err| anyhow!(err))
74            .with_context(|| "DmqConsumerServerPallas failed to create a new server")
75    }
76
77    /// Gets the cached `DmqServer`, creating a new one if it does not exist.
78    async fn get_server(&self) -> StdResult<MutexGuard<'_, Option<DmqServer>>> {
79        {
80            // Run this in a separate block to avoid dead lock on the Mutex
81            let server_lock = self.server.lock().await;
82            if server_lock.as_ref().is_some() {
83                return Ok(server_lock);
84            }
85        }
86
87        let mut server_lock = self.server.lock().await;
88        *server_lock = Some(self.new_server().await?);
89
90        Ok(server_lock)
91    }
92
93    /// Drops the current `DmqServer`, if it exists.
94    async fn drop_server(&self) -> StdResult<()> {
95        debug!(
96            self.logger,
97            "Drop existing DMQ server";
98            "socket" => ?self.socket,
99            "network" => ?self.network
100        );
101        let mut server_lock = self.server.try_lock()?;
102        if let Some(server) = server_lock.take() {
103            server.abort().await;
104        }
105
106        Ok(())
107    }
108
109    /// Registers the receiver for DMQ messages (only one receiver is allowed).
110    pub async fn register_receiver(
111        &self,
112        receiver: UnboundedReceiver<DmqMessage>,
113    ) -> StdResult<()> {
114        debug!(self.logger, "Register message receiver for DMQ messages");
115        let mut receiver_guard = self.messages_receiver.lock().await;
116        *receiver_guard = Some(receiver);
117
118        Ok(())
119    }
120
121    /// Receives incoming messages into the DMQ consumer server.
122    async fn receive_incoming_messages(&self) -> StdResult<()> {
123        info!(
124            self.logger,
125            "Receive incoming messages into DMQ consumer server...";
126            "socket" => ?self.socket,
127            "network" => ?self.network
128        );
129
130        let mut stop_rx = self.stop_rx.clone();
131        let mut receiver = self.messages_receiver.lock().await;
132        match *receiver {
133            Some(ref mut receiver) => loop {
134                select! {
135                    _ = stop_rx.changed() => {
136                        warn!(self.logger, "Stopping DMQ consumer server...");
137
138                        return Ok(());
139                    }
140                    message = receiver.recv() => {
141                        if let Some(message) = message {
142                            debug!(self.logger, "Received a message from the DMQ network"; "message" => ?message);
143                            self.messages_buffer.enqueue(message).await;
144                        } else {
145                            warn!(self.logger, "DMQ message receiver channel closed");
146                            return Ok(());
147                        }
148
149                    }
150                }
151            },
152            None => Err(anyhow!("DMQ message receiver is not registered")),
153        }
154    }
155
156    /// Serves incoming messages from the DMQ consumer server.
157    async fn serve_incoming_messages(&self) -> StdResult<()> {
158        info!(
159            self.logger,
160            "Serve incoming messages from DMQ consumer server...";
161            "socket" => ?self.socket,
162            "network" => ?self.network
163        );
164
165        let mut stop_rx = self.stop_rx.clone();
166        loop {
167            select! {
168                _ = stop_rx.changed() => {
169                    warn!(self.logger, "Stopping DMQ consumer server...");
170
171                    return Ok(());
172                }
173                res = self.process_message() => {
174                    match res {
175                        Ok(_) => {
176                            debug!(self.logger, "Processed a message successfully");
177                        }
178                        Err(err) => {
179                            error!(self.logger, "Failed to process message"; "error" => ?err);
180                            if let Err(drop_err) = self.drop_server().await {
181                                error!(self.logger, "Failed to drop DMQ consumer server"; "error" => ?drop_err);
182                            }
183                        }
184                    }
185                }
186            }
187        }
188    }
189}
190
191#[async_trait::async_trait]
192impl DmqConsumerServer for DmqConsumerServerPallas {
193    async fn process_message(&self) -> StdResult<()> {
194        debug!(
195            self.logger,
196            "Waiting for message received from the DMQ network"
197        );
198
199        let mut server_guard = self.get_server().await?;
200        let server = server_guard.as_mut().ok_or(anyhow!("DMQ server does not exist"))?;
201        let request = server
202            .msg_notification()
203            .recv_next_request()
204            .await
205            .map_err(|err| anyhow!("Failed to receive next DMQ message: {}", err))?;
206
207        match request {
208            Request::Blocking => {
209                debug!(
210                    self.logger,
211                    "Blocking notification of messages received from the DMQ network"
212                );
213                let reply_messages = self.messages_buffer.dequeue_blocking(None).await;
214                let reply_messages =
215                    reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
216                server
217                    .msg_notification()
218                    .send_reply_messages_blocking(reply_messages.clone())
219                    .await?;
220                debug!(
221                    self.logger,
222                    "Messages replied to the DMQ notification client: {:?}", reply_messages
223                );
224            }
225            Request::NonBlocking => {
226                debug!(
227                    self.logger,
228                    "Non blocking notification of messages received from the DMQ network"
229                );
230                let reply_messages = self.messages_buffer.dequeue_non_blocking(None).await;
231                let reply_messages =
232                    reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
233                let has_more = !self.messages_buffer.is_empty().await;
234                server
235                    .msg_notification()
236                    .send_reply_messages_non_blocking(reply_messages.clone(), has_more)
237                    .await?;
238                debug!(
239                    self.logger,
240                    "Messages replied to the DMQ notification client: {:?}", reply_messages
241                );
242            }
243        };
244
245        Ok(())
246    }
247
248    async fn run(&self) -> StdResult<()> {
249        info!(
250            self.logger,
251            "Starting DMQ consumer server";
252            "socket" => ?self.socket,
253            "network" => ?self.network
254        );
255
256        let (receive_result, serve_result) = join!(
257            self.receive_incoming_messages(),
258            self.serve_incoming_messages()
259        );
260        receive_result?;
261        serve_result?;
262
263        Ok(())
264    }
265}
266
267impl Drop for DmqConsumerServerPallas {
268    fn drop(&mut self) {
269        tokio::task::block_in_place(|| {
270            tokio::runtime::Handle::current().block_on(async {
271                if let Err(e) = self.drop_server().await {
272                    error!(self.logger, "Failed to drop DMQ consumer server: {}", e);
273                }
274            });
275        });
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use std::{sync::Arc, time::Duration};
282
283    use pallas_network::{
284        facades::DmqClient,
285        miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
286    };
287    use tokio::sync::{mpsc::unbounded_channel, watch};
288    use tokio::time::sleep;
289
290    use mithril_common::{current_function, test::TempDir};
291
292    use crate::test_tools::TestLogger;
293
294    use super::*;
295
296    fn create_temp_dir(folder_name: &str) -> PathBuf {
297        TempDir::create_with_short_path("dmq_consumer_server", folder_name)
298    }
299
300    fn fake_msg() -> DmqMsg {
301        DmqMsg {
302            msg_id: vec![0, 1],
303            msg_body: vec![0, 1, 2],
304            block_number: 10,
305            ttl: 100,
306            kes_signature: vec![0, 1, 2, 3],
307            operational_certificate: vec![0, 1, 2, 3, 4],
308            kes_period: 10,
309        }
310    }
311
312    #[tokio::test(flavor = "multi_thread")]
313    async fn pallas_dmq_consumer_server_non_blocking_success() {
314        let (stop_tx, stop_rx) = watch::channel(());
315        let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
316        let socket_path = create_temp_dir(current_function!()).join("node.socket");
317        let cardano_network = CardanoNetwork::TestNet(0);
318        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
319            socket_path.to_path_buf(),
320            cardano_network.to_owned(),
321            stop_rx,
322            TestLogger::stdout(),
323        ));
324        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
325        let message = fake_msg();
326        let client = tokio::spawn({
327            async move {
328                // sleep to avoid refused connection from the server
329                tokio::time::sleep(Duration::from_millis(10)).await;
330
331                // client setup
332                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
333
334                // init local msg notification client
335                let client_msg = client.msg_notification();
336                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
337
338                // client sends a non blocking request to server and waits for a reply from the server
339                client_msg.send_request_messages_non_blocking().await.unwrap();
340                assert_eq!(
341                    *client_msg.state(),
342                    localmsgnotification::State::BusyNonBlocking
343                );
344
345                let reply = client_msg.recv_next_reply().await.unwrap();
346                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
347                let result = match reply {
348                    localmsgnotification::Reply(messages, false) => Ok(messages),
349                    _ => Err(anyhow::anyhow!(
350                        "Failed to receive blocking reply from DMQ server"
351                    )),
352                };
353
354                // stop the consumer server
355                stop_tx.send(()).unwrap();
356
357                result
358            }
359        });
360        let message_clone = message.clone();
361        let _signature_dmq_tx_clone = signature_dmq_tx.clone();
362        let recorder = tokio::spawn(async move {
363            _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
364        });
365
366        let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
367        let messages_received: Vec<_> = messages_res.unwrap().unwrap();
368        assert_eq!(vec![message], messages_received);
369    }
370
371    #[tokio::test(flavor = "multi_thread")]
372    async fn pallas_dmq_consumer_server_blocking_success() {
373        let (stop_tx, stop_rx) = watch::channel(());
374        let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
375        let socket_path = create_temp_dir(current_function!()).join("node.socket");
376        let cardano_network = CardanoNetwork::TestNet(0);
377        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
378            socket_path.to_path_buf(),
379            cardano_network.to_owned(),
380            stop_rx,
381            TestLogger::stdout(),
382        ));
383        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
384        let message = fake_msg();
385        let client = tokio::spawn({
386            async move {
387                // sleep to avoid refused connection from the server
388                tokio::time::sleep(Duration::from_millis(10)).await;
389
390                // client setup
391                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
392
393                // init local msg notification client
394                let client_msg = client.msg_notification();
395                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
396
397                // client sends a blocking request to server and waits for a reply from the server
398                client_msg.send_request_messages_blocking().await.unwrap();
399                assert_eq!(
400                    *client_msg.state(),
401                    localmsgnotification::State::BusyBlocking
402                );
403
404                let reply = client_msg.recv_next_reply().await.unwrap();
405                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
406                let result = match reply {
407                    localmsgnotification::Reply(messages, false) => Ok(messages),
408                    _ => Err(anyhow::anyhow!(
409                        "Failed to receive blocking reply from DMQ server"
410                    )),
411                };
412
413                // stop the consumer server
414                stop_tx.send(()).unwrap();
415
416                result
417            }
418        });
419        let message_clone = message.clone();
420        let _signature_dmq_tx_clone = signature_dmq_tx.clone();
421        let recorder = tokio::spawn(async move {
422            _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
423        });
424
425        let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
426        let messages_received: Vec<_> = messages_res.unwrap().unwrap();
427        assert_eq!(vec![message], messages_received);
428    }
429
430    #[tokio::test(flavor = "multi_thread")]
431    async fn pallas_dmq_consumer_server_blocking_blocks_when_no_message_available() {
432        let (_stop_tx, stop_rx) = watch::channel(());
433        let (_signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
434        let socket_path = create_temp_dir(current_function!()).join("node.socket");
435        let cardano_network = CardanoNetwork::TestNet(0);
436        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
437            socket_path.to_path_buf(),
438            cardano_network.to_owned(),
439            stop_rx,
440            TestLogger::stdout(),
441        ));
442        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
443        let client = tokio::spawn({
444            async move {
445                // sleep to avoid refused connection from the server
446                tokio::time::sleep(Duration::from_millis(10)).await;
447
448                // client setup
449                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
450
451                // init local msg notification client
452                let client_msg = client.msg_notification();
453                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
454
455                // client sends a blocking request to server and waits for a reply from the server
456                client_msg.send_request_messages_blocking().await.unwrap();
457                assert_eq!(
458                    *client_msg.state(),
459                    localmsgnotification::State::BusyBlocking
460                );
461
462                let _ = client_msg.recv_next_reply().await;
463            }
464        });
465
466        let result = tokio::select!(
467            _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
468            _res =  dmq_consumer_server.run()  => {Ok(())},
469            _res =  client  => {Ok(())},
470        );
471
472        result.expect_err("Should have timed out");
473    }
474}