1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqClient, miniprotocols::localmsgnotification::State};
5use slog::{Logger, debug, error};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9 CardanoNetwork, StdResult,
10 crypto_helper::{OpCert, TryFromBytes},
11 entities::PartyId,
12 logging::LoggerExtensions,
13};
14
15use crate::DmqConsumerClient;
16
17pub struct DmqConsumerClientPallas<M: TryFromBytes + Debug> {
21 socket: PathBuf,
22 network: CardanoNetwork,
23 client: Mutex<Option<DmqClient>>,
24 logger: Logger,
25 phantom: PhantomData<M>,
26}
27
28impl<M: TryFromBytes + Debug> DmqConsumerClientPallas<M> {
29 pub fn new(socket: PathBuf, network: CardanoNetwork, logger: Logger) -> Self {
31 Self {
32 socket,
33 network,
34 client: Mutex::new(None),
35 logger: logger.new_with_component_name::<Self>(),
36 phantom: PhantomData,
37 }
38 }
39
40 async fn new_client(&self) -> StdResult<DmqClient> {
42 debug!(
43 self.logger,
44 "Create new DMQ client";
45 "socket" => ?self.socket,
46 "network" => ?self.network
47 );
48 DmqClient::connect(&self.socket, self.network.magic_id())
49 .await
50 .with_context(|| "DmqConsumerClientPallas failed to create a new client")
51 }
52
53 async fn get_client(&self) -> StdResult<MutexGuard<'_, Option<DmqClient>>> {
55 {
56 let client_lock = self.client.lock().await;
58 if client_lock.as_ref().is_some() {
59 return Ok(client_lock);
60 }
61 }
62
63 let mut client_lock = self.client.lock().await;
64 *client_lock = Some(self.new_client().await?);
65
66 Ok(client_lock)
67 }
68
69 async fn drop_client(&self) -> StdResult<()> {
71 debug!(
72 self.logger,
73 "Drop existing DMQ client";
74 "socket" => ?self.socket,
75 "network" => ?self.network
76 );
77 let mut client_lock = self.client.try_lock()?;
78 if let Some(client) = client_lock.take() {
79 client.abort().await;
80 }
81
82 Ok(())
83 }
84
85 #[cfg(test)]
86 async fn has_client(&self) -> bool {
88 let client_lock = self.client.lock().await;
89
90 client_lock.as_ref().is_some()
91 }
92
93 async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
94 debug!(self.logger, "Waiting for messages from DMQ...");
95 let mut client_guard = self.get_client().await?;
96 let client = client_guard.as_mut().ok_or(anyhow!("DMQ client does not exist"))?;
97 if *client.msg_notification().state() == State::Idle {
98 client
99 .msg_notification()
100 .send_request_messages_blocking()
101 .await
102 .with_context(|| "Failed to request notifications from DMQ server: {}")?;
103 }
104
105 let reply = client
106 .msg_notification()
107 .recv_next_reply()
108 .await
109 .with_context(|| "Failed to receive notifications from DMQ server")?;
110 debug!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
111
112 reply
113 .0
114 .into_iter()
115 .map(|dmq_message| {
116 let opcert = OpCert::try_from_bytes(&dmq_message.operational_certificate)
117 .with_context(|| "Failed to parse operational certificate")?;
118 let party_id = opcert.compute_protocol_party_id()?;
119 let payload = M::try_from_bytes(&dmq_message.msg_body)
120 .with_context(|| "Failed to parse DMQ message body")?;
121
122 Ok((payload, party_id))
123 })
124 .collect::<StdResult<Vec<_>>>()
125 .with_context(|| "Failed to parse DMQ messages")
126 }
127}
128
129#[async_trait::async_trait]
130impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumerClient<M> for DmqConsumerClientPallas<M> {
131 async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
132 let messages = self.consume_messages_internal().await;
133 if messages.is_err() {
134 self.drop_client().await?;
135 }
136
137 messages
138 }
139}
140
141impl<M: TryFromBytes + Debug> Drop for DmqConsumerClientPallas<M> {
142 fn drop(&mut self) {
143 tokio::task::block_in_place(|| {
144 tokio::runtime::Handle::current().block_on(async {
145 if let Err(e) = self.drop_client().await {
146 error!(self.logger, "Failed to drop DMQ consumer client: {}", e);
147 }
148 });
149 });
150 }
151}
152
153#[cfg(all(test, unix))]
154mod tests {
155
156 use std::{fs, future, time::Duration, vec};
157
158 use mithril_common::{crypto_helper::TryToBytes, current_function, test::TempDir};
159 use pallas_network::{
160 facades::DmqServer,
161 miniprotocols::{localmsgnotification, localmsgsubmission::DmqMsg},
162 };
163 use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
164
165 use crate::test::{TestLogger, payload::DmqMessageTestPayload};
166
167 use super::*;
168
169 fn create_temp_dir(folder_name: &str) -> PathBuf {
170 TempDir::create_with_short_path("dmq_consumer", folder_name)
171 }
172
173 fn fake_msgs() -> Vec<DmqMsg> {
174 vec![
175 DmqMsg {
176 msg_id: vec![0, 1],
177 msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
178 block_number: 10,
179 ttl: 100,
180 kes_signature: vec![0, 1, 2, 3],
181 operational_certificate: vec![
182 130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
183 198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
184 203, 41, 0, 0, 88, 64, 212, 171, 206, 39, 218, 5, 255, 3, 193, 52, 44, 198,
185 171, 83, 19, 80, 114, 225, 186, 191, 156, 192, 84, 146, 245, 159, 31, 240, 9,
186 247, 4, 87, 170, 168, 98, 199, 21, 139, 19, 190, 12, 251, 65, 215, 169, 26, 86,
187 37, 137, 188, 17, 14, 178, 205, 175, 93, 39, 86, 4, 138, 187, 234, 95, 5, 88,
188 32, 32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
189 231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
190 ],
191 kes_period: 10,
192 },
193 DmqMsg {
194 msg_id: vec![1, 2],
195 msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
196 block_number: 11,
197 ttl: 100,
198 kes_signature: vec![1, 2, 3, 4],
199 operational_certificate: vec![
200 130, 132, 88, 32, 230, 80, 215, 83, 21, 9, 187, 108, 255, 215, 153, 140, 40,
201 198, 142, 78, 200, 250, 98, 26, 9, 82, 32, 110, 161, 30, 176, 63, 205, 125,
202 203, 41, 0, 0, 88, 64, 132, 4, 199, 39, 190, 173, 88, 102, 121, 117, 55, 62,
203 39, 189, 113, 96, 175, 24, 171, 240, 74, 42, 139, 202, 128, 185, 44, 130, 209,
204 77, 191, 122, 196, 224, 33, 158, 187, 156, 203, 190, 173, 150, 247, 87, 172,
205 58, 153, 185, 157, 87, 128, 14, 187, 107, 187, 215, 105, 195, 107, 135, 172,
206 43, 173, 9, 88, 32, 77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105,
207 240, 103, 245, 159, 147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65,
208 245, 200,
209 ],
210 kes_period: 11,
211 },
212 ]
213 }
214
215 fn setup_dmq_server(
216 socket_path: PathBuf,
217 reply_messages: Vec<DmqMsg>,
218 ) -> JoinHandle<DmqServer> {
219 tokio::spawn({
220 async move {
221 if socket_path.exists() {
223 fs::remove_file(socket_path.clone()).unwrap();
224 }
225 let listener = UnixListener::bind(socket_path).unwrap();
226 let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
227 .await
228 .unwrap();
229
230 let server_msg = server.msg_notification();
232
233 let request = server_msg.recv_next_request().await.unwrap();
235 assert_eq!(request, localmsgnotification::Request::Blocking);
236
237 if !reply_messages.is_empty() {
238 server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
240 assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
241 } else {
242 future::pending().await
244 }
245
246 server
247 }
248 })
249 }
250
251 #[tokio::test(flavor = "multi_thread")]
252 async fn pallas_dmq_consumer_client_succeeds_when_messages_are_available() {
253 let socket_path = create_temp_dir(current_function!()).join("node.socket");
254 let reply_messages = fake_msgs();
255 let server = setup_dmq_server(socket_path.clone(), reply_messages);
256 let client = tokio::spawn(async move {
257 tokio::time::sleep(Duration::from_millis(10)).await;
259
260 let consumer = DmqConsumerClientPallas::new(
261 socket_path,
262 CardanoNetwork::TestNet(0),
263 TestLogger::stdout(),
264 );
265
266 consumer.consume_messages().await.unwrap()
267 });
268
269 let (_, client_res) = tokio::join!(server, client);
270 let messages = client_res.unwrap();
271
272 assert_eq!(
273 vec![
274 (
275 DmqMessageTestPayload::new(b"msg_1"),
276 "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
277 ),
278 (
279 DmqMessageTestPayload::new(b"msg_2"),
280 "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
281 ),
282 ],
283 messages
284 );
285 }
286
287 #[tokio::test]
288 async fn pallas_dmq_consumer_client_blocks_when_no_message_available() {
289 let socket_path = create_temp_dir(current_function!()).join("node.socket");
290 let reply_messages = vec![];
291 let server = setup_dmq_server(socket_path.clone(), reply_messages);
292 let client = tokio::spawn(async move {
293 tokio::time::sleep(Duration::from_millis(10)).await;
295
296 let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
297 socket_path,
298 CardanoNetwork::TestNet(0),
299 TestLogger::stdout(),
300 );
301
302 consumer.consume_messages().await.unwrap();
303 });
304
305 let result = tokio::select!(
306 _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
307 _res = client => {Ok(())},
308 _res = server => {Ok(())},
309 );
310
311 result.expect_err("Should have timed out");
312 }
313
314 #[tokio::test(flavor = "multi_thread")]
315 async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
316 let socket_path = create_temp_dir(current_function!()).join("node.socket");
317 let reply_messages = fake_msgs();
318 let server = setup_dmq_server(socket_path.clone(), reply_messages);
319 let client = tokio::spawn(async move {
320 tokio::time::sleep(Duration::from_millis(10)).await;
322
323 let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
324 socket_path,
325 CardanoNetwork::TestNet(0),
326 TestLogger::stdout(),
327 );
328
329 consumer.consume_messages().await.unwrap();
330
331 consumer
332 });
333
334 let (server_res, client_res) = tokio::join!(server, client);
335 let consumer = client_res.unwrap();
336 let server = server_res.unwrap();
337 server.abort().await;
338
339 let client = tokio::spawn(async move {
340 assert!(consumer.has_client().await, "Client should exist");
341
342 consumer
343 .consume_messages()
344 .await
345 .expect_err("Consuming messages should fail");
346
347 assert!(
348 !consumer.has_client().await,
349 "Client should have been dropped after error"
350 );
351
352 consumer
353 });
354 client.await.unwrap();
355 }
356}