mithril_cardano_node_chain/chain_reader/
pallas_chain_reader.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, anyhow};
4use async_trait::async_trait;
5use pallas_network::{
6    facades::NodeClient,
7    miniprotocols::chainsync::{BlockContent, NextResponse},
8};
9use pallas_traverse::MultiEraBlock;
10use slog::{Logger, debug};
11
12use mithril_common::StdResult;
13use mithril_common::entities::CardanoNetwork;
14use mithril_common::logging::LoggerExtensions;
15
16use crate::entities::{ChainBlockNextAction, RawCardanoPoint, ScannedBlock};
17
18use super::ChainBlockReader;
19
20/// [PallasChainReader] reads blocks with 'chainsync' mini-protocol
21pub struct PallasChainReader {
22    socket: PathBuf,
23    network: CardanoNetwork,
24    client: Option<NodeClient>,
25    logger: Logger,
26}
27
28impl PallasChainReader {
29    /// Creates a new `PallasChainReader` with the specified socket and network.
30    pub fn new(socket: &Path, network: CardanoNetwork, logger: Logger) -> Self {
31        Self {
32            socket: socket.to_owned(),
33            network,
34            client: None,
35            logger: logger.new_with_component_name::<Self>(),
36        }
37    }
38
39    /// Creates and returns a new `NodeClient` connected to the specified socket.
40    async fn new_client(&self) -> StdResult<NodeClient> {
41        let magic = self.network.code();
42        NodeClient::connect(&self.socket, magic)
43            .await
44            .map_err(|err| anyhow!(err))
45            .with_context(|| "PallasChainReader failed to create a new client")
46    }
47
48    /// Returns a mutable reference to the client.
49    async fn get_client(&mut self) -> StdResult<&mut NodeClient> {
50        if self.client.is_none() {
51            self.client = Some(self.new_client().await?);
52            debug!(self.logger, "Connected to a new client");
53        }
54
55        self.client
56            .as_mut()
57            .with_context(|| "PallasChainReader failed to get a client")
58    }
59
60    #[cfg(all(test, unix))]
61    /// Check if the client already exists (test only).
62    fn has_client(&self) -> bool {
63        self.client.is_some()
64    }
65
66    /// Drops the client by aborting the connection and setting it to `None`.
67    fn drop_client(&mut self) {
68        if let Some(client) = self.client.take() {
69            tokio::spawn(async move {
70                let _ = client.abort().await;
71            });
72        }
73    }
74
75    /// Intersects the point of the chain with the given point.
76    async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
77        let logger = self.logger.clone();
78        let client = self.get_client().await?;
79        let chainsync = client.chainsync();
80
81        if chainsync.has_agency() {
82            debug!(logger, "Has agency, finding intersect point..."; "point" => ?point);
83            chainsync.find_intersect(vec![point.to_owned().into()]).await?;
84        } else {
85            debug!(logger, "Doesn't have agency, no need to find intersect point";);
86        }
87
88        Ok(())
89    }
90
91    /// Processes a block content next response and returns the appropriate chain block next action.
92    async fn process_chain_block_next_action(
93        &mut self,
94        next: NextResponse<BlockContent>,
95    ) -> StdResult<Option<ChainBlockNextAction>> {
96        match next {
97            NextResponse::RollForward(raw_block, _forward_tip) => {
98                let multi_era_block = MultiEraBlock::decode(&raw_block)
99                    .with_context(|| "PallasChainReader failed to decode raw block")?;
100                let parsed_block = ScannedBlock::convert(multi_era_block);
101                Ok(Some(ChainBlockNextAction::RollForward { parsed_block }))
102            }
103            NextResponse::RollBackward(rollback_point, _) => {
104                Ok(Some(ChainBlockNextAction::RollBackward {
105                    rollback_point: RawCardanoPoint::from(rollback_point),
106                }))
107            }
108            NextResponse::Await => Ok(None),
109        }
110    }
111}
112
113impl Drop for PallasChainReader {
114    fn drop(&mut self) {
115        self.drop_client();
116    }
117}
118
119#[async_trait]
120impl ChainBlockReader for PallasChainReader {
121    async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
122        match self.find_intersect_point(point).await {
123            Ok(()) => Ok(()),
124            Err(err) => {
125                self.drop_client();
126
127                return Err(err);
128            }
129        }
130    }
131
132    async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
133        let client = self.get_client().await?;
134        let chainsync = client.chainsync();
135        let next = match chainsync.has_agency() {
136            true => chainsync.request_next().await,
137            false => chainsync.recv_while_must_reply().await,
138        };
139        match next {
140            Ok(next) => self.process_chain_block_next_action(next).await,
141            Err(err) => {
142                self.drop_client();
143
144                return Err(err.into());
145            }
146        }
147    }
148}
149
150// Windows does not support Unix sockets, nor pallas_network::facades::NodeServer
151#[cfg(all(test, unix))]
152mod tests {
153    use pallas_network::{
154        facades::NodeServer,
155        miniprotocols::{
156            Point,
157            chainsync::{BlockContent, Tip},
158        },
159    };
160    use std::fs;
161    use tokio::net::UnixListener;
162
163    use mithril_common::{current_function, entities::BlockNumber, test_utils::TempDir};
164
165    use crate::test::TestLogger;
166
167    use super::*;
168
169    /// Enum representing the action to be performed by the server.
170    enum ServerAction {
171        RollBackward,
172        RollForward,
173    }
174
175    /// Enum representing whether the node has agency or not.
176    #[derive(Debug, PartialEq)]
177    enum HasAgency {
178        Yes,
179        No,
180    }
181
182    /// Returns a fake specific point for testing purposes.
183    fn get_fake_specific_point() -> Point {
184        Point::Specific(
185            1654413,
186            hex::decode("7de1f036df5a133ce68a82877d14354d0ba6de7625ab918e75f3e2ecb29771c2")
187                .unwrap(),
188        )
189    }
190
191    /// Returns a fake block number for testing purposes.
192    fn get_fake_block_number() -> BlockNumber {
193        BlockNumber(1337)
194    }
195
196    /// Returns a fake cardano raw point for testing purposes.
197    fn get_fake_raw_point_backwards() -> RawCardanoPoint {
198        RawCardanoPoint::from(get_fake_specific_point())
199    }
200
201    /// Creates a new work directory in the system's temporary folder.
202    fn create_temp_dir(folder_name: &str) -> PathBuf {
203        TempDir::create_with_short_path("pallas_chain_observer_test", folder_name)
204    }
205
206    fn get_fake_raw_block() -> Vec<u8> {
207        let raw_block =
208            include_str!("../../../../../mithril-test-lab/test_data/blocks/shelley1.block");
209
210        hex::decode(raw_block).unwrap()
211    }
212
213    fn get_fake_scanned_block() -> ScannedBlock {
214        let raw_block = get_fake_raw_block();
215        let multi_era_block = MultiEraBlock::decode(&raw_block).unwrap();
216
217        ScannedBlock::convert(multi_era_block)
218    }
219
220    /// Sets up a mock server for related tests.
221    ///
222    /// Use the `action` parameter to specify the action to be performed by the server.
223    async fn setup_server(
224        socket_path: PathBuf,
225        action: ServerAction,
226        has_agency: HasAgency,
227    ) -> tokio::task::JoinHandle<NodeServer> {
228        tokio::spawn({
229            async move {
230                if socket_path.exists() {
231                    fs::remove_file(&socket_path).expect("Previous socket removal failed");
232                }
233
234                let known_point = get_fake_specific_point();
235                let tip_block_number = get_fake_block_number();
236                let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
237                let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
238
239                let chainsync_server = server.chainsync();
240
241                chainsync_server.recv_while_idle().await.unwrap();
242
243                chainsync_server
244                    .send_intersect_found(
245                        known_point.clone(),
246                        Tip(known_point.clone(), *tip_block_number),
247                    )
248                    .await
249                    .unwrap();
250
251                chainsync_server.recv_while_idle().await.unwrap();
252
253                if has_agency == HasAgency::No {
254                    chainsync_server.send_await_reply().await.unwrap();
255                }
256
257                match action {
258                    ServerAction::RollBackward => {
259                        chainsync_server
260                            .send_roll_backward(
261                                known_point.clone(),
262                                Tip(known_point.clone(), *tip_block_number),
263                            )
264                            .await
265                            .unwrap();
266                    }
267                    ServerAction::RollForward => {
268                        let block = BlockContent(get_fake_raw_block());
269                        chainsync_server
270                            .send_roll_forward(block, Tip(known_point.clone(), *tip_block_number))
271                            .await
272                            .unwrap();
273                    }
274                }
275
276                server
277            }
278        })
279    }
280
281    #[tokio::test]
282    async fn get_next_chain_block_rolls_backward() {
283        let socket_path = create_temp_dir(current_function!()).join("node.socket");
284        let known_point = get_fake_specific_point();
285        let server = setup_server(
286            socket_path.clone(),
287            ServerAction::RollBackward,
288            HasAgency::Yes,
289        )
290        .await;
291        let client = tokio::spawn(async move {
292            let mut chain_reader = PallasChainReader::new(
293                socket_path.as_path(),
294                CardanoNetwork::TestNet(10),
295                TestLogger::stdout(),
296            );
297
298            chain_reader
299                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
300                .await
301                .unwrap();
302
303            chain_reader.get_next_chain_block().await.unwrap().unwrap()
304        });
305
306        let (_, client_res) = tokio::join!(server, client);
307        let chain_block = client_res.expect("Client failed to get next chain block");
308        match chain_block {
309            ChainBlockNextAction::RollBackward { rollback_point } => {
310                assert_eq!(rollback_point, get_fake_raw_point_backwards());
311            }
312            _ => panic!("Unexpected chain block action"),
313        }
314    }
315
316    #[tokio::test]
317    async fn get_next_chain_block_rolls_forward() {
318        let socket_path = create_temp_dir(current_function!()).join("node.socket");
319        let known_point = get_fake_specific_point();
320        let server = setup_server(
321            socket_path.clone(),
322            ServerAction::RollForward,
323            HasAgency::Yes,
324        )
325        .await;
326        let client = tokio::spawn(async move {
327            let mut chain_reader = PallasChainReader::new(
328                socket_path.as_path(),
329                CardanoNetwork::TestNet(10),
330                TestLogger::stdout(),
331            );
332
333            chain_reader
334                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
335                .await
336                .unwrap();
337
338            chain_reader.get_next_chain_block().await.unwrap().unwrap()
339        });
340
341        let (_, client_res) = tokio::join!(server, client);
342        let chain_block = client_res.expect("Client failed to get next chain block");
343        match chain_block {
344            ChainBlockNextAction::RollForward { parsed_block } => {
345                assert_eq!(parsed_block, get_fake_scanned_block());
346            }
347            _ => panic!("Unexpected chain block action"),
348        }
349    }
350
351    #[tokio::test]
352    async fn get_next_chain_block_has_no_agency() {
353        let socket_path = create_temp_dir(current_function!()).join("node.socket");
354        let known_point = get_fake_specific_point();
355        let server = setup_server(
356            socket_path.clone(),
357            ServerAction::RollForward,
358            HasAgency::No,
359        )
360        .await;
361        let client = tokio::spawn(async move {
362            let mut chain_reader = PallasChainReader::new(
363                socket_path.as_path(),
364                CardanoNetwork::TestNet(10),
365                TestLogger::stdout(),
366            );
367
368            chain_reader
369                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
370                .await
371                .unwrap();
372
373            // forces the client to change the chainsync server agency state
374            let client = chain_reader.get_client().await.unwrap();
375            client.chainsync().request_next().await.unwrap();
376
377            // make sure that the chainsync client returns an error when attempting to find intersection without agency
378            client
379                .chainsync()
380                .find_intersect(vec![known_point.clone()])
381                .await
382                .expect_err("chainsync find_intersect without agency should fail");
383
384            // make sure that setting the chain point is harmless when the chainsync client does not have agency
385            chain_reader
386                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
387                .await
388                .unwrap();
389
390            chain_reader.get_next_chain_block().await.unwrap().unwrap()
391        });
392
393        let (_, client_res) = tokio::join!(server, client);
394        let chain_block = client_res.expect("Client failed to get next chain block");
395        match chain_block {
396            ChainBlockNextAction::RollForward { parsed_block } => {
397                assert_eq!(parsed_block, get_fake_scanned_block());
398            }
399            _ => panic!("Unexpected chain block action"),
400        }
401    }
402
403    #[tokio::test]
404    async fn cached_client_is_dropped_when_returning_error() {
405        let socket_path = create_temp_dir(current_function!()).join("node.socket");
406        let socket_path_clone = socket_path.clone();
407        let known_point = get_fake_specific_point();
408        let server = setup_server(
409            socket_path.clone(),
410            ServerAction::RollForward,
411            HasAgency::Yes,
412        )
413        .await;
414        let client = tokio::spawn(async move {
415            let mut chain_reader = PallasChainReader::new(
416                socket_path_clone.as_path(),
417                CardanoNetwork::TestNet(10),
418                TestLogger::stdout(),
419            );
420
421            chain_reader
422                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
423                .await
424                .unwrap();
425
426            chain_reader.get_next_chain_block().await.unwrap().unwrap();
427
428            chain_reader
429        });
430
431        let (server_res, client_res) = tokio::join!(server, client);
432        let chain_reader = client_res.expect("Client failed to get chain reader");
433        let server = server_res.expect("Server failed to get server");
434        server.abort().await;
435
436        let client = tokio::spawn(async move {
437            let mut chain_reader = chain_reader;
438
439            assert!(chain_reader.has_client(), "Client should exist");
440
441            chain_reader
442                .get_next_chain_block()
443                .await
444                .expect_err("Chain reader get_next_chain_block should fail");
445
446            assert!(
447                !chain_reader.has_client(),
448                "Client should have been dropped after error"
449            );
450
451            chain_reader
452        });
453        client.await.unwrap();
454    }
455}