1use std::{fmt::Debug, marker::PhantomData, path::PathBuf};
2
3use anyhow::{Context, anyhow};
4use pallas_network::{facades::DmqClient, miniprotocols::localmsgnotification::State};
5use slog::{Logger, debug, error};
6use tokio::sync::{Mutex, MutexGuard};
7
8use mithril_common::{
9 StdResult,
10 crypto_helper::{
11 OpCert, OpCertWithoutColdVerificationKey, TryFromBytes, ed25519::Ed25519VerificationKey,
12 },
13 entities::PartyId,
14 logging::LoggerExtensions,
15};
16
17use crate::{DmqConsumerClient, model::DmqNetwork};
18
19pub struct DmqConsumerClientPallas<M: TryFromBytes + Debug> {
23 socket: PathBuf,
24 network: DmqNetwork,
25 client: Mutex<Option<DmqClient>>,
26 logger: Logger,
27 phantom: PhantomData<M>,
28}
29
30impl<M: TryFromBytes + Debug> DmqConsumerClientPallas<M> {
31 pub fn new(socket: PathBuf, network: DmqNetwork, logger: Logger) -> Self {
33 Self {
34 socket,
35 network,
36 client: Mutex::new(None),
37 logger: logger.new_with_component_name::<Self>(),
38 phantom: PhantomData,
39 }
40 }
41
42 async fn new_client(&self) -> StdResult<DmqClient> {
44 debug!(
45 self.logger,
46 "Create new DMQ client";
47 "socket" => ?self.socket,
48 "network" => ?self.network
49 );
50 DmqClient::connect(&self.socket, self.network.magic_id())
51 .await
52 .with_context(|| "DmqConsumerClientPallas failed to create a new client")
53 }
54
55 async fn get_client(&self) -> StdResult<MutexGuard<'_, Option<DmqClient>>> {
57 {
58 let client_lock = self.client.lock().await;
60 if client_lock.as_ref().is_some() {
61 return Ok(client_lock);
62 }
63 }
64
65 let mut client_lock = self.client.lock().await;
66 *client_lock = Some(self.new_client().await?);
67
68 Ok(client_lock)
69 }
70
71 async fn drop_client(&self) -> StdResult<()> {
73 debug!(
74 self.logger,
75 "Drop existing DMQ client";
76 "socket" => ?self.socket,
77 "network" => ?self.network
78 );
79 let mut client_lock = self.client.try_lock()?;
80 if let Some(client) = client_lock.take() {
81 client.abort().await;
82 }
83
84 Ok(())
85 }
86
87 #[cfg(test)]
88 async fn has_client(&self) -> bool {
90 let client_lock = self.client.lock().await;
91
92 client_lock.as_ref().is_some()
93 }
94
95 async fn consume_messages_internal(&self) -> StdResult<Vec<(M, PartyId)>> {
96 debug!(self.logger, "Waiting for messages from DMQ...");
97 let mut client_guard = self.get_client().await?;
98 let client = client_guard.as_mut().ok_or(anyhow!("DMQ client does not exist"))?;
99 if *client.msg_notification().state() == State::Idle {
100 client
101 .msg_notification()
102 .send_request_messages_blocking()
103 .await
104 .with_context(|| "Failed to request notifications from DMQ server: {}")?;
105 }
106
107 let reply = client
108 .msg_notification()
109 .recv_next_reply()
110 .await
111 .with_context(|| "Failed to receive notifications from DMQ server")?;
112 debug!(self.logger, "Received single signatures from DMQ"; "messages" => ?reply);
113
114 reply
115 .0
116 .into_iter()
117 .map(|dmq_message| {
118 let opcert_without_verification_key = OpCertWithoutColdVerificationKey::try_new(
119 &dmq_message.operational_certificate.kes_vk,
120 dmq_message.operational_certificate.issue_number,
121 dmq_message.operational_certificate.start_kes_period,
122 &dmq_message.operational_certificate.cert_sig,
123 )
124 .with_context(|| "Failed to parse operational certificate")?;
125 let cold_verification_key =
126 Ed25519VerificationKey::from_bytes(&dmq_message.cold_verification_key)
127 .with_context(|| "Failed to parse cold verification key")?
128 .into_inner();
129 let opcert: OpCert =
130 (opcert_without_verification_key, cold_verification_key).into();
131 let party_id = opcert.compute_protocol_party_id()?;
132 let payload = M::try_from_bytes(&dmq_message.msg_payload.msg_body)
133 .with_context(|| "Failed to parse DMQ message body")?;
134
135 Ok((payload, party_id))
136 })
137 .collect::<StdResult<Vec<_>>>()
138 .with_context(|| "Failed to parse DMQ messages")
139 }
140}
141
142#[async_trait::async_trait]
143impl<M: TryFromBytes + Debug + Sync + Send> DmqConsumerClient<M> for DmqConsumerClientPallas<M> {
144 async fn consume_messages(&self) -> StdResult<Vec<(M, PartyId)>> {
145 let messages = self.consume_messages_internal().await;
146 if messages.is_err() {
147 self.drop_client().await?;
148 }
149
150 messages
151 }
152}
153
154impl<M: TryFromBytes + Debug> Drop for DmqConsumerClientPallas<M> {
155 fn drop(&mut self) {
156 tokio::task::block_in_place(|| {
157 tokio::runtime::Handle::current().block_on(async {
158 if let Err(e) = self.drop_client().await {
159 error!(self.logger, "Failed to drop DMQ consumer client: {}", e);
160 }
161 });
162 });
163 }
164}
165
166#[cfg(all(test, unix))]
167mod tests {
168
169 use std::{fs, future, time::Duration, vec};
170
171 use mithril_common::{crypto_helper::TryToBytes, current_function, test::TempDir};
172 use pallas_network::{
173 facades::DmqServer,
174 miniprotocols::{
175 localmsgnotification,
176 localmsgsubmission::{DmqMsg, DmqMsgOperationalCertificate, DmqMsgPayload},
177 },
178 };
179 use tokio::{net::UnixListener, task::JoinHandle, time::sleep};
180
181 use crate::test::{TestLogger, payload::DmqMessageTestPayload};
182
183 use super::*;
184
185 fn create_temp_dir(folder_name: &str) -> PathBuf {
186 TempDir::create_with_short_path("dmq_consumer", folder_name)
187 }
188
189 fn fake_msgs() -> Vec<DmqMsg> {
190 vec![
191 DmqMsg {
192 msg_payload: DmqMsgPayload {
193 msg_id: vec![0, 1],
194 msg_body: DmqMessageTestPayload::new(b"msg_1").to_bytes_vec().unwrap(),
195 kes_period: 10,
196 expires_at: 100,
197 },
198 kes_signature: vec![0, 1, 2, 3],
199 operational_certificate: DmqMsgOperationalCertificate {
200 kes_vk: vec![
201 50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
202 131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
203 ],
204 issue_number: 0,
205 start_kes_period: 0,
206 cert_sig: vec![
207 207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
208 5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
209 190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
210 209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
211 ],
212 },
213 cold_verification_key: vec![
214 32, 253, 186, 201, 177, 11, 117, 135, 187, 167, 181, 188, 22, 59, 206, 105,
215 231, 150, 215, 30, 78, 212, 76, 16, 252, 180, 72, 134, 137, 247, 161, 68,
216 ],
217 },
218 DmqMsg {
219 msg_payload: DmqMsgPayload {
220 msg_id: vec![1, 2],
221 msg_body: DmqMessageTestPayload::new(b"msg_2").to_bytes_vec().unwrap(),
222 kes_period: 11,
223 expires_at: 101,
224 },
225 kes_signature: vec![1, 2, 3, 4],
226 operational_certificate: DmqMsgOperationalCertificate {
227 kes_vk: vec![
228 50, 45, 160, 42, 80, 78, 184, 20, 210, 77, 140, 152, 63, 49, 165, 168, 5,
229 131, 101, 152, 110, 242, 144, 157, 176, 210, 5, 10, 166, 91, 196, 168,
230 ],
231 issue_number: 0,
232 start_kes_period: 0,
233 cert_sig: vec![
234 207, 135, 144, 168, 238, 41, 179, 216, 245, 74, 164, 231, 4, 158, 234, 141,
235 5, 19, 166, 11, 78, 34, 210, 211, 183, 72, 127, 83, 185, 156, 107, 55, 160,
236 190, 73, 251, 204, 47, 197, 86, 174, 231, 13, 49, 7, 83, 173, 177, 27, 53,
237 209, 66, 24, 203, 226, 152, 3, 91, 66, 56, 244, 206, 79, 0,
238 ],
239 },
240 cold_verification_key: vec![
241 77, 75, 24, 6, 47, 133, 2, 89, 141, 224, 69, 202, 123, 105, 240, 103, 245, 159,
242 147, 177, 110, 58, 248, 115, 58, 152, 138, 220, 35, 65, 245, 200,
243 ],
244 },
245 ]
246 }
247
248 fn setup_dmq_server(
249 socket_path: PathBuf,
250 reply_messages: Vec<DmqMsg>,
251 ) -> JoinHandle<DmqServer> {
252 tokio::spawn({
253 async move {
254 if socket_path.exists() {
256 fs::remove_file(socket_path.clone()).unwrap();
257 }
258 let listener = UnixListener::bind(socket_path).unwrap();
259 let mut server = pallas_network::facades::DmqServer::accept(&listener, 0)
260 .await
261 .unwrap();
262
263 let server_msg = server.msg_notification();
265
266 let request = server_msg.recv_next_request().await.unwrap();
268 assert_eq!(request, localmsgnotification::Request::Blocking);
269
270 if !reply_messages.is_empty() {
271 server_msg.send_reply_messages_blocking(reply_messages).await.unwrap();
273 assert_eq!(*server_msg.state(), localmsgnotification::State::Idle);
274 } else {
275 future::pending().await
277 }
278
279 server
280 }
281 })
282 }
283
284 #[tokio::test(flavor = "multi_thread")]
285 async fn pallas_dmq_consumer_client_succeeds_when_messages_are_available() {
286 let socket_path = create_temp_dir(current_function!()).join("node.socket");
287 let reply_messages = fake_msgs();
288 let server = setup_dmq_server(socket_path.clone(), reply_messages);
289 let client = tokio::spawn(async move {
290 tokio::time::sleep(Duration::from_millis(10)).await;
292
293 let consumer = DmqConsumerClientPallas::new(
294 socket_path,
295 DmqNetwork::TestNet(0),
296 TestLogger::stdout(),
297 );
298
299 consumer.consume_messages().await.unwrap()
300 });
301
302 let (_, client_res) = tokio::join!(server, client);
303 let messages = client_res.unwrap();
304
305 assert_eq!(
306 vec![
307 (
308 DmqMessageTestPayload::new(b"msg_1"),
309 "pool1mxyec46067n3querj9cxkk0g0zlag93pf3ya9vuyr3wgkq2e6t7".to_string()
310 ),
311 (
312 DmqMessageTestPayload::new(b"msg_2"),
313 "pool17sln0evyk5tfj6zh2qrlk9vttgy6264sfe2fkec5mheasnlx3yd".to_string()
314 ),
315 ],
316 messages
317 );
318 }
319
320 #[tokio::test]
321 async fn pallas_dmq_consumer_client_blocks_when_no_message_available() {
322 let socket_path = create_temp_dir(current_function!()).join("node.socket");
323 let reply_messages = vec![];
324 let server = setup_dmq_server(socket_path.clone(), reply_messages);
325 let client = tokio::spawn(async move {
326 tokio::time::sleep(Duration::from_millis(10)).await;
328
329 let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
330 socket_path,
331 DmqNetwork::TestNet(0),
332 TestLogger::stdout(),
333 );
334
335 consumer.consume_messages().await.unwrap();
336 });
337
338 let result = tokio::select!(
339 _res = sleep(Duration::from_millis(100)) => {Err(anyhow!("Timeout"))},
340 _res = client => {Ok(())},
341 _res = server => {Ok(())},
342 );
343
344 result.expect_err("Should have timed out");
345 }
346
347 #[tokio::test(flavor = "multi_thread")]
348 async fn pallas_dmq_consumer_client_is_dropped_when_returning_error() {
349 let socket_path = create_temp_dir(current_function!()).join("node.socket");
350 let reply_messages = fake_msgs();
351 let server = setup_dmq_server(socket_path.clone(), reply_messages);
352 let client = tokio::spawn(async move {
353 tokio::time::sleep(Duration::from_millis(10)).await;
355
356 let consumer = DmqConsumerClientPallas::<DmqMessageTestPayload>::new(
357 socket_path,
358 DmqNetwork::TestNet(0),
359 TestLogger::stdout(),
360 );
361
362 consumer.consume_messages().await.unwrap();
363
364 consumer
365 });
366
367 let (server_res, client_res) = tokio::join!(server, client);
368 let consumer = client_res.unwrap();
369 let server = server_res.unwrap();
370 server.abort().await;
371
372 let client = tokio::spawn(async move {
373 assert!(consumer.has_client().await, "Client should exist");
374
375 consumer
376 .consume_messages()
377 .await
378 .expect_err("Consuming messages should fail");
379
380 assert!(
381 !consumer.has_client().await,
382 "Client should have been dropped after error"
383 );
384
385 consumer
386 });
387 client.await.unwrap();
388 }
389}