1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::facades::DmqClient;
5use slog::{Logger, debug, error};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9 CardanoNetwork, StdResult,
10 crypto_helper::{OpCert, TryFromBytes},
11 entities::PartyId,
12 logging::LoggerExtensions,
13};
14
15use crate::DmqConsumer;
16
17pub struct DmqConsumerPallas<M: TryFromBytes + Debug> {
21 socket: PathBuf,
22 network: CardanoNetwork,
23 client: Mutex<Option<DmqClient>>,
24 logger: Logger,
25 phantom: PhantomData<M>,
26}
27
28impl<M: TryFromBytes + Debug> DmqConsumerPallas<M> {
29 pub fn new(socket: PathBuf, network: CardanoNetwork, logger: Logger) -> Self {
31 Self {
32 socket,
33 network,
34 client: Mutex::new(None),
35 logger: logger.new_with_component_name::<Self>(),
36 phantom: PhantomData,
37 }
38 }
39
40 async fn new_client(&self) -> StdResult<DmqClient> {
42 debug!(
43 self.logger,
44 "Create new DMQ client";
45 "socket" => ?self.socket,
46 "network" => ?self.network
47 );
48 DmqClient::connect(&self.socket, self.network.code())
49 .await
50 .with_context(|| "DmqConsumerPallas failed to create a new client")
51 }
52
53 async fn get_client(&self) -> StdResult<MutexGuard<Option<DmqClient>>> {
55 {
56 let client_lock = self.client.lock().await;
58 if client_lock.as_ref().is_some() {
59 return Ok(client_lock);
60 }
61 }
62
63 let mut client_lock = self.client.lock().await;
64 *client_lock = Some(self.new_client().await?);
65
66 Ok(client_lock)
67 }
68
69 async fn drop_client(&self) -> StdResult<()> {
71 debug!(
72 self.logger,
73 "Drop existing DMQ client";
74 "socket" => ?self.socket,
75 "network" => ?self.network
76 );
77 let mut client_lock = self.client.lock().await;
78 if let Some(client) = client_lock.take() {
79 client.abort().await;
80 }
81
82 Ok(())
83 }
84
85 #[cfg(test)]
86 async fn has_client(&self) -> bool {
88 let client_lock = self.client.lock().await;
89
90 client_lock.as_ref().is_some()
91 }
92
93 async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
94 debug!(self.logger, "Waiting for messages from DMQ...");
95 let mut client_guard = self.get_client().await?;
96 let client = client_guard.as_mut().ok_or(anyhow!("DMQ client does not exist"))?;
97 client
98 .msg_notification()
99 .send_request_messages_blocking()
100 .await
101 .with_context(|| "Failed to request notifications from DMQ server: {}")?;
102
103 let reply = client
104 .msg_notification()
105 .recv_next_reply()
106 .await
107 .with_context(|| "Failed to receive notifications from DMQ server")?;
108 debug!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
109 if let Err(e) = client.msg_notification().send_done().await {
110 error!(self.logger, "Failed to send Done"; "error" => ?e);
111 }
112
113 reply
114 .0
115 .into_iter()
116 .map(|dmq_message| {
117 let opcert = OpCert::try_from_bytes(&dmq_message.operational_certificate)
118 .with_context(|| "Failed to parse operational certificate")?;
119 let party_id = opcert.compute_protocol_party_id()?;
120 let payload = M::try_from_bytes(&dmq_message.msg_body)
121 .with_context(|| "Failed to parse DMQ message body")?;
122
123 Ok((payload, party_id))
124 })
125 .collect::<StdResult<Vec<_>>>()
126 .with_context(|| "Failed to parse DMQ messages")
127 }
128}
129
130#[async_trait::async_trait]
131impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumer<M> for DmqConsumerPallas<M> {
132 async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
133 let messages = self.consume_messages_internal().await;
134 if messages.is_err() {
135 self.drop_client().await?;
136 }
137
138 messages
139 }
140}
141
142#[cfg(all(test, unix))]
143mod tests {
144
145 use std::{fs, future, time::Duration, vec};
146
147 use mithril_common::{crypto_helper::TryToBytes, current_function, test_utils::TempDir};
148 use pallas_network::{
149 facades::DmqServer,
150 miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
151 };
152 use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
153
154 use crate::{test::payload::DmqMessageTestPayload, test_tools::TestLogger};
155
156 use super::*;
157
158 fn create_temp_dir(folder_name: &str) -> PathBuf {
159 TempDir::create_with_short_path("dmq_consumer", folder_name)
160 }
161
162 fn fake_msgs() -> Vec<DmqMsg> {
163 vec![
164 DmqMsg {
165 msg_id: vec![0, 1],
166 msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
167 block_number: 10,
168 ttl: 100,
169 kes_signature: vec![0, 1, 2, 3],
170 operational_certificate: vec![
171 130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
172 198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
173 203, 41, 0, 0, 88, 64, 212, 171, 206, 39, 218, 5, 255, 3, 193, 52, 44, 198,
174 171, 83, 19, 80, 114, 225, 186, 191, 156, 192, 84, 146, 245, 159, 31, 240, 9,
175 247, 4, 87, 170, 168, 98, 199, 21, 139, 19, 190, 12, 251, 65, 215, 169, 26, 86,
176 37, 137, 188, 17, 14, 178, 205, 175, 93, 39, 86, 4, 138, 187, 234, 95, 5, 88,
177 32, 32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
178 231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
179 ],
180 kes_period: 10,
181 },
182 DmqMsg {
183 msg_id: vec![1, 2],
184 msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
185 block_number: 11,
186 ttl: 100,
187 kes_signature: vec![1, 2, 3, 4],
188 operational_certificate: vec![
189 130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
190 198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
191 203, 41, 0, 0, 88, 64, 132, 4, 199, 39, 190, 173, 88, 102, 121, 117, 55, 62,
192 39, 189, 113, 96, 175, 24, 171, 240, 74, 42, 139, 202, 128, 185, 44, 130, 209,
193 77, 191, 122, 196, 224, 33, 158, 187, 156, 203, 190, 173, 150, 247, 87, 172,
194 58, 153, 185, 157, 87, 128, 14, 187, 107, 187, 215, 105, 195, 107, 135, 172,
195 43, 173, 9, 88, 32, 77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105,
196 240, 103, 245, 159, 147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65,
197 245, 200,
198 ],
199 kes_period: 11,
200 },
201 ]
202 }
203
204 fn setup_dmq_server(
205 socket_path: PathBuf,
206 reply_messages: Vec<DmqMsg>,
207 ) -> JoinHandle<DmqServer> {
208 tokio::spawn({
209 async move {
210 if socket_path.exists() {
212 fs::remove_file(socket_path.clone()).unwrap();
213 }
214 let listener = UnixListener::bind(socket_path).unwrap();
215 let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
216 .await
217 .unwrap();
218
219 let server_msg = server.msg_notification();
221
222 let request = server_msg.recv_next_request().await.unwrap();
224 assert_eq!(request, localmsgnotification::Request::Blocking);
225
226 if !reply_messages.is_empty() {
227 server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
229 assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
230
231 server_msg.recv_done().await.unwrap();
233 assert_eq!(*server_msg.state(), localmsgnotification::State::Done);
234 } else {
235 future::pending().await
237 }
238
239 server
240 }
241 })
242 }
243
244 #[tokio::test]
245 async fn pallas_dmq_consumer_publisher_succeeds_when_messages_are_available() {
246 let socket_path = create_temp_dir(current_function!()).join("node.socket");
247 let reply_messages = fake_msgs();
248 let server = setup_dmq_server(socket_path.clone(), reply_messages);
249 let client = tokio::spawn(async move {
250 let consumer = DmqConsumerPallas::new(
251 socket_path,
252 CardanoNetwork::TestNet(0),
253 TestLogger::stdout(),
254 );
255
256 consumer.consume_messages().await.unwrap()
257 });
258
259 let (_, client_res) = tokio::join!(server, client);
260 let messages = client_res.unwrap();
261
262 assert_eq!(
263 vec![
264 (
265 DmqMessageTestPayload::new(b"msg_1"),
266 "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
267 ),
268 (
269 DmqMessageTestPayload::new(b"msg_2"),
270 "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
271 ),
272 ],
273 messages
274 );
275 }
276
277 #[tokio::test]
278 async fn pallas_dmq_consumer_publisher_blocks_when_no_message_available() {
279 let socket_path = create_temp_dir(current_function!()).join("node.socket");
280 let reply_messages = vec![];
281 let server = setup_dmq_server(socket_path.clone(), reply_messages);
282 let client = tokio::spawn(async move {
283 let consumer = DmqConsumerPallas::<DmqMessageTestPayload>::new(
284 socket_path,
285 CardanoNetwork::TestNet(0),
286 TestLogger::stdout(),
287 );
288
289 consumer.consume_messages().await.unwrap();
290 });
291
292 let result = tokio::select!(
293 _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
294 _res = client => {Ok(())},
295 _res = server => {Ok(())},
296 );
297
298 result.expect_err("Should have timed out");
299 }
300
301 #[tokio::test]
302 async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
303 let socket_path = create_temp_dir(current_function!()).join("node.socket");
304 let reply_messages = fake_msgs();
305 let server = setup_dmq_server(socket_path.clone(), reply_messages);
306 let client = tokio::spawn(async move {
307 let consumer = DmqConsumerPallas::<DmqMessageTestPayload>::new(
308 socket_path,
309 CardanoNetwork::TestNet(0),
310 TestLogger::stdout(),
311 );
312
313 consumer.consume_messages().await.unwrap();
314
315 consumer
316 });
317
318 let (server_res, client_res) = tokio::join!(server, client);
319 let consumer = client_res.unwrap();
320 let server = server_res.unwrap();
321 server.abort().await;
322
323 let client = tokio::spawn(async move {
324 assert!(consumer.has_client().await, "Client should exist");
325
326 consumer
327 .consume_messages()
328 .await
329 .expect_err("Consuming messages should fail");
330
331 assert!(
332 !consumer.has_client().await,
333 "Client should have been dropped after error"
334 );
335
336 consumer
337 });
338 client.await.unwrap();
339 }
340}