mithril_dmq/consumer/server/
pallas.rs

1use std::{fs, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqServer, miniprotocols::localmsgnotification::Request};
5use tokio::{
6    join,
7    net::UnixListener,
8    select,
9    sync::{Mutex, MutexGuard, mpsc::UnboundedReceiver, watch::Receiver},
10};
11
12use slog::{Logger, debug, error, info, warn};
13
14use mithril_common::{StdResult, logging::LoggerExtensions};
15
16use crate::{DmqConsumerServer, DmqMessage, DmqNetwork};
17
18use super::queue::MessageQueue;
19
20/// A DMQ server implementation for messages notification from a DMQ node.
21pub struct DmqConsumerServerPallas {
22    socket: PathBuf,
23    network: DmqNetwork,
24    server: Mutex<Option<DmqServer>>,
25    messages_receiver: Mutex<Option<UnboundedReceiver<DmqMessage>>>,
26    messages_buffer: MessageQueue,
27    stop_rx: Receiver<()>,
28    logger: Logger,
29}
30
31impl DmqConsumerServerPallas {
32    /// Creates a new instance of [DmqConsumerServerPallas].
33    pub fn new(
34        socket: PathBuf,
35        network: DmqNetwork,
36        stop_rx: Receiver<()>,
37        logger: Logger,
38    ) -> Self {
39        Self {
40            socket,
41            network,
42            server: Mutex::new(None),
43            messages_receiver: Mutex::new(None),
44            messages_buffer: MessageQueue::default(),
45            stop_rx,
46            logger: logger.new_with_component_name::<Self>(),
47        }
48    }
49
50    /// Creates and returns a new `DmqServer` connected to the specified socket.
51    async fn new_server(&self) -> StdResult<DmqServer> {
52        info!(
53            self.logger,
54            "Creating a new DMQ consumer server";
55            "socket" => ?self.socket,
56            "network" => ?self.network
57        );
58        let magic = self.network.magic_id();
59        if self.socket.exists() {
60            fs::remove_file(self.socket.clone())?;
61        }
62        let listener = UnixListener::bind(&self.socket).with_context(|| {
63            format!(
64                "DmqConsumerServerPallas failed to bind Unix socket at {}",
65                self.socket.display()
66            )
67        })?;
68
69        DmqServer::accept(&listener, magic)
70            .await
71            .with_context(|| "DmqConsumerServerPallas failed to create a new server")
72    }
73
74    /// Gets the cached `DmqServer`, creating a new one if it does not exist.
75    async fn get_server(&self) -> StdResult<MutexGuard<'_, Option<DmqServer>>> {
76        {
77            // Run this in a separate block to avoid dead lock on the Mutex
78            let server_lock = self.server.lock().await;
79            if server_lock.as_ref().is_some() {
80                return Ok(server_lock);
81            }
82        }
83
84        let mut server_lock = self.server.lock().await;
85        *server_lock = Some(self.new_server().await?);
86
87        Ok(server_lock)
88    }
89
90    /// Drops the current `DmqServer`, if it exists.
91    async fn drop_server(&self) -> StdResult<()> {
92        debug!(
93            self.logger,
94            "Drop existing DMQ server";
95            "socket" => ?self.socket,
96            "network" => ?self.network
97        );
98        let mut server_lock = self.server.try_lock()?;
99        if let Some(server) = server_lock.take() {
100            server.abort().await;
101        }
102
103        Ok(())
104    }
105
106    /// Registers the receiver for DMQ messages (only one receiver is allowed).
107    pub async fn register_receiver(
108        &self,
109        receiver: UnboundedReceiver<DmqMessage>,
110    ) -> StdResult<()> {
111        debug!(self.logger, "Register message receiver for DMQ messages");
112        let mut receiver_guard = self.messages_receiver.lock().await;
113        *receiver_guard = Some(receiver);
114
115        Ok(())
116    }
117
118    /// Receives incoming messages into the DMQ consumer server.
119    async fn receive_incoming_messages(&self) -> StdResult<()> {
120        info!(
121            self.logger,
122            "Receive incoming messages into DMQ consumer server...";
123            "socket" => ?self.socket,
124            "network" => ?self.network
125        );
126
127        let mut stop_rx = self.stop_rx.clone();
128        let mut receiver = self.messages_receiver.lock().await;
129        match *receiver {
130            Some(ref mut receiver) => loop {
131                select! {
132                    _ = stop_rx.changed() => {
133                        warn!(self.logger, "Stopping DMQ consumer server...");
134
135                        return Ok(());
136                    }
137                    message = receiver.recv() => {
138                        if let Some(message) = message {
139                            debug!(self.logger, "Received a message from the DMQ network"; "message" => ?message);
140                            self.messages_buffer.enqueue(message).await;
141                        } else {
142                            warn!(self.logger, "DMQ message receiver channel closed");
143                            return Ok(());
144                        }
145
146                    }
147                }
148            },
149            None => Err(anyhow!("DMQ message receiver is not registered")),
150        }
151    }
152
153    /// Serves incoming messages from the DMQ consumer server.
154    async fn serve_incoming_messages(&self) -> StdResult<()> {
155        info!(
156            self.logger,
157            "Serve incoming messages from DMQ consumer server...";
158            "socket" => ?self.socket,
159            "network" => ?self.network
160        );
161
162        let mut stop_rx = self.stop_rx.clone();
163        loop {
164            select! {
165                _ = stop_rx.changed() => {
166                    warn!(self.logger, "Stopping DMQ consumer server...");
167
168                    return Ok(());
169                }
170                res = self.process_message() => {
171                    match res {
172                        Ok(_) => {
173                            debug!(self.logger, "Processed a message successfully");
174                        }
175                        Err(err) => {
176                            error!(self.logger, "Failed to process message"; "error" => ?err);
177                            if let Err(drop_err) = self.drop_server().await {
178                                error!(self.logger, "Failed to drop DMQ consumer server"; "error" => ?drop_err);
179                            }
180                        }
181                    }
182                }
183            }
184        }
185    }
186}
187
188#[async_trait::async_trait]
189impl DmqConsumerServer for DmqConsumerServerPallas {
190    async fn process_message(&self) -> StdResult<()> {
191        debug!(
192            self.logger,
193            "Waiting for message received from the DMQ network"
194        );
195
196        let mut server_guard = self.get_server().await?;
197        let server = server_guard.as_mut().with_context(|| "DMQ server does not exist")?;
198        let request = server
199            .msg_notification()
200            .recv_next_request()
201            .await
202            .with_context(|| "Failed to receive next DMQ message")?;
203
204        match request {
205            Request::Blocking => {
206                debug!(
207                    self.logger,
208                    "Blocking notification of messages received from the DMQ network"
209                );
210                let reply_messages = self.messages_buffer.dequeue_blocking(None).await;
211                let reply_messages =
212                    reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
213                server
214                    .msg_notification()
215                    .send_reply_messages_blocking(reply_messages.clone())
216                    .await?;
217                debug!(
218                    self.logger,
219                    "Messages replied to the DMQ notification client: {:?}", reply_messages
220                );
221            }
222            Request::NonBlocking => {
223                debug!(
224                    self.logger,
225                    "Non blocking notification of messages received from the DMQ network"
226                );
227                let reply_messages = self.messages_buffer.dequeue_non_blocking(None).await;
228                let reply_messages =
229                    reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
230                let has_more = !self.messages_buffer.is_empty().await;
231                server
232                    .msg_notification()
233                    .send_reply_messages_non_blocking(reply_messages.clone(), has_more)
234                    .await?;
235                debug!(
236                    self.logger,
237                    "Messages replied to the DMQ notification client: {:?}", reply_messages
238                );
239            }
240        };
241
242        Ok(())
243    }
244
245    async fn run(&self) -> StdResult<()> {
246        info!(
247            self.logger,
248            "Starting DMQ consumer server";
249            "socket" => ?self.socket,
250            "network" => ?self.network
251        );
252
253        let (receive_result, serve_result) = join!(
254            self.receive_incoming_messages(),
255            self.serve_incoming_messages()
256        );
257        receive_result?;
258        serve_result?;
259
260        Ok(())
261    }
262}
263
264impl Drop for DmqConsumerServerPallas {
265    fn drop(&mut self) {
266        tokio::task::block_in_place(|| {
267            tokio::runtime::Handle::current().block_on(async {
268                if let Err(e) = self.drop_server().await {
269                    error!(self.logger, "Failed to drop DMQ consumer server: {}", e);
270                }
271            });
272        });
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use std::{sync::Arc, time::Duration};
279
280    use pallas_network::{
281        facades::DmqClient,
282        miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
283    };
284    use tokio::sync::{mpsc::unbounded_channel, watch};
285    use tokio::time::sleep;
286
287    use mithril_common::{current_function, test::TempDir};
288
289    use crate::{test::fake_message::compute_fake_msg, test_tools::TestLogger};
290
291    use super::*;
292
293    fn create_temp_dir(folder_name: &str) -> PathBuf {
294        TempDir::create_with_short_path("dmq_consumer_server", folder_name)
295    }
296
297    #[tokio::test(flavor = "multi_thread")]
298    async fn pallas_dmq_consumer_server_non_blocking_success() {
299        let current_function_name = current_function!();
300        let (stop_tx, stop_rx) = watch::channel(());
301        let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
302        let socket_path = create_temp_dir(current_function_name).join("node.socket");
303        let dmq_network = DmqNetwork::TestNet(0);
304        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
305            socket_path.to_path_buf(),
306            dmq_network.to_owned(),
307            stop_rx,
308            TestLogger::stdout(),
309        ));
310        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
311        let message: DmqMsg = compute_fake_msg(b"test", current_function_name).await.into();
312        let client = tokio::spawn({
313            async move {
314                // sleep to avoid refused connection from the server
315                tokio::time::sleep(Duration::from_millis(10)).await;
316
317                // client setup
318                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
319
320                // init local msg notification client
321                let client_msg = client.msg_notification();
322                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
323
324                // client sends a non blocking request to server and waits for a reply from the server
325                client_msg.send_request_messages_non_blocking().await.unwrap();
326                assert_eq!(
327                    *client_msg.state(),
328                    localmsgnotification::State::BusyNonBlocking
329                );
330
331                let reply = client_msg.recv_next_reply().await.unwrap();
332                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
333                let result = match reply {
334                    localmsgnotification::Reply(messages, false) => Ok(messages),
335                    _ => Err(anyhow::anyhow!(
336                        "Failed to receive blocking reply from DMQ server"
337                    )),
338                };
339
340                // stop the consumer server
341                stop_tx.send(()).unwrap();
342
343                result
344            }
345        });
346        let message_clone = message.clone();
347        let _signature_dmq_tx_clone = signature_dmq_tx.clone();
348        let recorder = tokio::spawn(async move {
349            _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
350        });
351
352        let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
353        let messages_received: Vec<_> = messages_res.unwrap().unwrap();
354        assert_eq!(vec![message], messages_received);
355    }
356
357    #[tokio::test(flavor = "multi_thread")]
358    async fn pallas_dmq_consumer_server_blocking_success() {
359        let current_function_name = current_function!();
360        let (stop_tx, stop_rx) = watch::channel(());
361        let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
362        let socket_path = create_temp_dir(current_function_name).join("node.socket");
363        let dmq_network = DmqNetwork::TestNet(0);
364        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
365            socket_path.to_path_buf(),
366            dmq_network.to_owned(),
367            stop_rx,
368            TestLogger::stdout(),
369        ));
370        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
371        let message: DmqMsg = compute_fake_msg(b"test", current_function_name).await.into();
372        let client = tokio::spawn({
373            async move {
374                // sleep to avoid refused connection from the server
375                tokio::time::sleep(Duration::from_millis(10)).await;
376
377                // client setup
378                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
379
380                // init local msg notification client
381                let client_msg = client.msg_notification();
382                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
383
384                // client sends a blocking request to server and waits for a reply from the server
385                client_msg.send_request_messages_blocking().await.unwrap();
386                assert_eq!(
387                    *client_msg.state(),
388                    localmsgnotification::State::BusyBlocking
389                );
390
391                let reply = client_msg.recv_next_reply().await.unwrap();
392                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
393                let result = match reply {
394                    localmsgnotification::Reply(messages, false) => Ok(messages),
395                    _ => Err(anyhow::anyhow!(
396                        "Failed to receive blocking reply from DMQ server"
397                    )),
398                };
399
400                // stop the consumer server
401                stop_tx.send(()).unwrap();
402
403                result
404            }
405        });
406        let message_clone = message.clone();
407        let _signature_dmq_tx_clone = signature_dmq_tx.clone();
408        let recorder = tokio::spawn(async move {
409            _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
410        });
411
412        let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
413        let messages_received: Vec<_> = messages_res.unwrap().unwrap();
414        assert_eq!(vec![message], messages_received);
415    }
416
417    #[tokio::test(flavor = "multi_thread")]
418    async fn pallas_dmq_consumer_server_blocking_blocks_when_no_message_available() {
419        let (_stop_tx, stop_rx) = watch::channel(());
420        let (_signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
421        let socket_path = create_temp_dir(current_function!()).join("node.socket");
422        let dmq_network = DmqNetwork::TestNet(0);
423        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
424            socket_path.to_path_buf(),
425            dmq_network.to_owned(),
426            stop_rx,
427            TestLogger::stdout(),
428        ));
429        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
430        let client = tokio::spawn({
431            async move {
432                // sleep to avoid refused connection from the server
433                tokio::time::sleep(Duration::from_millis(10)).await;
434
435                // client setup
436                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
437
438                // init local msg notification client
439                let client_msg = client.msg_notification();
440                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
441
442                // client sends a blocking request to server and waits for a reply from the server
443                client_msg.send_request_messages_blocking().await.unwrap();
444                assert_eq!(
445                    *client_msg.state(),
446                    localmsgnotification::State::BusyBlocking
447                );
448
449                let _ = client_msg.recv_next_reply().await;
450            }
451        });
452
453        let result = tokio::select!(
454            _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
455            _res =  dmq_consumer_server.run()  => {Ok(())},
456            _res =  client  => {Ok(())},
457        );
458
459        result.expect_err("Should have timed out");
460    }
461}