mithril_cardano_node_chain/chain_reader/
pallas_chain_reader.rs1use std::path::{Path, PathBuf};
2
3use anyhow::Context;
4use async_trait::async_trait;
5use pallas_network::{
6 facades::NodeClient,
7 miniprotocols::chainsync::{BlockContent, NextResponse},
8};
9use pallas_traverse::MultiEraBlock;
10use slog::{Logger, debug};
11
12use mithril_common::StdResult;
13use mithril_common::entities::CardanoNetwork;
14use mithril_common::logging::LoggerExtensions;
15
16use crate::entities::{ChainBlockNextAction, RawCardanoPoint, ScannedBlock};
17
18use super::ChainBlockReader;
19
20pub struct PallasChainReader {
22 socket: PathBuf,
23 network: CardanoNetwork,
24 client: Option<NodeClient>,
25 logger: Logger,
26}
27
28impl PallasChainReader {
29 pub fn new(socket: &Path, network: CardanoNetwork, logger: Logger) -> Self {
31 Self {
32 socket: socket.to_owned(),
33 network,
34 client: None,
35 logger: logger.new_with_component_name::<Self>(),
36 }
37 }
38
39 async fn new_client(&self) -> StdResult<NodeClient> {
41 let magic = self.network.magic_id();
42 NodeClient::connect(&self.socket, magic)
43 .await
44 .with_context(|| "PallasChainReader failed to create a new client")
45 }
46
47 async fn get_client(&mut self) -> StdResult<&mut NodeClient> {
49 if self.client.is_none() {
50 self.client = Some(self.new_client().await?);
51 debug!(self.logger, "Connected to a new client");
52 }
53
54 self.client
55 .as_mut()
56 .with_context(|| "PallasChainReader failed to get a client")
57 }
58
59 #[cfg(all(test, unix))]
60 fn has_client(&self) -> bool {
62 self.client.is_some()
63 }
64
65 fn drop_client(&mut self) {
67 if let Some(client) = self.client.take() {
68 tokio::spawn(async move {
69 let _ = client.abort().await;
70 });
71 }
72 }
73
74 async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
76 let logger = self.logger.clone();
77 let client = self.get_client().await?;
78 let chainsync = client.chainsync();
79
80 if chainsync.has_agency() {
81 debug!(logger, "Has agency, finding intersect point..."; "point" => ?point);
82 chainsync.find_intersect(vec![point.to_owned().into()]).await?;
83 } else {
84 debug!(logger, "Doesn't have agency, no need to find intersect point";);
85 }
86
87 Ok(())
88 }
89
90 async fn process_chain_block_next_action(
92 &mut self,
93 next: NextResponse<BlockContent>,
94 ) -> StdResult<Option<ChainBlockNextAction>> {
95 match next {
96 NextResponse::RollForward(raw_block, _forward_tip) => {
97 let multi_era_block = MultiEraBlock::decode(&raw_block)
98 .with_context(|| "PallasChainReader failed to decode raw block")?;
99 let parsed_block = ScannedBlock::convert(multi_era_block);
100 Ok(Some(ChainBlockNextAction::RollForward { parsed_block }))
101 }
102 NextResponse::RollBackward(rollback_point, _) => {
103 Ok(Some(ChainBlockNextAction::RollBackward {
104 rollback_point: RawCardanoPoint::from(rollback_point),
105 }))
106 }
107 NextResponse::Await => Ok(None),
108 }
109 }
110}
111
112impl Drop for PallasChainReader {
113 fn drop(&mut self) {
114 self.drop_client();
115 }
116}
117
118#[async_trait]
119impl ChainBlockReader for PallasChainReader {
120 async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
121 match self.find_intersect_point(point).await {
122 Ok(()) => Ok(()),
123 Err(err) => {
124 self.drop_client();
125
126 return Err(err);
127 }
128 }
129 }
130
131 async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
132 let client = self.get_client().await?;
133 let chainsync = client.chainsync();
134 let next = match chainsync.has_agency() {
135 true => chainsync.request_next().await,
136 false => chainsync.recv_while_must_reply().await,
137 };
138 match next {
139 Ok(next) => self.process_chain_block_next_action(next).await,
140 Err(err) => {
141 self.drop_client();
142
143 return Err(err.into());
144 }
145 }
146 }
147}
148
149#[cfg(all(test, unix))]
151mod tests {
152 use pallas_network::{
153 facades::NodeServer,
154 miniprotocols::{
155 Point,
156 chainsync::{BlockContent, Tip},
157 },
158 };
159 use std::fs;
160 use tokio::net::UnixListener;
161
162 use mithril_common::{current_function, entities::BlockNumber, test::TempDir};
163
164 use crate::test::TestLogger;
165
166 use super::*;
167
168 enum ServerAction {
170 RollBackward,
171 RollForward,
172 }
173
174 #[derive(Debug, PartialEq)]
176 enum HasAgency {
177 Yes,
178 No,
179 }
180
181 fn get_fake_specific_point() -> Point {
183 Point::Specific(
184 1654413,
185 hex::decode("7de1f036df5a133ce68a82877d14354d0ba6de7625ab918e75f3e2ecb29771c2")
186 .unwrap(),
187 )
188 }
189
190 fn get_fake_block_number() -> BlockNumber {
192 BlockNumber(1337)
193 }
194
195 fn get_fake_raw_point_backwards() -> RawCardanoPoint {
197 RawCardanoPoint::from(get_fake_specific_point())
198 }
199
200 fn create_temp_dir(folder_name: &str) -> PathBuf {
202 TempDir::create_with_short_path("pallas_chain_observer_test", folder_name)
203 }
204
205 fn get_fake_raw_block() -> Vec<u8> {
206 let raw_block =
207 include_str!("../../../../../mithril-test-lab/test_data/blocks/shelley1.block");
208
209 hex::decode(raw_block).unwrap()
210 }
211
212 fn get_fake_scanned_block() -> ScannedBlock {
213 let raw_block = get_fake_raw_block();
214 let multi_era_block = MultiEraBlock::decode(&raw_block).unwrap();
215
216 ScannedBlock::convert(multi_era_block)
217 }
218
219 async fn setup_server(
223 socket_path: PathBuf,
224 action: ServerAction,
225 has_agency: HasAgency,
226 ) -> tokio::task::JoinHandle<NodeServer> {
227 tokio::spawn({
228 async move {
229 if socket_path.exists() {
230 fs::remove_file(&socket_path).expect("Previous socket removal failed");
231 }
232
233 let known_point = get_fake_specific_point();
234 let tip_block_number = get_fake_block_number();
235 let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
236 let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
237
238 let chainsync_server = server.chainsync();
239
240 chainsync_server.recv_while_idle().await.unwrap();
241
242 chainsync_server
243 .send_intersect_found(
244 known_point.clone(),
245 Tip(known_point.clone(), *tip_block_number),
246 )
247 .await
248 .unwrap();
249
250 chainsync_server.recv_while_idle().await.unwrap();
251
252 if has_agency == HasAgency::No {
253 chainsync_server.send_await_reply().await.unwrap();
254 }
255
256 match action {
257 ServerAction::RollBackward => {
258 chainsync_server
259 .send_roll_backward(
260 known_point.clone(),
261 Tip(known_point.clone(), *tip_block_number),
262 )
263 .await
264 .unwrap();
265 }
266 ServerAction::RollForward => {
267 let block = BlockContent(get_fake_raw_block());
268 chainsync_server
269 .send_roll_forward(block, Tip(known_point.clone(), *tip_block_number))
270 .await
271 .unwrap();
272 }
273 }
274
275 server
276 }
277 })
278 }
279
280 #[tokio::test]
281 async fn get_next_chain_block_rolls_backward() {
282 let socket_path = create_temp_dir(current_function!()).join("node.socket");
283 let known_point = get_fake_specific_point();
284 let server = setup_server(
285 socket_path.clone(),
286 ServerAction::RollBackward,
287 HasAgency::Yes,
288 )
289 .await;
290 let client = tokio::spawn(async move {
291 let mut chain_reader = PallasChainReader::new(
292 socket_path.as_path(),
293 CardanoNetwork::TestNet(10),
294 TestLogger::stdout(),
295 );
296
297 chain_reader
298 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
299 .await
300 .unwrap();
301
302 chain_reader.get_next_chain_block().await.unwrap().unwrap()
303 });
304
305 let (_, client_res) = tokio::join!(server, client);
306 let chain_block = client_res.expect("Client failed to get next chain block");
307 match chain_block {
308 ChainBlockNextAction::RollBackward { rollback_point } => {
309 assert_eq!(rollback_point, get_fake_raw_point_backwards());
310 }
311 _ => panic!("Unexpected chain block action"),
312 }
313 }
314
315 #[tokio::test]
316 async fn get_next_chain_block_rolls_forward() {
317 let socket_path = create_temp_dir(current_function!()).join("node.socket");
318 let known_point = get_fake_specific_point();
319 let server = setup_server(
320 socket_path.clone(),
321 ServerAction::RollForward,
322 HasAgency::Yes,
323 )
324 .await;
325 let client = tokio::spawn(async move {
326 let mut chain_reader = PallasChainReader::new(
327 socket_path.as_path(),
328 CardanoNetwork::TestNet(10),
329 TestLogger::stdout(),
330 );
331
332 chain_reader
333 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
334 .await
335 .unwrap();
336
337 chain_reader.get_next_chain_block().await.unwrap().unwrap()
338 });
339
340 let (_, client_res) = tokio::join!(server, client);
341 let chain_block = client_res.expect("Client failed to get next chain block");
342 match chain_block {
343 ChainBlockNextAction::RollForward { parsed_block } => {
344 assert_eq!(parsed_block, get_fake_scanned_block());
345 }
346 _ => panic!("Unexpected chain block action"),
347 }
348 }
349
350 #[tokio::test]
351 async fn get_next_chain_block_has_no_agency() {
352 let socket_path = create_temp_dir(current_function!()).join("node.socket");
353 let known_point = get_fake_specific_point();
354 let server = setup_server(
355 socket_path.clone(),
356 ServerAction::RollForward,
357 HasAgency::No,
358 )
359 .await;
360 let client = tokio::spawn(async move {
361 let mut chain_reader = PallasChainReader::new(
362 socket_path.as_path(),
363 CardanoNetwork::TestNet(10),
364 TestLogger::stdout(),
365 );
366
367 chain_reader
368 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
369 .await
370 .unwrap();
371
372 let client = chain_reader.get_client().await.unwrap();
374 client.chainsync().request_next().await.unwrap();
375
376 client
378 .chainsync()
379 .find_intersect(vec![known_point.clone()])
380 .await
381 .expect_err("chainsync find_intersect without agency should fail");
382
383 chain_reader
385 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
386 .await
387 .unwrap();
388
389 chain_reader.get_next_chain_block().await.unwrap().unwrap()
390 });
391
392 let (_, client_res) = tokio::join!(server, client);
393 let chain_block = client_res.expect("Client failed to get next chain block");
394 match chain_block {
395 ChainBlockNextAction::RollForward { parsed_block } => {
396 assert_eq!(parsed_block, get_fake_scanned_block());
397 }
398 _ => panic!("Unexpected chain block action"),
399 }
400 }
401
402 #[tokio::test]
403 async fn cached_client_is_dropped_when_returning_error() {
404 let socket_path = create_temp_dir(current_function!()).join("node.socket");
405 let socket_path_clone = socket_path.clone();
406 let known_point = get_fake_specific_point();
407 let server = setup_server(
408 socket_path.clone(),
409 ServerAction::RollForward,
410 HasAgency::Yes,
411 )
412 .await;
413 let client = tokio::spawn(async move {
414 let mut chain_reader = PallasChainReader::new(
415 socket_path_clone.as_path(),
416 CardanoNetwork::TestNet(10),
417 TestLogger::stdout(),
418 );
419
420 chain_reader
421 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
422 .await
423 .unwrap();
424
425 chain_reader.get_next_chain_block().await.unwrap().unwrap();
426
427 chain_reader
428 });
429
430 let (server_res, client_res) = tokio::join!(server, client);
431 let chain_reader = client_res.expect("Client failed to get chain reader");
432 let server = server_res.expect("Server failed to get server");
433 server.abort().await;
434
435 let client = tokio::spawn(async move {
436 let mut chain_reader = chain_reader;
437
438 assert!(chain_reader.has_client(), "Client should exist");
439
440 chain_reader
441 .get_next_chain_block()
442 .await
443 .expect_err("Chain reader get_next_chain_block should fail");
444
445 assert!(
446 !chain_reader.has_client(),
447 "Client should have been dropped after error"
448 );
449
450 chain_reader
451 });
452 client.await.unwrap();
453 }
454}