mithril_cardano_node_chain/chain_reader/
pallas_chain_reader.rs1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, anyhow};
4use async_trait::async_trait;
5use pallas_network::{
6 facades::NodeClient,
7 miniprotocols::chainsync::{BlockContent, NextResponse},
8};
9use pallas_traverse::MultiEraBlock;
10use slog::{Logger, debug};
11
12use mithril_common::StdResult;
13use mithril_common::entities::CardanoNetwork;
14use mithril_common::logging::LoggerExtensions;
15
16use crate::entities::{ChainBlockNextAction, RawCardanoPoint, ScannedBlock};
17
18use super::ChainBlockReader;
19
20pub struct PallasChainReader {
22 socket: PathBuf,
23 network: CardanoNetwork,
24 client: Option<NodeClient>,
25 logger: Logger,
26}
27
28impl PallasChainReader {
29 pub fn new(socket: &Path, network: CardanoNetwork, logger: Logger) -> Self {
31 Self {
32 socket: socket.to_owned(),
33 network,
34 client: None,
35 logger: logger.new_with_component_name::<Self>(),
36 }
37 }
38
39 async fn new_client(&self) -> StdResult<NodeClient> {
41 let magic = self.network.code();
42 NodeClient::connect(&self.socket, magic)
43 .await
44 .map_err(|err| anyhow!(err))
45 .with_context(|| "PallasChainReader failed to create a new client")
46 }
47
48 async fn get_client(&mut self) -> StdResult<&mut NodeClient> {
50 if self.client.is_none() {
51 self.client = Some(self.new_client().await?);
52 debug!(self.logger, "Connected to a new client");
53 }
54
55 self.client
56 .as_mut()
57 .with_context(|| "PallasChainReader failed to get a client")
58 }
59
60 #[cfg(all(test, unix))]
61 fn has_client(&self) -> bool {
63 self.client.is_some()
64 }
65
66 fn drop_client(&mut self) {
68 if let Some(client) = self.client.take() {
69 tokio::spawn(async move {
70 let _ = client.abort().await;
71 });
72 }
73 }
74
75 async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
77 let logger = self.logger.clone();
78 let client = self.get_client().await?;
79 let chainsync = client.chainsync();
80
81 if chainsync.has_agency() {
82 debug!(logger, "Has agency, finding intersect point..."; "point" => ?point);
83 chainsync.find_intersect(vec![point.to_owned().into()]).await?;
84 } else {
85 debug!(logger, "Doesn't have agency, no need to find intersect point";);
86 }
87
88 Ok(())
89 }
90
91 async fn process_chain_block_next_action(
93 &mut self,
94 next: NextResponse<BlockContent>,
95 ) -> StdResult<Option<ChainBlockNextAction>> {
96 match next {
97 NextResponse::RollForward(raw_block, _forward_tip) => {
98 let multi_era_block = MultiEraBlock::decode(&raw_block)
99 .with_context(|| "PallasChainReader failed to decode raw block")?;
100 let parsed_block = ScannedBlock::convert(multi_era_block);
101 Ok(Some(ChainBlockNextAction::RollForward { parsed_block }))
102 }
103 NextResponse::RollBackward(rollback_point, _) => {
104 Ok(Some(ChainBlockNextAction::RollBackward {
105 rollback_point: RawCardanoPoint::from(rollback_point),
106 }))
107 }
108 NextResponse::Await => Ok(None),
109 }
110 }
111}
112
113impl Drop for PallasChainReader {
114 fn drop(&mut self) {
115 self.drop_client();
116 }
117}
118
119#[async_trait]
120impl ChainBlockReader for PallasChainReader {
121 async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
122 match self.find_intersect_point(point).await {
123 Ok(()) => Ok(()),
124 Err(err) => {
125 self.drop_client();
126
127 return Err(err);
128 }
129 }
130 }
131
132 async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
133 let client = self.get_client().await?;
134 let chainsync = client.chainsync();
135 let next = match chainsync.has_agency() {
136 true => chainsync.request_next().await,
137 false => chainsync.recv_while_must_reply().await,
138 };
139 match next {
140 Ok(next) => self.process_chain_block_next_action(next).await,
141 Err(err) => {
142 self.drop_client();
143
144 return Err(err.into());
145 }
146 }
147 }
148}
149
150#[cfg(all(test, unix))]
152mod tests {
153 use pallas_network::{
154 facades::NodeServer,
155 miniprotocols::{
156 Point,
157 chainsync::{BlockContent, Tip},
158 },
159 };
160 use std::fs;
161 use tokio::net::UnixListener;
162
163 use mithril_common::{current_function, entities::BlockNumber, test_utils::TempDir};
164
165 use crate::test::TestLogger;
166
167 use super::*;
168
169 enum ServerAction {
171 RollBackward,
172 RollForward,
173 }
174
175 #[derive(Debug, PartialEq)]
177 enum HasAgency {
178 Yes,
179 No,
180 }
181
182 fn get_fake_specific_point() -> Point {
184 Point::Specific(
185 1654413,
186 hex::decode("7de1f036df5a133ce68a82877d14354d0ba6de7625ab918e75f3e2ecb29771c2")
187 .unwrap(),
188 )
189 }
190
191 fn get_fake_block_number() -> BlockNumber {
193 BlockNumber(1337)
194 }
195
196 fn get_fake_raw_point_backwards() -> RawCardanoPoint {
198 RawCardanoPoint::from(get_fake_specific_point())
199 }
200
201 fn create_temp_dir(folder_name: &str) -> PathBuf {
203 TempDir::create_with_short_path("pallas_chain_observer_test", folder_name)
204 }
205
206 fn get_fake_raw_block() -> Vec<u8> {
207 let raw_block =
208 include_str!("../../../../../mithril-test-lab/test_data/blocks/shelley1.block");
209
210 hex::decode(raw_block).unwrap()
211 }
212
213 fn get_fake_scanned_block() -> ScannedBlock {
214 let raw_block = get_fake_raw_block();
215 let multi_era_block = MultiEraBlock::decode(&raw_block).unwrap();
216
217 ScannedBlock::convert(multi_era_block)
218 }
219
220 async fn setup_server(
224 socket_path: PathBuf,
225 action: ServerAction,
226 has_agency: HasAgency,
227 ) -> tokio::task::JoinHandle<NodeServer> {
228 tokio::spawn({
229 async move {
230 if socket_path.exists() {
231 fs::remove_file(&socket_path).expect("Previous socket removal failed");
232 }
233
234 let known_point = get_fake_specific_point();
235 let tip_block_number = get_fake_block_number();
236 let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
237 let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
238
239 let chainsync_server = server.chainsync();
240
241 chainsync_server.recv_while_idle().await.unwrap();
242
243 chainsync_server
244 .send_intersect_found(
245 known_point.clone(),
246 Tip(known_point.clone(), *tip_block_number),
247 )
248 .await
249 .unwrap();
250
251 chainsync_server.recv_while_idle().await.unwrap();
252
253 if has_agency == HasAgency::No {
254 chainsync_server.send_await_reply().await.unwrap();
255 }
256
257 match action {
258 ServerAction::RollBackward => {
259 chainsync_server
260 .send_roll_backward(
261 known_point.clone(),
262 Tip(known_point.clone(), *tip_block_number),
263 )
264 .await
265 .unwrap();
266 }
267 ServerAction::RollForward => {
268 let block = BlockContent(get_fake_raw_block());
269 chainsync_server
270 .send_roll_forward(block, Tip(known_point.clone(), *tip_block_number))
271 .await
272 .unwrap();
273 }
274 }
275
276 server
277 }
278 })
279 }
280
281 #[tokio::test]
282 async fn get_next_chain_block_rolls_backward() {
283 let socket_path = create_temp_dir(current_function!()).join("node.socket");
284 let known_point = get_fake_specific_point();
285 let server = setup_server(
286 socket_path.clone(),
287 ServerAction::RollBackward,
288 HasAgency::Yes,
289 )
290 .await;
291 let client = tokio::spawn(async move {
292 let mut chain_reader = PallasChainReader::new(
293 socket_path.as_path(),
294 CardanoNetwork::TestNet(10),
295 TestLogger::stdout(),
296 );
297
298 chain_reader
299 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
300 .await
301 .unwrap();
302
303 chain_reader.get_next_chain_block().await.unwrap().unwrap()
304 });
305
306 let (_, client_res) = tokio::join!(server, client);
307 let chain_block = client_res.expect("Client failed to get next chain block");
308 match chain_block {
309 ChainBlockNextAction::RollBackward { rollback_point } => {
310 assert_eq!(rollback_point, get_fake_raw_point_backwards());
311 }
312 _ => panic!("Unexpected chain block action"),
313 }
314 }
315
316 #[tokio::test]
317 async fn get_next_chain_block_rolls_forward() {
318 let socket_path = create_temp_dir(current_function!()).join("node.socket");
319 let known_point = get_fake_specific_point();
320 let server = setup_server(
321 socket_path.clone(),
322 ServerAction::RollForward,
323 HasAgency::Yes,
324 )
325 .await;
326 let client = tokio::spawn(async move {
327 let mut chain_reader = PallasChainReader::new(
328 socket_path.as_path(),
329 CardanoNetwork::TestNet(10),
330 TestLogger::stdout(),
331 );
332
333 chain_reader
334 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
335 .await
336 .unwrap();
337
338 chain_reader.get_next_chain_block().await.unwrap().unwrap()
339 });
340
341 let (_, client_res) = tokio::join!(server, client);
342 let chain_block = client_res.expect("Client failed to get next chain block");
343 match chain_block {
344 ChainBlockNextAction::RollForward { parsed_block } => {
345 assert_eq!(parsed_block, get_fake_scanned_block());
346 }
347 _ => panic!("Unexpected chain block action"),
348 }
349 }
350
351 #[tokio::test]
352 async fn get_next_chain_block_has_no_agency() {
353 let socket_path = create_temp_dir(current_function!()).join("node.socket");
354 let known_point = get_fake_specific_point();
355 let server = setup_server(
356 socket_path.clone(),
357 ServerAction::RollForward,
358 HasAgency::No,
359 )
360 .await;
361 let client = tokio::spawn(async move {
362 let mut chain_reader = PallasChainReader::new(
363 socket_path.as_path(),
364 CardanoNetwork::TestNet(10),
365 TestLogger::stdout(),
366 );
367
368 chain_reader
369 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
370 .await
371 .unwrap();
372
373 let client = chain_reader.get_client().await.unwrap();
375 client.chainsync().request_next().await.unwrap();
376
377 client
379 .chainsync()
380 .find_intersect(vec![known_point.clone()])
381 .await
382 .expect_err("chainsync find_intersect without agency should fail");
383
384 chain_reader
386 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
387 .await
388 .unwrap();
389
390 chain_reader.get_next_chain_block().await.unwrap().unwrap()
391 });
392
393 let (_, client_res) = tokio::join!(server, client);
394 let chain_block = client_res.expect("Client failed to get next chain block");
395 match chain_block {
396 ChainBlockNextAction::RollForward { parsed_block } => {
397 assert_eq!(parsed_block, get_fake_scanned_block());
398 }
399 _ => panic!("Unexpected chain block action"),
400 }
401 }
402
403 #[tokio::test]
404 async fn cached_client_is_dropped_when_returning_error() {
405 let socket_path = create_temp_dir(current_function!()).join("node.socket");
406 let socket_path_clone = socket_path.clone();
407 let known_point = get_fake_specific_point();
408 let server = setup_server(
409 socket_path.clone(),
410 ServerAction::RollForward,
411 HasAgency::Yes,
412 )
413 .await;
414 let client = tokio::spawn(async move {
415 let mut chain_reader = PallasChainReader::new(
416 socket_path_clone.as_path(),
417 CardanoNetwork::TestNet(10),
418 TestLogger::stdout(),
419 );
420
421 chain_reader
422 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
423 .await
424 .unwrap();
425
426 chain_reader.get_next_chain_block().await.unwrap().unwrap();
427
428 chain_reader
429 });
430
431 let (server_res, client_res) = tokio::join!(server, client);
432 let chain_reader = client_res.expect("Client failed to get chain reader");
433 let server = server_res.expect("Server failed to get server");
434 server.abort().await;
435
436 let client = tokio::spawn(async move {
437 let mut chain_reader = chain_reader;
438
439 assert!(chain_reader.has_client(), "Client should exist");
440
441 chain_reader
442 .get_next_chain_block()
443 .await
444 .expect_err("Chain reader get_next_chain_block should fail");
445
446 assert!(
447 !chain_reader.has_client(),
448 "Client should have been dropped after error"
449 );
450
451 chain_reader
452 });
453 client.await.unwrap();
454 }
455}