1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::Context;
4use pallas_network::{facades::DmqClient, miniprotocols::localmsgnotification::State};
5use slog::{Logger, debug, error, trace};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9 StdResult,
10 crypto_helper::{
11 KesPeriod, OpCert, OpCertWithoutColdVerificationKey, TryFromBytes,
12 ed25519::Ed25519VerificationKey,
13 },
14 entities::PartyId,
15 logging::LoggerExtensions,
16};
17
18use crate::{DmqConsumerClient, model::DmqNetwork};
19
20pub struct DmqConsumerClientPallas<M: TryFromBytes + Debug> {
24 socket: PathBuf,
25 network: DmqNetwork,
26 client: Mutex<Option<DmqClient>>,
27 logger: Logger,
28 phantom: PhantomData<M>,
29}
30
31impl<M: TryFromBytes + Debug> DmqConsumerClientPallas<M> {
32 pub fn new(socket: PathBuf, network: DmqNetwork, logger: Logger) -> Self {
34 Self {
35 socket,
36 network,
37 client: Mutex::new(None),
38 logger: logger.new_with_component_name::<Self>(),
39 phantom: PhantomData,
40 }
41 }
42
43 async fn new_client(&self) -> StdResult<DmqClient> {
45 debug!(
46 self.logger,
47 "Create new DMQ client";
48 "socket" => ?self.socket,
49 "network" => ?self.network
50 );
51 DmqClient::connect(&self.socket, self.network.magic_id())
52 .await
53 .with_context(|| "DmqConsumerClientPallas failed to create a new client")
54 }
55
56 async fn get_client(&self) -> StdResult<MutexGuard<'_, Option<DmqClient>>> {
58 {
59 let client_lock = self.client.lock().await;
61 if client_lock.as_ref().is_some() {
62 return Ok(client_lock);
63 }
64 }
65
66 let mut client_lock = self.client.lock().await;
67 *client_lock = Some(self.new_client().await?);
68
69 Ok(client_lock)
70 }
71
72 async fn drop_client(&self) -> StdResult<()> {
74 debug!(
75 self.logger,
76 "Drop existing DMQ client";
77 "socket" => ?self.socket,
78 "network" => ?self.network
79 );
80 let mut client_lock = self.client.try_lock()?;
81 if let Some(client) = client_lock.take() {
82 client.abort().await;
83 }
84
85 Ok(())
86 }
87
88 #[cfg(test)]
89 async fn has_client(&self) -> bool {
91 let client_lock = self.client.lock().await;
92
93 client_lock.as_ref().is_some()
94 }
95
96 async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
97 debug!(self.logger, "Waiting for messages from DMQ...");
98 let mut client_guard = self.get_client().await?;
99 let client = client_guard.as_mut().with_context(|| "DMQ client does not exist")?;
100 if *client.msg_notification().state() == State::Idle {
101 client
102 .msg_notification()
103 .send_request_messages_blocking()
104 .await
105 .with_context(|| "Failed to request notifications from DMQ server: {}")?;
106 }
107
108 let reply = client
109 .msg_notification()
110 .recv_next_reply()
111 .await
112 .with_context(|| "Failed to receive notifications from DMQ server")?;
113 debug!(self.logger, "Received single signatures from DMQ"; "total_messages" => reply.0.len());
114 trace!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
115
116 reply
117 .0
118 .into_iter()
119 .map(|dmq_message| {
120 debug!(
121 self.logger,
122 "Received DMQ message";
123 "msg_id" => hex::encode(&dmq_message.msg_payload.msg_id),
124 "kes_period" => dmq_message.msg_payload.kes_period,
125 "expires_at" => dmq_message.msg_payload.expires_at,
126 );
127 let opcert_without_verification_key = OpCertWithoutColdVerificationKey::try_new(
128 &dmq_message.operational_certificate.kes_vk,
129 dmq_message.operational_certificate.issue_number,
130 KesPeriod(dmq_message.operational_certificate.start_kes_period),
131 &dmq_message.operational_certificate.cert_sig,
132 )
133 .with_context(|| "Failed to parse operational certificate")?;
134 let cold_verification_key =
135 Ed25519VerificationKey::from_bytes(&dmq_message.cold_verification_key)
136 .with_context(|| "Failed to parse cold verification key")?
137 .into_inner();
138 let opcert: OpCert =
139 (opcert_without_verification_key, cold_verification_key).into();
140 let party_id = opcert.compute_protocol_party_id()?;
141 let payload = M::try_from_bytes(&dmq_message.msg_payload.msg_body)
142 .with_context(|| "Failed to parse DMQ message body")?;
143
144 Ok((payload, party_id))
145 })
146 .collect::<StdResult<Vec<_>>>()
147 .with_context(|| "Failed to parse DMQ messages")
148 }
149}
150
151#[async_trait::async_trait]
152impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumerClient<M> for DmqConsumerClientPallas<M> {
153 async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
154 let messages = self.consume_messages_internal().await;
155 if messages.is_err() {
156 self.drop_client().await?;
157 }
158
159 messages
160 }
161}
162
163impl<M: TryFromBytes + Debug> Drop for DmqConsumerClientPallas<M> {
164 fn drop(&mut self) {
165 tokio::task::block_in_place(|| {
166 tokio::runtime::Handle::current().block_on(async {
167 if let Err(e) = self.drop_client().await {
168 error!(self.logger, "Failed to drop DMQ consumer client: {}", e);
169 }
170 });
171 });
172 }
173}
174
175#[cfg(all(test, unix))]
176mod tests {
177 use anyhow::anyhow;
178 use pallas_network::{
179 facades::DmqServer,
180 miniprotocols::{
181 localmsgnotification,
182 localmsgsubmission::{DmqMsg, DmqMsgOperationalCertificate, DmqMsgPayload},
183 },
184 };
185 use std::{fs, future, time::Duration, vec};
186 use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
187
188 use mithril_common::{crypto_helper::TryToBytes, current_function, test::TempDir};
189
190 use crate::test::{TestLogger, payload::DmqMessageTestPayload};
191
192 use super::*;
193
194 fn create_temp_dir(folder_name: &str) -> PathBuf {
195 TempDir::create_with_short_path("dmq_consumer", folder_name)
196 }
197
198 fn fake_msgs() -> Vec<DmqMsg> {
199 vec![
200 DmqMsg {
201 msg_payload: DmqMsgPayload {
202 msg_id: vec![0, 1],
203 msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
204 kes_period: 10,
205 expires_at: 100,
206 },
207 kes_signature: vec![0, 1, 2, 3],
208 operational_certificate: DmqMsgOperationalCertificate {
209 kes_vk: vec![
210 50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
211 131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
212 ],
213 issue_number: 0,
214 start_kes_period: 0,
215 cert_sig: vec![
216 207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
217 5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
218 190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
219 209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
220 ],
221 },
222 cold_verification_key: vec![
223 32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
224 231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
225 ],
226 },
227 DmqMsg {
228 msg_payload: DmqMsgPayload {
229 msg_id: vec![1, 2],
230 msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
231 kes_period: 11,
232 expires_at: 101,
233 },
234 kes_signature: vec![1, 2, 3, 4],
235 operational_certificate: DmqMsgOperationalCertificate {
236 kes_vk: vec![
237 50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
238 131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
239 ],
240 issue_number: 0,
241 start_kes_period: 0,
242 cert_sig: vec![
243 207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
244 5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
245 190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
246 209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
247 ],
248 },
249 cold_verification_key: vec![
250 77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105, 240, 103, 245, 159,
251 147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65, 245, 200,
252 ],
253 },
254 ]
255 }
256
257 fn setup_dmq_server(
258 socket_path: PathBuf,
259 reply_messages: Vec<DmqMsg>,
260 ) -> JoinHandle<DmqServer> {
261 tokio::spawn({
262 async move {
263 if socket_path.exists() {
265 fs::remove_file(socket_path.clone()).unwrap();
266 }
267 let listener = UnixListener::bind(socket_path).unwrap();
268 let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
269 .await
270 .unwrap();
271
272 let server_msg = server.msg_notification();
274
275 let request = server_msg.recv_next_request().await.unwrap();
277 assert_eq!(request, localmsgnotification::Request::Blocking);
278
279 if !reply_messages.is_empty() {
280 server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
282 assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
283 } else {
284 future::pending().await
286 }
287
288 server
289 }
290 })
291 }
292
293 #[tokio::test(flavor = "multi_thread")]
294 async fn pallas_dmq_consumer_client_succeeds_when_messages_are_available() {
295 let socket_path = create_temp_dir(current_function!()).join("node.socket");
296 let reply_messages = fake_msgs();
297 let server = setup_dmq_server(socket_path.clone(), reply_messages);
298 let client = tokio::spawn(async move {
299 tokio::time::sleep(Duration::from_millis(10)).await;
301
302 let consumer = DmqConsumerClientPallas::new(
303 socket_path,
304 DmqNetwork::TestNet(0),
305 TestLogger::stdout(),
306 );
307
308 consumer.consume_messages().await.unwrap()
309 });
310
311 let (_, client_res) = tokio::join!(server, client);
312 let messages = client_res.unwrap();
313
314 assert_eq!(
315 vec![
316 (
317 DmqMessageTestPayload::new(b"msg_1"),
318 "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
319 ),
320 (
321 DmqMessageTestPayload::new(b"msg_2"),
322 "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
323 ),
324 ],
325 messages
326 );
327 }
328
329 #[tokio::test]
330 async fn pallas_dmq_consumer_client_blocks_when_no_message_available() {
331 let socket_path = create_temp_dir(current_function!()).join("node.socket");
332 let reply_messages = vec![];
333 let server = setup_dmq_server(socket_path.clone(), reply_messages);
334 let client = tokio::spawn(async move {
335 tokio::time::sleep(Duration::from_millis(10)).await;
337
338 let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
339 socket_path,
340 DmqNetwork::TestNet(0),
341 TestLogger::stdout(),
342 );
343
344 consumer.consume_messages().await.unwrap();
345 });
346
347 let result = tokio::select!(
348 _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
349 _res = client => {Ok(())},
350 _res = server => {Ok(())},
351 );
352
353 result.expect_err("Should have timed out");
354 }
355
356 #[tokio::test(flavor = "multi_thread")]
357 async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
358 let socket_path = create_temp_dir(current_function!()).join("node.socket");
359 let reply_messages = fake_msgs();
360 let server = setup_dmq_server(socket_path.clone(), reply_messages);
361 let client = tokio::spawn(async move {
362 tokio::time::sleep(Duration::from_millis(10)).await;
364
365 let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
366 socket_path,
367 DmqNetwork::TestNet(0),
368 TestLogger::stdout(),
369 );
370
371 consumer.consume_messages().await.unwrap();
372
373 consumer
374 });
375
376 let (server_res, client_res) = tokio::join!(server, client);
377 let consumer = client_res.unwrap();
378 let server = server_res.unwrap();
379 server.abort().await;
380
381 let client = tokio::spawn(async move {
382 assert!(consumer.has_client().await, "Client should exist");
383
384 consumer
385 .consume_messages()
386 .await
387 .expect_err("Consuming messages should fail");
388
389 assert!(
390 !consumer.has_client().await,
391 "Client should have been dropped after error"
392 );
393
394 consumer
395 });
396 client.await.unwrap();
397 }
398}