mithril_cardano_node_chain/chain_reader/
pallas_chain_reader.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::Context;
4use async_trait::async_trait;
5use pallas_network::{
6    facades::NodeClient,
7    miniprotocols::chainsync::{BlockContent, NextResponse},
8};
9use pallas_traverse::MultiEraBlock;
10use slog::{Logger, debug};
11
12use mithril_common::StdResult;
13use mithril_common::entities::CardanoNetwork;
14use mithril_common::logging::LoggerExtensions;
15
16use crate::entities::{ChainBlockNextAction, RawCardanoPoint, ScannedBlock};
17
18use super::ChainBlockReader;
19
20/// [PallasChainReader] reads blocks with 'chainsync' mini-protocol
21pub struct PallasChainReader {
22    socket: PathBuf,
23    network: CardanoNetwork,
24    client: Option<NodeClient>,
25    logger: Logger,
26}
27
28impl PallasChainReader {
29    /// Creates a new `PallasChainReader` with the specified socket and network.
30    pub fn new(socket: &Path, network: CardanoNetwork, logger: Logger) -> Self {
31        Self {
32            socket: socket.to_owned(),
33            network,
34            client: None,
35            logger: logger.new_with_component_name::<Self>(),
36        }
37    }
38
39    /// Creates and returns a new `NodeClient` connected to the specified socket.
40    async fn new_client(&self) -> StdResult<NodeClient> {
41        let magic = self.network.magic_id();
42        NodeClient::connect(&self.socket, magic)
43            .await
44            .with_context(|| "PallasChainReader failed to create a new client")
45    }
46
47    /// Returns a mutable reference to the client.
48    async fn get_client(&mut self) -> StdResult<&mut NodeClient> {
49        if self.client.is_none() {
50            self.client = Some(self.new_client().await?);
51            debug!(self.logger, "Connected to a new client");
52        }
53
54        self.client
55            .as_mut()
56            .with_context(|| "PallasChainReader failed to get a client")
57    }
58
59    #[cfg(all(test, unix))]
60    /// Check if the client already exists (test only).
61    fn has_client(&self) -> bool {
62        self.client.is_some()
63    }
64
65    /// Drops the client by aborting the connection and setting it to `None`.
66    fn drop_client(&mut self) {
67        if let Some(client) = self.client.take() {
68            tokio::spawn(async move {
69                let _ = client.abort().await;
70            });
71        }
72    }
73
74    /// Intersects the point of the chain with the given point.
75    async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
76        let logger = self.logger.clone();
77        let client = self.get_client().await?;
78        let chainsync = client.chainsync();
79
80        if chainsync.has_agency() {
81            debug!(logger, "Has agency, finding intersect point..."; "point" => ?point);
82            chainsync.find_intersect(vec![point.to_owned().into()]).await?;
83        } else {
84            debug!(logger, "Doesn't have agency, no need to find intersect point";);
85        }
86
87        Ok(())
88    }
89
90    /// Processes a block content next response and returns the appropriate chain block next action.
91    async fn process_chain_block_next_action(
92        &mut self,
93        next: NextResponse<BlockContent>,
94    ) -> StdResult<Option<ChainBlockNextAction>> {
95        match next {
96            NextResponse::RollForward(raw_block, _forward_tip) => {
97                let multi_era_block = MultiEraBlock::decode(&raw_block)
98                    .with_context(|| "PallasChainReader failed to decode raw block")?;
99                let parsed_block = ScannedBlock::convert(multi_era_block);
100                Ok(Some(ChainBlockNextAction::RollForward { parsed_block }))
101            }
102            NextResponse::RollBackward(rollback_point, _) => {
103                Ok(Some(ChainBlockNextAction::RollBackward {
104                    rollback_point: RawCardanoPoint::from(rollback_point),
105                }))
106            }
107            NextResponse::Await => Ok(None),
108        }
109    }
110}
111
112impl Drop for PallasChainReader {
113    fn drop(&mut self) {
114        self.drop_client();
115    }
116}
117
118#[async_trait]
119impl ChainBlockReader for PallasChainReader {
120    async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
121        match self.find_intersect_point(point).await {
122            Ok(()) => Ok(()),
123            Err(err) => {
124                self.drop_client();
125
126                return Err(err);
127            }
128        }
129    }
130
131    async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
132        let client = self.get_client().await?;
133        let chainsync = client.chainsync();
134        let next = match chainsync.has_agency() {
135            true => chainsync.request_next().await,
136            false => chainsync.recv_while_must_reply().await,
137        };
138        match next {
139            Ok(next) => self.process_chain_block_next_action(next).await,
140            Err(err) => {
141                self.drop_client();
142
143                return Err(err.into());
144            }
145        }
146    }
147}
148
149// Windows does not support Unix sockets, nor pallas_network::facades::NodeServer
150#[cfg(all(test, unix))]
151mod tests {
152    use pallas_network::{
153        facades::NodeServer,
154        miniprotocols::{
155            Point,
156            chainsync::{BlockContent, Tip},
157        },
158    };
159    use std::fs;
160    use tokio::net::UnixListener;
161
162    use mithril_common::{current_function, entities::BlockNumber, test::TempDir};
163
164    use crate::test::TestLogger;
165
166    use super::*;
167
168    /// Enum representing the action to be performed by the server.
169    enum ServerAction {
170        RollBackward,
171        RollForward,
172    }
173
174    /// Enum representing whether the node has agency or not.
175    #[derive(Debug, PartialEq)]
176    enum HasAgency {
177        Yes,
178        No,
179    }
180
181    /// Returns a fake specific point for testing purposes.
182    fn get_fake_specific_point() -> Point {
183        Point::Specific(
184            1654413,
185            hex::decode("7de1f036df5a133ce68a82877d14354d0ba6de7625ab918e75f3e2ecb29771c2")
186                .unwrap(),
187        )
188    }
189
190    /// Returns a fake block number for testing purposes.
191    fn get_fake_block_number() -> BlockNumber {
192        BlockNumber(1337)
193    }
194
195    /// Returns a fake cardano raw point for testing purposes.
196    fn get_fake_raw_point_backwards() -> RawCardanoPoint {
197        RawCardanoPoint::from(get_fake_specific_point())
198    }
199
200    /// Creates a new work directory in the system's temporary folder.
201    fn create_temp_dir(folder_name: &str) -> PathBuf {
202        TempDir::create_with_short_path("pallas_chain_observer_test", folder_name)
203    }
204
205    fn get_fake_raw_block() -> Vec<u8> {
206        let raw_block =
207            include_str!("../../../../../mithril-test-lab/test_data/blocks/shelley1.block");
208
209        hex::decode(raw_block).unwrap()
210    }
211
212    fn get_fake_scanned_block() -> ScannedBlock {
213        let raw_block = get_fake_raw_block();
214        let multi_era_block = MultiEraBlock::decode(&raw_block).unwrap();
215
216        ScannedBlock::convert(multi_era_block)
217    }
218
219    /// Sets up a mock server for related tests.
220    ///
221    /// Use the `action` parameter to specify the action to be performed by the server.
222    async fn setup_server(
223        socket_path: PathBuf,
224        action: ServerAction,
225        has_agency: HasAgency,
226    ) -> tokio::task::JoinHandle<NodeServer> {
227        tokio::spawn({
228            async move {
229                if socket_path.exists() {
230                    fs::remove_file(&socket_path).expect("Previous socket removal failed");
231                }
232
233                let known_point = get_fake_specific_point();
234                let tip_block_number = get_fake_block_number();
235                let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
236                let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
237
238                let chainsync_server = server.chainsync();
239
240                chainsync_server.recv_while_idle().await.unwrap();
241
242                chainsync_server
243                    .send_intersect_found(
244                        known_point.clone(),
245                        Tip(known_point.clone(), *tip_block_number),
246                    )
247                    .await
248                    .unwrap();
249
250                chainsync_server.recv_while_idle().await.unwrap();
251
252                if has_agency == HasAgency::No {
253                    chainsync_server.send_await_reply().await.unwrap();
254                }
255
256                match action {
257                    ServerAction::RollBackward => {
258                        chainsync_server
259                            .send_roll_backward(
260                                known_point.clone(),
261                                Tip(known_point.clone(), *tip_block_number),
262                            )
263                            .await
264                            .unwrap();
265                    }
266                    ServerAction::RollForward => {
267                        let block = BlockContent(get_fake_raw_block());
268                        chainsync_server
269                            .send_roll_forward(block, Tip(known_point.clone(), *tip_block_number))
270                            .await
271                            .unwrap();
272                    }
273                }
274
275                server
276            }
277        })
278    }
279
280    #[tokio::test]
281    async fn get_next_chain_block_rolls_backward() {
282        let socket_path = create_temp_dir(current_function!()).join("node.socket");
283        let known_point = get_fake_specific_point();
284        let server = setup_server(
285            socket_path.clone(),
286            ServerAction::RollBackward,
287            HasAgency::Yes,
288        )
289        .await;
290        let client = tokio::spawn(async move {
291            let mut chain_reader = PallasChainReader::new(
292                socket_path.as_path(),
293                CardanoNetwork::TestNet(10),
294                TestLogger::stdout(),
295            );
296
297            chain_reader
298                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
299                .await
300                .unwrap();
301
302            chain_reader.get_next_chain_block().await.unwrap().unwrap()
303        });
304
305        let (_, client_res) = tokio::join!(server, client);
306        let chain_block = client_res.expect("Client failed to get next chain block");
307        match chain_block {
308            ChainBlockNextAction::RollBackward { rollback_point } => {
309                assert_eq!(rollback_point, get_fake_raw_point_backwards());
310            }
311            _ => panic!("Unexpected chain block action"),
312        }
313    }
314
315    #[tokio::test]
316    async fn get_next_chain_block_rolls_forward() {
317        let socket_path = create_temp_dir(current_function!()).join("node.socket");
318        let known_point = get_fake_specific_point();
319        let server = setup_server(
320            socket_path.clone(),
321            ServerAction::RollForward,
322            HasAgency::Yes,
323        )
324        .await;
325        let client = tokio::spawn(async move {
326            let mut chain_reader = PallasChainReader::new(
327                socket_path.as_path(),
328                CardanoNetwork::TestNet(10),
329                TestLogger::stdout(),
330            );
331
332            chain_reader
333                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
334                .await
335                .unwrap();
336
337            chain_reader.get_next_chain_block().await.unwrap().unwrap()
338        });
339
340        let (_, client_res) = tokio::join!(server, client);
341        let chain_block = client_res.expect("Client failed to get next chain block");
342        match chain_block {
343            ChainBlockNextAction::RollForward { parsed_block } => {
344                assert_eq!(parsed_block, get_fake_scanned_block());
345            }
346            _ => panic!("Unexpected chain block action"),
347        }
348    }
349
350    #[tokio::test]
351    async fn get_next_chain_block_has_no_agency() {
352        let socket_path = create_temp_dir(current_function!()).join("node.socket");
353        let known_point = get_fake_specific_point();
354        let server = setup_server(
355            socket_path.clone(),
356            ServerAction::RollForward,
357            HasAgency::No,
358        )
359        .await;
360        let client = tokio::spawn(async move {
361            let mut chain_reader = PallasChainReader::new(
362                socket_path.as_path(),
363                CardanoNetwork::TestNet(10),
364                TestLogger::stdout(),
365            );
366
367            chain_reader
368                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
369                .await
370                .unwrap();
371
372            // forces the client to change the chainsync server agency state
373            let client = chain_reader.get_client().await.unwrap();
374            client.chainsync().request_next().await.unwrap();
375
376            // make sure that the chainsync client returns an error when attempting to find intersection without agency
377            client
378                .chainsync()
379                .find_intersect(vec![known_point.clone()])
380                .await
381                .expect_err("chainsync find_intersect without agency should fail");
382
383            // make sure that setting the chain point is harmless when the chainsync client does not have agency
384            chain_reader
385                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
386                .await
387                .unwrap();
388
389            chain_reader.get_next_chain_block().await.unwrap().unwrap()
390        });
391
392        let (_, client_res) = tokio::join!(server, client);
393        let chain_block = client_res.expect("Client failed to get next chain block");
394        match chain_block {
395            ChainBlockNextAction::RollForward { parsed_block } => {
396                assert_eq!(parsed_block, get_fake_scanned_block());
397            }
398            _ => panic!("Unexpected chain block action"),
399        }
400    }
401
402    #[tokio::test]
403    async fn cached_client_is_dropped_when_returning_error() {
404        let socket_path = create_temp_dir(current_function!()).join("node.socket");
405        let socket_path_clone = socket_path.clone();
406        let known_point = get_fake_specific_point();
407        let server = setup_server(
408            socket_path.clone(),
409            ServerAction::RollForward,
410            HasAgency::Yes,
411        )
412        .await;
413        let client = tokio::spawn(async move {
414            let mut chain_reader = PallasChainReader::new(
415                socket_path_clone.as_path(),
416                CardanoNetwork::TestNet(10),
417                TestLogger::stdout(),
418            );
419
420            chain_reader
421                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
422                .await
423                .unwrap();
424
425            chain_reader.get_next_chain_block().await.unwrap().unwrap();
426
427            chain_reader
428        });
429
430        let (server_res, client_res) = tokio::join!(server, client);
431        let chain_reader = client_res.expect("Client failed to get chain reader");
432        let server = server_res.expect("Server failed to get server");
433        server.abort().await;
434
435        let client = tokio::spawn(async move {
436            let mut chain_reader = chain_reader;
437
438            assert!(chain_reader.has_client(), "Client should exist");
439
440            chain_reader
441                .get_next_chain_block()
442                .await
443                .expect_err("Chain reader get_next_chain_block should fail");
444
445            assert!(
446                !chain_reader.has_client(),
447                "Client should have been dropped after error"
448            );
449
450            chain_reader
451        });
452        client.await.unwrap();
453    }
454}