mithril_cardano_node_chain/chain_reader/
pallas_chain_reader.rs

1use std::path::{Path, PathBuf};
2use std::time::Duration;
3
4use anyhow::Context;
5use async_trait::async_trait;
6use pallas_network::{
7    facades::NodeClient,
8    miniprotocols::chainsync::{BlockContent, NextResponse},
9};
10use pallas_traverse::MultiEraBlock;
11use slog::{Logger, debug, warn};
12use tokio::time::timeout;
13
14use mithril_common::StdResult;
15use mithril_common::entities::CardanoNetwork;
16use mithril_common::logging::LoggerExtensions;
17
18use crate::entities::{ChainBlockNextAction, RawCardanoPoint, ScannedBlock};
19
20use super::ChainBlockReader;
21
22/// Default timeout duration for Pallas `chainsync` operations.
23const DEFAULT_CHAINSYNC_TIMEOUT: Duration = Duration::from_secs(60);
24
25/// [PallasChainReader] reads blocks with 'chainsync' mini-protocol
26pub struct PallasChainReader {
27    socket: PathBuf,
28    network: CardanoNetwork,
29    client: Option<NodeClient>,
30    chainsync_timeout: Duration,
31    logger: Logger,
32}
33
34impl PallasChainReader {
35    /// Creates a new `PallasChainReader` with the specified socket and network.
36    pub fn new(socket: &Path, network: CardanoNetwork, logger: Logger) -> Self {
37        Self {
38            socket: socket.to_owned(),
39            network,
40            client: None,
41            chainsync_timeout: DEFAULT_CHAINSYNC_TIMEOUT,
42            logger: logger.new_with_component_name::<Self>(),
43        }
44    }
45
46    /// Sets the timeout duration for chainsync operations.
47    pub fn with_timeout(mut self, timeout: Duration) -> Self {
48        self.chainsync_timeout = timeout;
49        self
50    }
51
52    /// Creates and returns a new `NodeClient` connected to the specified socket.
53    async fn new_client(&self) -> StdResult<NodeClient> {
54        let magic = self.network.magic_id();
55        timeout(
56            self.chainsync_timeout,
57            NodeClient::connect(&self.socket, magic),
58        )
59        .await
60        .map_err(|_| {
61            warn!(
62                self.logger, "Timeout elapsed while connecting to the Cardano node";
63                "timeout" => ?self.chainsync_timeout, "socket" => self.socket.display().to_string()
64            );
65            anyhow::anyhow!("PallasChainReader timed out connecting to the Cardano node")
66        })?
67        .with_context(|| "PallasChainReader failed to create a new client")
68    }
69
70    /// Returns a mutable reference to the client.
71    async fn get_client(&mut self) -> StdResult<&mut NodeClient> {
72        if self.client.is_none() {
73            self.client = Some(self.new_client().await?);
74            debug!(self.logger, "Connected to a new client");
75        }
76
77        self.client
78            .as_mut()
79            .with_context(|| "PallasChainReader failed to get a client")
80    }
81
82    #[cfg(all(test, unix))]
83    /// Check if the client already exists (test only).
84    fn has_client(&self) -> bool {
85        self.client.is_some()
86    }
87
88    /// Drops the client by aborting the connection and setting it to `None`.
89    fn drop_client(&mut self) {
90        if let Some(client) = self.client.take() {
91            tokio::spawn(async move {
92                let _ = client.abort().await;
93            });
94        }
95    }
96
97    /// Intersects the point of the chain with the given point.
98    async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
99        let logger = self.logger.clone();
100        let chainsync_timeout = self.chainsync_timeout;
101        let client = self.get_client().await?;
102        let chainsync = client.chainsync();
103
104        if chainsync.has_agency() {
105            debug!(logger, "Has agency, finding intersect point..."; "point" => ?point);
106            let result = timeout(
107                chainsync_timeout,
108                chainsync.find_intersect(vec![point.to_owned().into()]),
109            )
110            .await;
111            match result {
112                Ok(Ok(_)) => {}
113                Ok(Err(err)) => {
114                    self.drop_client();
115
116                    return Err(anyhow::anyhow!(err)
117                        .context("PallasChainReader failed to find intersect point"));
118                }
119                Err(_elapsed) => {
120                    warn!(
121                        logger, "Timeout elapsed while finding intersect point, dropping connection";
122                        "timeout" => ?chainsync_timeout, "point" => ?point
123                    );
124                    self.drop_client();
125
126                    return Err(anyhow::anyhow!(
127                        "PallasChainReader timed out finding intersect point"
128                    ));
129                }
130            }
131        } else {
132            debug!(logger, "Doesn't have agency, no need to find intersect point";);
133        }
134
135        Ok(())
136    }
137
138    /// Processes a block content next response and returns the appropriate chain block next action.
139    async fn process_chain_block_next_action(
140        &mut self,
141        next: NextResponse<BlockContent>,
142    ) -> StdResult<Option<ChainBlockNextAction>> {
143        match next {
144            NextResponse::RollForward(raw_block, _forward_tip) => {
145                let multi_era_block = MultiEraBlock::decode(&raw_block)
146                    .with_context(|| "PallasChainReader failed to decode raw block")?;
147                let parsed_block = ScannedBlock::convert(multi_era_block);
148                Ok(Some(ChainBlockNextAction::RollForward { parsed_block }))
149            }
150            NextResponse::RollBackward(rollback_point, _) => {
151                Ok(Some(ChainBlockNextAction::RollBackward {
152                    rollback_point: RawCardanoPoint::from(rollback_point),
153                }))
154            }
155            NextResponse::Await => Ok(None),
156        }
157    }
158}
159
160impl Drop for PallasChainReader {
161    fn drop(&mut self) {
162        self.drop_client();
163    }
164}
165
166#[async_trait]
167impl ChainBlockReader for PallasChainReader {
168    async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
169        self.find_intersect_point(point).await
170    }
171
172    async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
173        let chainsync_timeout = self.chainsync_timeout;
174        let logger = self.logger.clone();
175        let client = self.get_client().await?;
176        let chainsync = client.chainsync();
177        let next = match chainsync.has_agency() {
178            true => timeout(chainsync_timeout, chainsync.request_next()).await,
179            false => timeout(chainsync_timeout, chainsync.recv_while_must_reply()).await,
180        };
181        match next {
182            Ok(Ok(response)) => self.process_chain_block_next_action(response).await,
183            Ok(Err(err)) => {
184                self.drop_client();
185
186                Err(err.into())
187            }
188            Err(_elapsed) => {
189                warn!(
190                    logger,
191                    "Timeout elapsed while waiting for next chain block from the Cardano node, dropping connection";
192                    "timeout" => ?chainsync_timeout
193                );
194                self.drop_client();
195
196                Err(anyhow::anyhow!(
197                    "PallasChainReader timed out waiting for next chain block from the Cardano node"
198                ))
199            }
200        }
201    }
202}
203
204// Windows does not support Unix sockets, nor pallas_network::facades::NodeServer
205#[cfg(all(test, unix))]
206mod tests {
207    use pallas_network::{
208        facades::NodeServer,
209        miniprotocols::{
210            Point,
211            chainsync::{BlockContent, Tip},
212        },
213    };
214    use std::fs;
215    use tokio::net::UnixListener;
216
217    use mithril_common::{current_function, entities::BlockNumber, test::TempDir};
218
219    use crate::test::TestLogger;
220
221    use super::*;
222
223    /// Enum representing the action to be performed by the server.
224    enum ServerAction {
225        RollBackward,
226        RollForward,
227    }
228
229    /// Enum representing whether the node has agency or not.
230    #[derive(Debug, PartialEq)]
231    enum HasAgency {
232        Yes,
233        No,
234    }
235
236    /// Returns a fake specific point for testing purposes.
237    fn get_fake_specific_point() -> Point {
238        Point::Specific(
239            1654413,
240            hex::decode("7de1f036df5a133ce68a82877d14354d0ba6de7625ab918e75f3e2ecb29771c2")
241                .unwrap(),
242        )
243    }
244
245    /// Returns a fake block number for testing purposes.
246    fn get_fake_block_number() -> BlockNumber {
247        BlockNumber(1337)
248    }
249
250    /// Returns a fake cardano raw point for testing purposes.
251    fn get_fake_raw_point_backwards() -> RawCardanoPoint {
252        RawCardanoPoint::from(get_fake_specific_point())
253    }
254
255    /// Creates a new work directory in the system's temporary folder.
256    fn create_temp_dir(folder_name: &str) -> PathBuf {
257        TempDir::create_with_short_path("pallas_chain_observer_test", folder_name)
258    }
259
260    fn get_fake_raw_block() -> Vec<u8> {
261        let raw_block =
262            include_str!("../../../../../mithril-test-lab/test_data/blocks/shelley1.block");
263
264        hex::decode(raw_block).unwrap()
265    }
266
267    fn get_fake_scanned_block() -> ScannedBlock {
268        let raw_block = get_fake_raw_block();
269        let multi_era_block = MultiEraBlock::decode(&raw_block).unwrap();
270
271        ScannedBlock::convert(multi_era_block)
272    }
273
274    /// Sets up a mock server for related tests.
275    ///
276    /// Use the `action` parameter to specify the action to be performed by the server.
277    async fn setup_server(
278        socket_path: PathBuf,
279        action: ServerAction,
280        has_agency: HasAgency,
281    ) -> tokio::task::JoinHandle<NodeServer> {
282        tokio::spawn({
283            async move {
284                if socket_path.exists() {
285                    fs::remove_file(&socket_path).expect("Previous socket removal failed");
286                }
287
288                let known_point = get_fake_specific_point();
289                let tip_block_number = get_fake_block_number();
290                let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
291                let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
292
293                let chainsync_server = server.chainsync();
294
295                chainsync_server.recv_while_idle().await.unwrap();
296
297                chainsync_server
298                    .send_intersect_found(
299                        known_point.clone(),
300                        Tip(known_point.clone(), *tip_block_number),
301                    )
302                    .await
303                    .unwrap();
304
305                chainsync_server.recv_while_idle().await.unwrap();
306
307                if has_agency == HasAgency::No {
308                    chainsync_server.send_await_reply().await.unwrap();
309                }
310
311                match action {
312                    ServerAction::RollBackward => {
313                        chainsync_server
314                            .send_roll_backward(
315                                known_point.clone(),
316                                Tip(known_point.clone(), *tip_block_number),
317                            )
318                            .await
319                            .unwrap();
320                    }
321                    ServerAction::RollForward => {
322                        let block = BlockContent(get_fake_raw_block());
323                        chainsync_server
324                            .send_roll_forward(block, Tip(known_point.clone(), *tip_block_number))
325                            .await
326                            .unwrap();
327                    }
328                }
329
330                server
331            }
332        })
333    }
334
335    #[tokio::test]
336    async fn get_next_chain_block_rolls_backward() {
337        let socket_path = create_temp_dir(current_function!()).join("node.socket");
338        let known_point = get_fake_specific_point();
339        let server = setup_server(
340            socket_path.clone(),
341            ServerAction::RollBackward,
342            HasAgency::Yes,
343        )
344        .await;
345        let client = tokio::spawn(async move {
346            let mut chain_reader = PallasChainReader::new(
347                socket_path.as_path(),
348                CardanoNetwork::TestNet(10),
349                TestLogger::stdout(),
350            );
351
352            chain_reader
353                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
354                .await
355                .unwrap();
356
357            chain_reader.get_next_chain_block().await.unwrap().unwrap()
358        });
359
360        let (_, client_res) = tokio::join!(server, client);
361        let chain_block = client_res.expect("Client failed to get next chain block");
362        match chain_block {
363            ChainBlockNextAction::RollBackward { rollback_point } => {
364                assert_eq!(rollback_point, get_fake_raw_point_backwards());
365            }
366            _ => panic!("Unexpected chain block action"),
367        }
368    }
369
370    #[tokio::test]
371    async fn get_next_chain_block_rolls_forward() {
372        let socket_path = create_temp_dir(current_function!()).join("node.socket");
373        let known_point = get_fake_specific_point();
374        let server = setup_server(
375            socket_path.clone(),
376            ServerAction::RollForward,
377            HasAgency::Yes,
378        )
379        .await;
380        let client = tokio::spawn(async move {
381            let mut chain_reader = PallasChainReader::new(
382                socket_path.as_path(),
383                CardanoNetwork::TestNet(10),
384                TestLogger::stdout(),
385            );
386
387            chain_reader
388                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
389                .await
390                .unwrap();
391
392            chain_reader.get_next_chain_block().await.unwrap().unwrap()
393        });
394
395        let (_, client_res) = tokio::join!(server, client);
396        let chain_block = client_res.expect("Client failed to get next chain block");
397        match chain_block {
398            ChainBlockNextAction::RollForward { parsed_block } => {
399                assert_eq!(parsed_block, get_fake_scanned_block());
400            }
401            _ => panic!("Unexpected chain block action"),
402        }
403    }
404
405    #[tokio::test]
406    async fn get_next_chain_block_has_no_agency() {
407        let socket_path = create_temp_dir(current_function!()).join("node.socket");
408        let known_point = get_fake_specific_point();
409        let server = setup_server(
410            socket_path.clone(),
411            ServerAction::RollForward,
412            HasAgency::No,
413        )
414        .await;
415        let client = tokio::spawn(async move {
416            let mut chain_reader = PallasChainReader::new(
417                socket_path.as_path(),
418                CardanoNetwork::TestNet(10),
419                TestLogger::stdout(),
420            );
421
422            chain_reader
423                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
424                .await
425                .unwrap();
426
427            // forces the client to change the chainsync server agency state
428            let client = chain_reader.get_client().await.unwrap();
429            client.chainsync().request_next().await.unwrap();
430
431            // make sure that the chainsync client returns an error when attempting to find intersection without agency
432            client
433                .chainsync()
434                .find_intersect(vec![known_point.clone()])
435                .await
436                .expect_err("chainsync find_intersect without agency should fail");
437
438            // make sure that setting the chain point is harmless when the chainsync client does not have agency
439            chain_reader
440                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
441                .await
442                .unwrap();
443
444            chain_reader.get_next_chain_block().await.unwrap().unwrap()
445        });
446
447        let (_, client_res) = tokio::join!(server, client);
448        let chain_block = client_res.expect("Client failed to get next chain block");
449        match chain_block {
450            ChainBlockNextAction::RollForward { parsed_block } => {
451                assert_eq!(parsed_block, get_fake_scanned_block());
452            }
453            _ => panic!("Unexpected chain block action"),
454        }
455    }
456
457    #[tokio::test]
458    async fn cached_client_is_dropped_when_returning_error() {
459        let socket_path = create_temp_dir(current_function!()).join("node.socket");
460        let socket_path_clone = socket_path.clone();
461        let known_point = get_fake_specific_point();
462        let server = setup_server(
463            socket_path.clone(),
464            ServerAction::RollForward,
465            HasAgency::Yes,
466        )
467        .await;
468        let client = tokio::spawn(async move {
469            let mut chain_reader = PallasChainReader::new(
470                socket_path_clone.as_path(),
471                CardanoNetwork::TestNet(10),
472                TestLogger::stdout(),
473            );
474
475            chain_reader
476                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
477                .await
478                .unwrap();
479
480            chain_reader.get_next_chain_block().await.unwrap().unwrap();
481
482            chain_reader
483        });
484
485        let (server_res, client_res) = tokio::join!(server, client);
486        let chain_reader = client_res.expect("Client failed to get chain reader");
487        let server = server_res.expect("Server failed to get server");
488        server.abort().await;
489
490        let client = tokio::spawn(async move {
491            let mut chain_reader = chain_reader;
492
493            assert!(chain_reader.has_client(), "Client should exist");
494
495            chain_reader
496                .get_next_chain_block()
497                .await
498                .expect_err("Chain reader get_next_chain_block should fail");
499
500            assert!(
501                !chain_reader.has_client(),
502                "Client should have been dropped after error"
503            );
504
505            chain_reader
506        });
507        client.await.unwrap();
508    }
509
510    #[tokio::test]
511    async fn cached_client_is_dropped_when_get_next_chain_block_times_out() {
512        let socket_path = create_temp_dir(current_function!()).join("node.socket");
513        let known_point = get_fake_specific_point();
514        let tip_block_number = get_fake_block_number();
515
516        let server = tokio::spawn({
517            let socket_path = socket_path.clone();
518            async move {
519                if socket_path.exists() {
520                    fs::remove_file(&socket_path).expect("Previous socket removal failed");
521                }
522
523                let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
524                let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
525                let chainsync_server = server.chainsync();
526
527                chainsync_server.recv_while_idle().await.unwrap();
528                chainsync_server
529                    .send_intersect_found(
530                        known_point.clone(),
531                        Tip(known_point.clone(), *tip_block_number),
532                    )
533                    .await
534                    .unwrap();
535
536                // Receive the request_next but never respond — simulates a stuck Cardano node
537                chainsync_server.recv_while_idle().await.unwrap();
538
539                // Keep server alive until aborted so the socket does not close
540                std::future::pending::<()>().await;
541                server
542            }
543        });
544
545        let client = tokio::spawn({
546            let socket_path = socket_path.clone();
547            async move {
548                let mut chain_reader = PallasChainReader::new(
549                    socket_path.as_path(),
550                    CardanoNetwork::TestNet(10),
551                    TestLogger::stdout(),
552                )
553                .with_timeout(Duration::from_millis(200));
554
555                chain_reader
556                    .set_chain_point(&RawCardanoPoint::from(get_fake_specific_point()))
557                    .await
558                    .unwrap();
559
560                assert!(
561                    chain_reader.has_client(),
562                    "Client should exist before timeout"
563                );
564
565                let result = chain_reader.get_next_chain_block().await;
566                assert!(result.is_err(), "Expected timeout error");
567
568                assert!(
569                    !chain_reader.has_client(),
570                    "Client should have been dropped after timeout"
571                );
572
573                chain_reader
574            }
575        });
576
577        let client_result = client.await;
578        server.abort();
579        client_result.unwrap();
580    }
581
582    #[tokio::test]
583    async fn cached_client_is_dropped_when_set_chain_point_times_out() {
584        let socket_path = create_temp_dir(current_function!()).join("node.socket");
585
586        let server = tokio::spawn({
587            let socket_path = socket_path.clone();
588            async move {
589                if socket_path.exists() {
590                    fs::remove_file(&socket_path).expect("Previous socket removal failed");
591                }
592
593                let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
594                let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
595                let chainsync_server = server.chainsync();
596
597                // Receive the find_intersect request but never respond
598                chainsync_server.recv_while_idle().await.unwrap();
599
600                // Keep server alive until aborted so the socket does not close
601                std::future::pending::<()>().await;
602                server
603            }
604        });
605
606        let client = tokio::spawn({
607            let socket_path = socket_path.clone();
608            async move {
609                let mut chain_reader = PallasChainReader::new(
610                    socket_path.as_path(),
611                    CardanoNetwork::TestNet(10),
612                    TestLogger::stdout(),
613                )
614                .with_timeout(Duration::from_millis(200));
615
616                assert!(
617                    !chain_reader.has_client(),
618                    "Client should not exist before connection"
619                );
620
621                let result = chain_reader
622                    .set_chain_point(&RawCardanoPoint::from(get_fake_specific_point()))
623                    .await;
624                assert!(result.is_err(), "Expected timeout error");
625
626                assert!(
627                    !chain_reader.has_client(),
628                    "Client should have been dropped after timeout"
629                );
630
631                chain_reader
632            }
633        });
634
635        let client_result = client.await;
636        server.abort();
637        client_result.unwrap();
638    }
639}