mithril_dmq/consumer/server/
pallas.rs1use std::{fs, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqServer, miniprotocols::localmsgnotification::Request};
5use tokio::{
6 join,
7 net::UnixListener,
8 select,
9 sync::{Mutex, MutexGuard, mpsc::UnboundedReceiver, watch::Receiver},
10};
11
12use slog::{Logger, debug, error, info, warn};
13
14use mithril_common::{StdResult, logging::LoggerExtensions};
15
16use crate::{DmqConsumerServer, DmqMessage, DmqNetwork};
17
18use super::queue::MessageQueue;
19
20pub struct DmqConsumerServerPallas {
22 socket: PathBuf,
23 network: DmqNetwork,
24 server: Mutex<Option<DmqServer>>,
25 messages_receiver: Mutex<Option<UnboundedReceiver<DmqMessage>>>,
26 messages_buffer: MessageQueue,
27 stop_rx: Receiver<()>,
28 logger: Logger,
29}
30
31impl DmqConsumerServerPallas {
32 pub fn new(
34 socket: PathBuf,
35 network: DmqNetwork,
36 stop_rx: Receiver<()>,
37 logger: Logger,
38 ) -> Self {
39 Self {
40 socket,
41 network,
42 server: Mutex::new(None),
43 messages_receiver: Mutex::new(None),
44 messages_buffer: MessageQueue::default(),
45 stop_rx,
46 logger: logger.new_with_component_name::<Self>(),
47 }
48 }
49
50 async fn new_server(&self) -> StdResult<DmqServer> {
52 info!(
53 self.logger,
54 "Creating a new DMQ consumer server";
55 "socket" => ?self.socket,
56 "network" => ?self.network
57 );
58 let magic = self.network.magic_id();
59 if self.socket.exists() {
60 fs::remove_file(self.socket.clone())?;
61 }
62 let listener = UnixListener::bind(&self.socket).with_context(|| {
63 format!(
64 "DmqConsumerServerPallas failed to bind Unix socket at {}",
65 self.socket.display()
66 )
67 })?;
68
69 DmqServer::accept(&listener, magic)
70 .await
71 .with_context(|| "DmqConsumerServerPallas failed to create a new server")
72 }
73
74 async fn get_server(&self) -> StdResult<MutexGuard<'_, Option<DmqServer>>> {
76 {
77 let server_lock = self.server.lock().await;
79 if server_lock.as_ref().is_some() {
80 return Ok(server_lock);
81 }
82 }
83
84 let mut server_lock = self.server.lock().await;
85 *server_lock = Some(self.new_server().await?);
86
87 Ok(server_lock)
88 }
89
90 async fn drop_server(&self) -> StdResult<()> {
92 debug!(
93 self.logger,
94 "Drop existing DMQ server";
95 "socket" => ?self.socket,
96 "network" => ?self.network
97 );
98 let mut server_lock = self.server.try_lock()?;
99 if let Some(server) = server_lock.take() {
100 server.abort().await;
101 }
102
103 Ok(())
104 }
105
106 pub async fn register_receiver(
108 &self,
109 receiver: UnboundedReceiver<DmqMessage>,
110 ) -> StdResult<()> {
111 debug!(self.logger, "Register message receiver for DMQ messages");
112 let mut receiver_guard = self.messages_receiver.lock().await;
113 *receiver_guard = Some(receiver);
114
115 Ok(())
116 }
117
118 async fn receive_incoming_messages(&self) -> StdResult<()> {
120 info!(
121 self.logger,
122 "Receive incoming messages into DMQ consumer server...";
123 "socket" => ?self.socket,
124 "network" => ?self.network
125 );
126
127 let mut stop_rx = self.stop_rx.clone();
128 let mut receiver = self.messages_receiver.lock().await;
129 match *receiver {
130 Some(ref mut receiver) => loop {
131 select! {
132 _ = stop_rx.changed() => {
133 warn!(self.logger, "Stopping DMQ consumer server...");
134
135 return Ok(());
136 }
137 message = receiver.recv() => {
138 if let Some(message) = message {
139 debug!(self.logger, "Received a message from the DMQ network"; "message" => ?message);
140 self.messages_buffer.enqueue(message).await;
141 } else {
142 warn!(self.logger, "DMQ message receiver channel closed");
143 return Ok(());
144 }
145
146 }
147 }
148 },
149 None => Err(anyhow!("DMQ message receiver is not registered")),
150 }
151 }
152
153 async fn serve_incoming_messages(&self) -> StdResult<()> {
155 info!(
156 self.logger,
157 "Serve incoming messages from DMQ consumer server...";
158 "socket" => ?self.socket,
159 "network" => ?self.network
160 );
161
162 let mut stop_rx = self.stop_rx.clone();
163 loop {
164 select! {
165 _ = stop_rx.changed() => {
166 warn!(self.logger, "Stopping DMQ consumer server...");
167
168 return Ok(());
169 }
170 res = self.process_message() => {
171 match res {
172 Ok(_) => {
173 debug!(self.logger, "Processed a message successfully");
174 }
175 Err(err) => {
176 error!(self.logger, "Failed to process message"; "error" => ?err);
177 if let Err(drop_err) = self.drop_server().await {
178 error!(self.logger, "Failed to drop DMQ consumer server"; "error" => ?drop_err);
179 }
180 }
181 }
182 }
183 }
184 }
185 }
186}
187
188#[async_trait::async_trait]
189impl DmqConsumerServer for DmqConsumerServerPallas {
190 async fn process_message(&self) -> StdResult<()> {
191 debug!(
192 self.logger,
193 "Waiting for message received from the DMQ network"
194 );
195
196 let mut server_guard = self.get_server().await?;
197 let server = server_guard.as_mut().with_context(|| "DMQ server does not exist")?;
198 let request = server
199 .msg_notification()
200 .recv_next_request()
201 .await
202 .with_context(|| "Failed to receive next DMQ message")?;
203
204 match request {
205 Request::Blocking => {
206 debug!(
207 self.logger,
208 "Blocking notification of messages received from the DMQ network"
209 );
210 let reply_messages = self.messages_buffer.dequeue_blocking(None).await;
211 let reply_messages =
212 reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
213 server
214 .msg_notification()
215 .send_reply_messages_blocking(reply_messages.clone())
216 .await?;
217 debug!(
218 self.logger,
219 "Messages replied to the DMQ notification client: {:?}", reply_messages
220 );
221 }
222 Request::NonBlocking => {
223 debug!(
224 self.logger,
225 "Non blocking notification of messages received from the DMQ network"
226 );
227 let reply_messages = self.messages_buffer.dequeue_non_blocking(None).await;
228 let reply_messages =
229 reply_messages.into_iter().map(|msg| msg.into()).collect::<Vec<_>>();
230 let has_more = !self.messages_buffer.is_empty().await;
231 server
232 .msg_notification()
233 .send_reply_messages_non_blocking(reply_messages.clone(), has_more)
234 .await?;
235 debug!(
236 self.logger,
237 "Messages replied to the DMQ notification client: {:?}", reply_messages
238 );
239 }
240 };
241
242 Ok(())
243 }
244
245 async fn run(&self) -> StdResult<()> {
246 info!(
247 self.logger,
248 "Starting DMQ consumer server";
249 "socket" => ?self.socket,
250 "network" => ?self.network
251 );
252
253 let (receive_result, serve_result) = join!(
254 self.receive_incoming_messages(),
255 self.serve_incoming_messages()
256 );
257 receive_result?;
258 serve_result?;
259
260 Ok(())
261 }
262}
263
264impl Drop for DmqConsumerServerPallas {
265 fn drop(&mut self) {
266 tokio::task::block_in_place(|| {
267 tokio::runtime::Handle::current().block_on(async {
268 if let Err(e) = self.drop_server().await {
269 error!(self.logger, "Failed to drop DMQ consumer server: {}", e);
270 }
271 });
272 });
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use std::{sync::Arc, time::Duration};
279
280 use pallas_network::{
281 facades::DmqClient,
282 miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
283 };
284 use tokio::sync::{mpsc::unbounded_channel, watch};
285 use tokio::time::sleep;
286
287 use mithril_common::{current_function, test::TempDir};
288
289 use crate::{test::fake_message::compute_fake_msg, test_tools::TestLogger};
290
291 use super::*;
292
293 fn create_temp_dir(folder_name: &str) -> PathBuf {
294 TempDir::create_with_short_path("dmq_consumer_server", folder_name)
295 }
296
297 #[tokio::test(flavor = "multi_thread")]
298 async fn pallas_dmq_consumer_server_non_blocking_success() {
299 let current_function_name = current_function!();
300 let (stop_tx, stop_rx) = watch::channel(());
301 let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
302 let socket_path = create_temp_dir(current_function_name).join("node.socket");
303 let dmq_network = DmqNetwork::TestNet(0);
304 let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
305 socket_path.to_path_buf(),
306 dmq_network.to_owned(),
307 stop_rx,
308 TestLogger::stdout(),
309 ));
310 dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
311 let message: DmqMsg = compute_fake_msg(b"test", current_function_name).await.into();
312 let client = tokio::spawn({
313 async move {
314 tokio::time::sleep(Duration::from_millis(10)).await;
316
317 let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
319
320 let client_msg = client.msg_notification();
322 assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
323
324 client_msg.send_request_messages_non_blocking().await.unwrap();
326 assert_eq!(
327 *client_msg.state(),
328 localmsgnotification::State::BusyNonBlocking
329 );
330
331 let reply = client_msg.recv_next_reply().await.unwrap();
332 assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
333 let result = match reply {
334 localmsgnotification::Reply(messages, false) => Ok(messages),
335 _ => Err(anyhow::anyhow!(
336 "Failed to receive blocking reply from DMQ server"
337 )),
338 };
339
340 stop_tx.send(()).unwrap();
342
343 result
344 }
345 });
346 let message_clone = message.clone();
347 let _signature_dmq_tx_clone = signature_dmq_tx.clone();
348 let recorder = tokio::spawn(async move {
349 _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
350 });
351
352 let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
353 let messages_received: Vec<_> = messages_res.unwrap().unwrap();
354 assert_eq!(vec![message], messages_received);
355 }
356
357 #[tokio::test(flavor = "multi_thread")]
358 async fn pallas_dmq_consumer_server_blocking_success() {
359 let current_function_name = current_function!();
360 let (stop_tx, stop_rx) = watch::channel(());
361 let (signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
362 let socket_path = create_temp_dir(current_function_name).join("node.socket");
363 let dmq_network = DmqNetwork::TestNet(0);
364 let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
365 socket_path.to_path_buf(),
366 dmq_network.to_owned(),
367 stop_rx,
368 TestLogger::stdout(),
369 ));
370 dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
371 let message: DmqMsg = compute_fake_msg(b"test", current_function_name).await.into();
372 let client = tokio::spawn({
373 async move {
374 tokio::time::sleep(Duration::from_millis(10)).await;
376
377 let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
379
380 let client_msg = client.msg_notification();
382 assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
383
384 client_msg.send_request_messages_blocking().await.unwrap();
386 assert_eq!(
387 *client_msg.state(),
388 localmsgnotification::State::BusyBlocking
389 );
390
391 let reply = client_msg.recv_next_reply().await.unwrap();
392 assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
393 let result = match reply {
394 localmsgnotification::Reply(messages, false) => Ok(messages),
395 _ => Err(anyhow::anyhow!(
396 "Failed to receive blocking reply from DMQ server"
397 )),
398 };
399
400 stop_tx.send(()).unwrap();
402
403 result
404 }
405 });
406 let message_clone = message.clone();
407 let _signature_dmq_tx_clone = signature_dmq_tx.clone();
408 let recorder = tokio::spawn(async move {
409 _signature_dmq_tx_clone.send(message_clone.into()).unwrap();
410 });
411
412 let (_, messages_res, _) = tokio::join!(dmq_consumer_server.run(), client, recorder);
413 let messages_received: Vec<_> = messages_res.unwrap().unwrap();
414 assert_eq!(vec![message], messages_received);
415 }
416
417 #[tokio::test(flavor = "multi_thread")]
418 async fn pallas_dmq_consumer_server_blocking_blocks_when_no_message_available() {
419 let (_stop_tx, stop_rx) = watch::channel(());
420 let (_signature_dmq_tx, signature_dmq_rx) = unbounded_channel::<DmqMessage>();
421 let socket_path = create_temp_dir(current_function!()).join("node.socket");
422 let dmq_network = DmqNetwork::TestNet(0);
423 let dmq_consumer_server = Arc::new(DmqConsumerServerPallas::new(
424 socket_path.to_path_buf(),
425 dmq_network.to_owned(),
426 stop_rx,
427 TestLogger::stdout(),
428 ));
429 dmq_consumer_server.register_receiver(signature_dmq_rx).await.unwrap();
430 let client = tokio::spawn({
431 async move {
432 tokio::time::sleep(Duration::from_millis(10)).await;
434
435 let mut client = DmqClient::connect(socket_path.clone(), 0).await.unwrap();
437
438 let client_msg = client.msg_notification();
440 assert_eq!(*client_msg.state(), localmsgnotification::State::Idle);
441
442 client_msg.send_request_messages_blocking().await.unwrap();
444 assert_eq!(
445 *client_msg.state(),
446 localmsgnotification::State::BusyBlocking
447 );
448
449 let _ = client_msg.recv_next_reply().await;
450 }
451 });
452
453 let result = tokio::select!(
454 _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
455 _res = dmq_consumer_server.run() => {Ok(())},
456 _res = client => {Ok(())},
457 );
458
459 result.expect_err("Should have timed out");
460 }
461}