mithril_dmq/consumer/server/
pallas.rs

1use std::{fs, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqServer, miniprotocols::localmsgnotification::Request};
5use tokio::{
6    join,
7    net::UnixListener,
8    select,
9    sync::{Mutex, MutexGuard, mpsc::UnboundedReceiver, watch::Receiver},
10};
11
12use slog::{Logger, debug, error, info, warn};
13
14use mithril_common::{StdResult, logging::LoggerExtensions};
15
16use crate::{DmqConsumerServer, DmqMessage, DmqNetwork};
17
18use super::queue::MessageQueue;
19
20/// A DMQ server implementation for messages notification from a DMQ node.
21pub struct DmqConsumerServerPallas {
22    socket: PathBuf,
23    network: DmqNetwork,
24    server: Mutex<Option<DmqServer>>,
25    messages_receiver: Mutex<Option<UnboundedReceiver<DmqMessage>>>,
26    messages_buffer: MessageQueue,
27    stop_rx: Receiver<()>,
28    logger: Logger,
29}
30
31impl DmqConsumerServerPallas {
32    /// Creates a new instance of [DmqConsumerServerPallas].
33    pub fn new(
34        socket: PathBuf,
35        network: DmqNetwork,
36        stop_rx: Receiver<()>,
37        logger: Logger,
38    ) -> Self {
39        Self {
40            socket,
41            network,
42            server: Mutex::new(None),
43            messages_receiver: Mutex::new(None),
44            messages_buffer: MessageQueue::default(),
45            stop_rx,
46            logger: logger.new_with_component_name::<Self>(),
47        }
48    }
49
50    /// Creates and returns a new `DmqServer` connected to the specified socket.
51    async fn new_server(&self) -> StdResult<DmqServer> {
52        info!(
53            self.logger,
54            "Creating a new DMQ consumer server";
55            "socket" => ?self.socket,
56            "network" => ?self.network
57        );
58        let magic = self.network.magic_id();
59        if self.socket.exists() {
60            fs::remove_file(self.socket.clone())?;
61        }
62        let listener = UnixListener::bind(&self.socket)
63            .map_err(|err| anyhow!(err))
64            .with_context(|| {
65                format!(
66                    "DmqConsumerServerPallas failed to bind Unix socket at {}",
67                    self.socket.display()
68                )
69            })?;
70
71        DmqServer::accept(&listener, magic)
72            .await
73            .map_err(|err| anyhow!(err))
74            .with_context(|| "DmqConsumerServerPallas failed to create a new server")
75    }
76
77    /// Gets the cached `DmqServer`, creating a new one if it does not exist.
78    async fn get_server(&self) -> StdResult<MutexGuard<'_, Option<DmqServer>>> {
79        {
80            // Run this in a separate block to avoid dead lock on the Mutex
81            let server_lock = self.server.lock().await;
82            if server_lock.as_ref().is_some() {
83                return Ok(server_lock);
84            }
85        }
86
87        let mut server_lock = self.server.lock().await;
88        *server_lock = Some(self.new_server().await?);
89
90        Ok(server_lock)
91    }
92
93    /// Drops the current `DmqServer`, if it exists.
94    async fn drop_server(&self) -> StdResult<()> {
95        debug!(
96            self.logger,
97            "Drop existing DMQ server";
98            "socket" => ?self.socket,
99            "network" => ?self.network
100        );
101        let mut server_lock = self.server.try_lock()?;
102        if let Some(server) = server_lock.take() {
103            server.abort().await;
104        }
105
106        Ok(())
107    }
108
109    /// Registers the receiver for DMQ messages (only one receiver is allowed).
110    pub async fn register_receiver(
111        &self,
112        receiver: UnboundedReceiver<DmqMessage>,
113    ) -> StdResult<()> {
114        debug!(self.logger, "Register message receiver for DMQ messages");
115        let mut receiver_guard = self.messages_receiver.lock().await;
116        *receiver_guard = Some(receiver);
117
118        Ok(())
119    }
120
121    /// Receives incoming messages into the DMQ consumer server.
122    async fn receive_incoming_messages(&self) -> StdResult<()> {
123        info!(
124            self.logger,
125            "Receive incoming messages into DMQ consumer server...";
126            "socket" => ?self.socket,
127            "network" => ?self.network
128        );
129
130        let mut stop_rx = self.stop_rx.clone();
131        let mut receiver = self.messages_receiver.lock().await;
132        match *receiver {
133            Some(ref mut receiver) => loop {
134                select! {
135                    _ = stop_rx.changed() => {
136                        warn!(self.logger, "Stopping DMQ consumer server...");
137
138                        return Ok(());
139                    }
140                    message = receiver.recv() => {
141                        if let Some(message) = message {
142                            debug!(self.logger, "Received a message from the DMQ network"; "message" => ?message);
143                            self.messages_buffer.enqueue(message).await;
144                        } else {
145                            warn!(self.logger, "DMQ message receiver channel closed");
146                            return Ok(());
147                        }
148
149                    }
150                }
151            },
152            None => Err(anyhow!("DMQ message receiver is not registered")),
153        }
154    }
155
156    /// Serves incoming messages from the DMQ consumer server.
157    async fn serve_incoming_messages(&self) -> StdResult<()> {
158        info!(
159            self.logger,
160            "Serve incoming messages from DMQ consumer server...";
161            "socket" => ?self.socket,
162            "network" => ?self.network
163        );
164
165        let mut stop_rx = self.stop_rx.clone();
166        loop {
167            select! {
168                _ = stop_rx.changed() => {
169                    warn!(self.logger, "Stopping DMQ consumer server...");
170
171                    return Ok(());
172                }
173                res = self.process_message() => {
174                    match res {
175                        Ok(_) => {
176                            debug!(self.logger, "Processed a message successfully");
177                        }
178                        Err(err) => {
179                            error!(self.logger, "Failed to process message"; "error" => ?err);
180                            if let Err(drop_err) = self.drop_server().await {
181                                error!(self.logger, "Failed to drop DMQ consumer server"; "error" => ?drop_err);
182                            }
183                        }
184                    }
185                }
186            }
187        }
188    }
189}
190
191#[async_trait::async_trait]
192impl DmqConsumerServer for DmqConsumerServerPallas {
193    async fn process_message(&self) -> StdResult<()> {
194        debug!(
195            self.logger,
196            "Waiting for message received from the DMQ network"
197        );
198
199        let mut server_guard = self.get_server().await?;
200        let server = server_guard.as_mut().ok_or(anyhow!("DMQ server does not exist"))?;
201        let request = server
202            .msg_notification()
203            .recv_next_request()
204            .await
205            .map_err(|err| anyhow!("Failed to receive next DMQ message: {}", err))?;
206
207        match request {
208            Request::Blocking => {
209                debug!(
210                    self.logger,
211                    "Blocking notification of messages received from the DMQ network"
212                );
213                let reply_messages = self.messages_buffer.dequeue_blocking(None).await;
214                let reply_messages =
215                    reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
216                server
217                    .msg_notification()
218                    .send_reply_messages_blocking(reply_messages.clone())
219                    .await?;
220                debug!(
221                    self.logger,
222                    "Messages replied to the DMQ notification client: {:?}", reply_messages
223                );
224            }
225            Request::NonBlocking => {
226                debug!(
227                    self.logger,
228                    "Non blocking notification of messages received from the DMQ network"
229                );
230                let reply_messages = self.messages_buffer.dequeue_non_blocking(None).await;
231                let reply_messages =
232                    reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
233                let has_more = !self.messages_buffer.is_empty().await;
234                server
235                    .msg_notification()
236                    .send_reply_messages_non_blocking(reply_messages.clone(), has_more)
237                    .await?;
238                debug!(
239                    self.logger,
240                    "Messages replied to the DMQ notification client: {:?}", reply_messages
241                );
242            }
243        };
244
245        Ok(())
246    }
247
248    async fn run(&self) -> StdResult<()> {
249        info!(
250            self.logger,
251            "Starting DMQ consumer server";
252            "socket" => ?self.socket,
253            "network" => ?self.network
254        );
255
256        let (receive_result, serve_result) = join!(
257            self.receive_incoming_messages(),
258            self.serve_incoming_messages()
259        );
260        receive_result?;
261        serve_result?;
262
263        Ok(())
264    }
265}
266
267impl Drop for DmqConsumerServerPallas {
268    fn drop(&mut self) {
269        tokio::task::block_in_place(|| {
270            tokio::runtime::Handle::current().block_on(async {
271                if let Err(e) = self.drop_server().await {
272                    error!(self.logger, "Failed to drop DMQ consumer server: {}", e);
273                }
274            });
275        });
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use std::{sync::Arc, time::Duration};
282
283    use pallas_network::{
284        facades::DmqClient,
285        miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
286    };
287    use tokio::sync::{mpsc::unbounded_channel, watch};
288    use tokio::time::sleep;
289
290    use mithril_common::{current_function, test::TempDir};
291
292    use crate::{test::fake_message::compute_fake_msg, test_tools::TestLogger};
293
294    use super::*;
295
296    fn create_temp_dir(folder_name: &str) -> PathBuf {
297        TempDir::create_with_short_path("dmq_consumer_server", folder_name)
298    }
299
300    #[tokio::test(flavor = "multi_thread")]
301    async fn pallas_dmq_consumer_server_non_blocking_success() {
302        let current_function_name = current_function!();
303        let (stop_tx, stop_rx) = watch::channel(());
304        let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
305        let socket_path = create_temp_dir(current_function_name).join("node.socket");
306        let dmq_network = DmqNetwork::TestNet(0);
307        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
308            socket_path.to_path_buf(),
309            dmq_network.to_owned(),
310            stop_rx,
311            TestLogger::stdout(),
312        ));
313        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
314        let message: DmqMsg = compute_fake_msg(b"test", current_function_name).await.into();
315        let client = tokio::spawn({
316            async move {
317                // sleep to avoid refused connection from the server
318                tokio::time::sleep(Duration::from_millis(10)).await;
319
320                // client setup
321                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
322
323                // init local msg notification client
324                let client_msg = client.msg_notification();
325                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
326
327                // client sends a non blocking request to server and waits for a reply from the server
328                client_msg.send_request_messages_non_blocking().await.unwrap();
329                assert_eq!(
330                    *client_msg.state(),
331                    localmsgnotification::State::BusyNonBlocking
332                );
333
334                let reply = client_msg.recv_next_reply().await.unwrap();
335                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
336                let result = match reply {
337                    localmsgnotification::Reply(messages, false) => Ok(messages),
338                    _ => Err(anyhow::anyhow!(
339                        "Failed to receive blocking reply from DMQ server"
340                    )),
341                };
342
343                // stop the consumer server
344                stop_tx.send(()).unwrap();
345
346                result
347            }
348        });
349        let message_clone = message.clone();
350        let _signature_dmq_tx_clone = signature_dmq_tx.clone();
351        let recorder = tokio::spawn(async move {
352            _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
353        });
354
355        let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
356        let messages_received: Vec<_> = messages_res.unwrap().unwrap();
357        assert_eq!(vec![message], messages_received);
358    }
359
360    #[tokio::test(flavor = "multi_thread")]
361    async fn pallas_dmq_consumer_server_blocking_success() {
362        let current_function_name = current_function!();
363        let (stop_tx, stop_rx) = watch::channel(());
364        let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
365        let socket_path = create_temp_dir(current_function_name).join("node.socket");
366        let dmq_network = DmqNetwork::TestNet(0);
367        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
368            socket_path.to_path_buf(),
369            dmq_network.to_owned(),
370            stop_rx,
371            TestLogger::stdout(),
372        ));
373        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
374        let message: DmqMsg = compute_fake_msg(b"test", current_function_name).await.into();
375        let client = tokio::spawn({
376            async move {
377                // sleep to avoid refused connection from the server
378                tokio::time::sleep(Duration::from_millis(10)).await;
379
380                // client setup
381                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
382
383                // init local msg notification client
384                let client_msg = client.msg_notification();
385                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
386
387                // client sends a blocking request to server and waits for a reply from the server
388                client_msg.send_request_messages_blocking().await.unwrap();
389                assert_eq!(
390                    *client_msg.state(),
391                    localmsgnotification::State::BusyBlocking
392                );
393
394                let reply = client_msg.recv_next_reply().await.unwrap();
395                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
396                let result = match reply {
397                    localmsgnotification::Reply(messages, false) => Ok(messages),
398                    _ => Err(anyhow::anyhow!(
399                        "Failed to receive blocking reply from DMQ server"
400                    )),
401                };
402
403                // stop the consumer server
404                stop_tx.send(()).unwrap();
405
406                result
407            }
408        });
409        let message_clone = message.clone();
410        let _signature_dmq_tx_clone = signature_dmq_tx.clone();
411        let recorder = tokio::spawn(async move {
412            _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
413        });
414
415        let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
416        let messages_received: Vec<_> = messages_res.unwrap().unwrap();
417        assert_eq!(vec![message], messages_received);
418    }
419
420    #[tokio::test(flavor = "multi_thread")]
421    async fn pallas_dmq_consumer_server_blocking_blocks_when_no_message_available() {
422        let (_stop_tx, stop_rx) = watch::channel(());
423        let (_signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
424        let socket_path = create_temp_dir(current_function!()).join("node.socket");
425        let dmq_network = DmqNetwork::TestNet(0);
426        let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
427            socket_path.to_path_buf(),
428            dmq_network.to_owned(),
429            stop_rx,
430            TestLogger::stdout(),
431        ));
432        dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
433        let client = tokio::spawn({
434            async move {
435                // sleep to avoid refused connection from the server
436                tokio::time::sleep(Duration::from_millis(10)).await;
437
438                // client setup
439                let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
440
441                // init local msg notification client
442                let client_msg = client.msg_notification();
443                assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
444
445                // client sends a blocking request to server and waits for a reply from the server
446                client_msg.send_request_messages_blocking().await.unwrap();
447                assert_eq!(
448                    *client_msg.state(),
449                    localmsgnotification::State::BusyBlocking
450                );
451
452                let _ = client_msg.recv_next_reply().await;
453            }
454        });
455
456        let result = tokio::select!(
457            _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
458            _res =  dmq_consumer_server.run()  => {Ok(())},
459            _res =  client  => {Ok(())},
460        );
461
462        result.expect_err("Should have timed out");
463    }
464}