mithril_common/chain_reader/
pallas_chain_reader.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{anyhow, Context};
4use async_trait::async_trait;
5use pallas_network::{
6    facades::NodeClient,
7    miniprotocols::chainsync::{BlockContent, NextResponse},
8};
9use pallas_traverse::MultiEraBlock;
10use slog::{debug, Logger};
11
12use crate::logging::LoggerExtensions;
13use crate::{
14    cardano_block_scanner::{RawCardanoPoint, ScannedBlock},
15    CardanoNetwork, StdResult,
16};
17
18use super::{ChainBlockNextAction, ChainBlockReader};
19
20/// [PallasChainReader] reads blocks with 'chainsync' mini-protocol
21pub struct PallasChainReader {
22    socket: PathBuf,
23    network: CardanoNetwork,
24    client: Option<NodeClient>,
25    logger: Logger,
26}
27
28impl PallasChainReader {
29    /// Creates a new `PallasChainReader` with the specified socket and network.
30    pub fn new(socket: &Path, network: CardanoNetwork, logger: Logger) -> Self {
31        Self {
32            socket: socket.to_owned(),
33            network,
34            client: None,
35            logger: logger.new_with_component_name::<Self>(),
36        }
37    }
38
39    /// Creates and returns a new `NodeClient` connected to the specified socket.
40    async fn new_client(&self) -> StdResult<NodeClient> {
41        let magic = self.network.code();
42        NodeClient::connect(&self.socket, magic)
43            .await
44            .map_err(|err| anyhow!(err))
45            .with_context(|| "PallasChainReader failed to create a new client")
46    }
47
48    /// Returns a mutable reference to the client.
49    async fn get_client(&mut self) -> StdResult<&mut NodeClient> {
50        if self.client.is_none() {
51            self.client = Some(self.new_client().await?);
52            debug!(self.logger, "Connected to a new client");
53        }
54
55        self.client
56            .as_mut()
57            .with_context(|| "PallasChainReader failed to get a client")
58    }
59
60    /// Intersects the point of the chain with the given point.
61    async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
62        let logger = self.logger.clone();
63        let client = self.get_client().await?;
64        let chainsync = client.chainsync();
65
66        if chainsync.has_agency() {
67            debug!(logger, "Has agency, finding intersect point..."; "point" => ?point);
68            chainsync
69                .find_intersect(vec![point.to_owned().into()])
70                .await?;
71        } else {
72            debug!(logger, "Doesn't have agency, no need to find intersect point";);
73        }
74
75        Ok(())
76    }
77
78    /// Processes a block content next response and returns the appropriate chain block next action.
79    async fn process_chain_block_next_action(
80        &mut self,
81        next: NextResponse<BlockContent>,
82    ) -> StdResult<Option<ChainBlockNextAction>> {
83        match next {
84            NextResponse::RollForward(raw_block, _forward_tip) => {
85                let multi_era_block = MultiEraBlock::decode(&raw_block)
86                    .with_context(|| "PallasChainReader failed to decode raw block")?;
87                let parsed_block = ScannedBlock::convert(multi_era_block);
88                Ok(Some(ChainBlockNextAction::RollForward { parsed_block }))
89            }
90            NextResponse::RollBackward(rollback_point, _) => {
91                Ok(Some(ChainBlockNextAction::RollBackward {
92                    rollback_point: RawCardanoPoint::from(rollback_point),
93                }))
94            }
95            NextResponse::Await => Ok(None),
96        }
97    }
98}
99
100impl Drop for PallasChainReader {
101    fn drop(&mut self) {
102        if let Some(client) = self.client.take() {
103            tokio::spawn(async move {
104                let _ = client.abort().await;
105            });
106        }
107    }
108}
109
110#[async_trait]
111impl ChainBlockReader for PallasChainReader {
112    async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
113        self.find_intersect_point(point).await
114    }
115
116    async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
117        let client = self.get_client().await?;
118        let chainsync = client.chainsync();
119
120        let next = match chainsync.has_agency() {
121            true => chainsync.request_next().await?,
122            false => chainsync.recv_while_must_reply().await?,
123        };
124
125        self.process_chain_block_next_action(next).await
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use std::fs;
132
133    use pallas_network::{
134        facades::NodeServer,
135        miniprotocols::{
136            chainsync::{BlockContent, Tip},
137            Point,
138        },
139    };
140    use tokio::net::UnixListener;
141
142    use super::*;
143
144    use crate::test_utils::TestLogger;
145    use crate::{entities::BlockNumber, test_utils::TempDir};
146
147    /// Enum representing the action to be performed by the server.
148    enum ServerAction {
149        RollBackward,
150        RollForward,
151    }
152
153    /// Enum representing whether the node has agency or not.
154    #[derive(Debug, PartialEq)]
155    enum HasAgency {
156        Yes,
157        No,
158    }
159
160    /// Returns a fake specific point for testing purposes.
161    fn get_fake_specific_point() -> Point {
162        Point::Specific(
163            1654413,
164            hex::decode("7de1f036df5a133ce68a82877d14354d0ba6de7625ab918e75f3e2ecb29771c2")
165                .unwrap(),
166        )
167    }
168
169    /// Returns a fake block number for testing purposes.
170    fn get_fake_block_number() -> BlockNumber {
171        BlockNumber(1337)
172    }
173
174    /// Returns a fake cardano raw point for testing purposes.
175    fn get_fake_raw_point_backwards() -> RawCardanoPoint {
176        RawCardanoPoint::from(get_fake_specific_point())
177    }
178
179    /// Creates a new work directory in the system's temporary folder.
180    fn create_temp_dir(folder_name: &str) -> PathBuf {
181        TempDir::create_with_short_path("pallas_chain_observer_test", folder_name)
182    }
183
184    fn get_fake_raw_block() -> Vec<u8> {
185        let raw_block = include_str!("../../../mithril-test-lab/test_data/blocks/shelley1.block");
186
187        hex::decode(raw_block).unwrap()
188    }
189
190    fn get_fake_scanned_block() -> ScannedBlock {
191        let raw_block = get_fake_raw_block();
192        let multi_era_block = MultiEraBlock::decode(&raw_block).unwrap();
193
194        ScannedBlock::convert(multi_era_block)
195    }
196
197    /// Sets up a mock server for related tests.
198    ///
199    /// Use the `action` parameter to specify the action to be performed by the server.
200    async fn setup_server(
201        socket_path: PathBuf,
202        action: ServerAction,
203        has_agency: HasAgency,
204    ) -> tokio::task::JoinHandle<()> {
205        tokio::spawn({
206            async move {
207                if socket_path.exists() {
208                    fs::remove_file(&socket_path).expect("Previous socket removal failed");
209                }
210
211                let known_point = get_fake_specific_point();
212                let tip_block_number = get_fake_block_number();
213                let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
214                let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
215
216                let chainsync_server = server.chainsync();
217
218                chainsync_server.recv_while_idle().await.unwrap();
219
220                chainsync_server
221                    .send_intersect_found(
222                        known_point.clone(),
223                        Tip(known_point.clone(), *tip_block_number),
224                    )
225                    .await
226                    .unwrap();
227
228                chainsync_server.recv_while_idle().await.unwrap();
229
230                if has_agency == HasAgency::No {
231                    chainsync_server.send_await_reply().await.unwrap();
232                }
233
234                match action {
235                    ServerAction::RollBackward => {
236                        chainsync_server
237                            .send_roll_backward(
238                                known_point.clone(),
239                                Tip(known_point.clone(), *tip_block_number),
240                            )
241                            .await
242                            .unwrap();
243                    }
244                    ServerAction::RollForward => {
245                        let block = BlockContent(get_fake_raw_block());
246                        chainsync_server
247                            .send_roll_forward(block, Tip(known_point.clone(), *tip_block_number))
248                            .await
249                            .unwrap();
250                    }
251                }
252            }
253        })
254    }
255
256    #[tokio::test]
257    async fn get_next_chain_block_rolls_backward() {
258        let socket_path =
259            create_temp_dir("get_next_chain_block_rolls_backward").join("node.socket");
260        let known_point = get_fake_specific_point();
261        let server = setup_server(
262            socket_path.clone(),
263            ServerAction::RollBackward,
264            HasAgency::Yes,
265        )
266        .await;
267        let client = tokio::spawn(async move {
268            let mut chain_reader = PallasChainReader::new(
269                socket_path.as_path(),
270                CardanoNetwork::TestNet(10),
271                TestLogger::stdout(),
272            );
273
274            chain_reader
275                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
276                .await
277                .unwrap();
278
279            chain_reader.get_next_chain_block().await.unwrap().unwrap()
280        });
281
282        let (_, client_res) = tokio::join!(server, client);
283        let chain_block = client_res.expect("Client failed to get next chain block");
284        match chain_block {
285            ChainBlockNextAction::RollBackward { rollback_point } => {
286                assert_eq!(rollback_point, get_fake_raw_point_backwards());
287            }
288            _ => panic!("Unexpected chain block action"),
289        }
290    }
291
292    #[tokio::test]
293    async fn get_next_chain_block_rolls_forward() {
294        let socket_path = create_temp_dir("get_next_chain_block_rolls_forward").join("node.socket");
295        let known_point = get_fake_specific_point();
296        let server = setup_server(
297            socket_path.clone(),
298            ServerAction::RollForward,
299            HasAgency::Yes,
300        )
301        .await;
302        let client = tokio::spawn(async move {
303            let mut chain_reader = PallasChainReader::new(
304                socket_path.as_path(),
305                CardanoNetwork::TestNet(10),
306                TestLogger::stdout(),
307            );
308
309            chain_reader
310                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
311                .await
312                .unwrap();
313
314            chain_reader.get_next_chain_block().await.unwrap().unwrap()
315        });
316
317        let (_, client_res) = tokio::join!(server, client);
318        let chain_block = client_res.expect("Client failed to get next chain block");
319        match chain_block {
320            ChainBlockNextAction::RollForward { parsed_block } => {
321                assert_eq!(parsed_block, get_fake_scanned_block());
322            }
323            _ => panic!("Unexpected chain block action"),
324        }
325    }
326
327    #[tokio::test]
328    async fn get_next_chain_block_has_no_agency() {
329        let socket_path = create_temp_dir("get_next_chain_block_has_no_agency").join("node.socket");
330        let known_point = get_fake_specific_point();
331        let server = setup_server(
332            socket_path.clone(),
333            ServerAction::RollForward,
334            HasAgency::No,
335        )
336        .await;
337        let client = tokio::spawn(async move {
338            let mut chain_reader = PallasChainReader::new(
339                socket_path.as_path(),
340                CardanoNetwork::TestNet(10),
341                TestLogger::stdout(),
342            );
343
344            chain_reader
345                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
346                .await
347                .unwrap();
348
349            // forces the client to change the chainsync server agency state
350            let client = chain_reader.get_client().await.unwrap();
351            client.chainsync().request_next().await.unwrap();
352
353            // make sure that the chainsync client returns an error when attempting to find intersection without agency
354            client
355                .chainsync()
356                .find_intersect(vec![known_point.clone()])
357                .await
358                .expect_err("chainsync find_intersect without agency should fail");
359
360            // make sure that setting the chain point is harmless when the chainsync client does not have agency
361            chain_reader
362                .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
363                .await
364                .unwrap();
365
366            chain_reader.get_next_chain_block().await.unwrap().unwrap()
367        });
368
369        let (_, client_res) = tokio::join!(server, client);
370        let chain_block = client_res.expect("Client failed to get next chain block");
371        match chain_block {
372            ChainBlockNextAction::RollForward { parsed_block } => {
373                assert_eq!(parsed_block, get_fake_scanned_block());
374            }
375            _ => panic!("Unexpected chain block action"),
376        }
377    }
378}