mithril_common/chain_reader/
pallas_chain_reader.rs1use std::path::{Path, PathBuf};
2
3use anyhow::{anyhow, Context};
4use async_trait::async_trait;
5use pallas_network::{
6 facades::NodeClient,
7 miniprotocols::chainsync::{BlockContent, NextResponse},
8};
9use pallas_traverse::MultiEraBlock;
10use slog::{debug, Logger};
11
12use crate::logging::LoggerExtensions;
13use crate::{
14 cardano_block_scanner::{RawCardanoPoint, ScannedBlock},
15 CardanoNetwork, StdResult,
16};
17
18use super::{ChainBlockNextAction, ChainBlockReader};
19
20pub struct PallasChainReader {
22 socket: PathBuf,
23 network: CardanoNetwork,
24 client: Option<NodeClient>,
25 logger: Logger,
26}
27
28impl PallasChainReader {
29 pub fn new(socket: &Path, network: CardanoNetwork, logger: Logger) -> Self {
31 Self {
32 socket: socket.to_owned(),
33 network,
34 client: None,
35 logger: logger.new_with_component_name::<Self>(),
36 }
37 }
38
39 async fn new_client(&self) -> StdResult<NodeClient> {
41 let magic = self.network.code();
42 NodeClient::connect(&self.socket, magic)
43 .await
44 .map_err(|err| anyhow!(err))
45 .with_context(|| "PallasChainReader failed to create a new client")
46 }
47
48 async fn get_client(&mut self) -> StdResult<&mut NodeClient> {
50 if self.client.is_none() {
51 self.client = Some(self.new_client().await?);
52 debug!(self.logger, "Connected to a new client");
53 }
54
55 self.client
56 .as_mut()
57 .with_context(|| "PallasChainReader failed to get a client")
58 }
59
60 async fn find_intersect_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
62 let logger = self.logger.clone();
63 let client = self.get_client().await?;
64 let chainsync = client.chainsync();
65
66 if chainsync.has_agency() {
67 debug!(logger, "Has agency, finding intersect point..."; "point" => ?point);
68 chainsync
69 .find_intersect(vec![point.to_owned().into()])
70 .await?;
71 } else {
72 debug!(logger, "Doesn't have agency, no need to find intersect point";);
73 }
74
75 Ok(())
76 }
77
78 async fn process_chain_block_next_action(
80 &mut self,
81 next: NextResponse<BlockContent>,
82 ) -> StdResult<Option<ChainBlockNextAction>> {
83 match next {
84 NextResponse::RollForward(raw_block, _forward_tip) => {
85 let multi_era_block = MultiEraBlock::decode(&raw_block)
86 .with_context(|| "PallasChainReader failed to decode raw block")?;
87 let parsed_block = ScannedBlock::convert(multi_era_block);
88 Ok(Some(ChainBlockNextAction::RollForward { parsed_block }))
89 }
90 NextResponse::RollBackward(rollback_point, _) => {
91 Ok(Some(ChainBlockNextAction::RollBackward {
92 rollback_point: RawCardanoPoint::from(rollback_point),
93 }))
94 }
95 NextResponse::Await => Ok(None),
96 }
97 }
98}
99
100impl Drop for PallasChainReader {
101 fn drop(&mut self) {
102 if let Some(client) = self.client.take() {
103 tokio::spawn(async move {
104 let _ = client.abort().await;
105 });
106 }
107 }
108}
109
110#[async_trait]
111impl ChainBlockReader for PallasChainReader {
112 async fn set_chain_point(&mut self, point: &RawCardanoPoint) -> StdResult<()> {
113 self.find_intersect_point(point).await
114 }
115
116 async fn get_next_chain_block(&mut self) -> StdResult<Option<ChainBlockNextAction>> {
117 let client = self.get_client().await?;
118 let chainsync = client.chainsync();
119
120 let next = match chainsync.has_agency() {
121 true => chainsync.request_next().await?,
122 false => chainsync.recv_while_must_reply().await?,
123 };
124
125 self.process_chain_block_next_action(next).await
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use std::fs;
132
133 use pallas_network::{
134 facades::NodeServer,
135 miniprotocols::{
136 chainsync::{BlockContent, Tip},
137 Point,
138 },
139 };
140 use tokio::net::UnixListener;
141
142 use super::*;
143
144 use crate::test_utils::TestLogger;
145 use crate::{entities::BlockNumber, test_utils::TempDir};
146
147 enum ServerAction {
149 RollBackward,
150 RollForward,
151 }
152
153 #[derive(Debug, PartialEq)]
155 enum HasAgency {
156 Yes,
157 No,
158 }
159
160 fn get_fake_specific_point() -> Point {
162 Point::Specific(
163 1654413,
164 hex::decode("7de1f036df5a133ce68a82877d14354d0ba6de7625ab918e75f3e2ecb29771c2")
165 .unwrap(),
166 )
167 }
168
169 fn get_fake_block_number() -> BlockNumber {
171 BlockNumber(1337)
172 }
173
174 fn get_fake_raw_point_backwards() -> RawCardanoPoint {
176 RawCardanoPoint::from(get_fake_specific_point())
177 }
178
179 fn create_temp_dir(folder_name: &str) -> PathBuf {
181 TempDir::create_with_short_path("pallas_chain_observer_test", folder_name)
182 }
183
184 fn get_fake_raw_block() -> Vec<u8> {
185 let raw_block = include_str!("../../../mithril-test-lab/test_data/blocks/shelley1.block");
186
187 hex::decode(raw_block).unwrap()
188 }
189
190 fn get_fake_scanned_block() -> ScannedBlock {
191 let raw_block = get_fake_raw_block();
192 let multi_era_block = MultiEraBlock::decode(&raw_block).unwrap();
193
194 ScannedBlock::convert(multi_era_block)
195 }
196
197 async fn setup_server(
201 socket_path: PathBuf,
202 action: ServerAction,
203 has_agency: HasAgency,
204 ) -> tokio::task::JoinHandle<()> {
205 tokio::spawn({
206 async move {
207 if socket_path.exists() {
208 fs::remove_file(&socket_path).expect("Previous socket removal failed");
209 }
210
211 let known_point = get_fake_specific_point();
212 let tip_block_number = get_fake_block_number();
213 let unix_listener = UnixListener::bind(socket_path.as_path()).unwrap();
214 let mut server = NodeServer::accept(&unix_listener, 10).await.unwrap();
215
216 let chainsync_server = server.chainsync();
217
218 chainsync_server.recv_while_idle().await.unwrap();
219
220 chainsync_server
221 .send_intersect_found(
222 known_point.clone(),
223 Tip(known_point.clone(), *tip_block_number),
224 )
225 .await
226 .unwrap();
227
228 chainsync_server.recv_while_idle().await.unwrap();
229
230 if has_agency == HasAgency::No {
231 chainsync_server.send_await_reply().await.unwrap();
232 }
233
234 match action {
235 ServerAction::RollBackward => {
236 chainsync_server
237 .send_roll_backward(
238 known_point.clone(),
239 Tip(known_point.clone(), *tip_block_number),
240 )
241 .await
242 .unwrap();
243 }
244 ServerAction::RollForward => {
245 let block = BlockContent(get_fake_raw_block());
246 chainsync_server
247 .send_roll_forward(block, Tip(known_point.clone(), *tip_block_number))
248 .await
249 .unwrap();
250 }
251 }
252 }
253 })
254 }
255
256 #[tokio::test]
257 async fn get_next_chain_block_rolls_backward() {
258 let socket_path =
259 create_temp_dir("get_next_chain_block_rolls_backward").join("node.socket");
260 let known_point = get_fake_specific_point();
261 let server = setup_server(
262 socket_path.clone(),
263 ServerAction::RollBackward,
264 HasAgency::Yes,
265 )
266 .await;
267 let client = tokio::spawn(async move {
268 let mut chain_reader = PallasChainReader::new(
269 socket_path.as_path(),
270 CardanoNetwork::TestNet(10),
271 TestLogger::stdout(),
272 );
273
274 chain_reader
275 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
276 .await
277 .unwrap();
278
279 chain_reader.get_next_chain_block().await.unwrap().unwrap()
280 });
281
282 let (_, client_res) = tokio::join!(server, client);
283 let chain_block = client_res.expect("Client failed to get next chain block");
284 match chain_block {
285 ChainBlockNextAction::RollBackward { rollback_point } => {
286 assert_eq!(rollback_point, get_fake_raw_point_backwards());
287 }
288 _ => panic!("Unexpected chain block action"),
289 }
290 }
291
292 #[tokio::test]
293 async fn get_next_chain_block_rolls_forward() {
294 let socket_path = create_temp_dir("get_next_chain_block_rolls_forward").join("node.socket");
295 let known_point = get_fake_specific_point();
296 let server = setup_server(
297 socket_path.clone(),
298 ServerAction::RollForward,
299 HasAgency::Yes,
300 )
301 .await;
302 let client = tokio::spawn(async move {
303 let mut chain_reader = PallasChainReader::new(
304 socket_path.as_path(),
305 CardanoNetwork::TestNet(10),
306 TestLogger::stdout(),
307 );
308
309 chain_reader
310 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
311 .await
312 .unwrap();
313
314 chain_reader.get_next_chain_block().await.unwrap().unwrap()
315 });
316
317 let (_, client_res) = tokio::join!(server, client);
318 let chain_block = client_res.expect("Client failed to get next chain block");
319 match chain_block {
320 ChainBlockNextAction::RollForward { parsed_block } => {
321 assert_eq!(parsed_block, get_fake_scanned_block());
322 }
323 _ => panic!("Unexpected chain block action"),
324 }
325 }
326
327 #[tokio::test]
328 async fn get_next_chain_block_has_no_agency() {
329 let socket_path = create_temp_dir("get_next_chain_block_has_no_agency").join("node.socket");
330 let known_point = get_fake_specific_point();
331 let server = setup_server(
332 socket_path.clone(),
333 ServerAction::RollForward,
334 HasAgency::No,
335 )
336 .await;
337 let client = tokio::spawn(async move {
338 let mut chain_reader = PallasChainReader::new(
339 socket_path.as_path(),
340 CardanoNetwork::TestNet(10),
341 TestLogger::stdout(),
342 );
343
344 chain_reader
345 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
346 .await
347 .unwrap();
348
349 let client = chain_reader.get_client().await.unwrap();
351 client.chainsync().request_next().await.unwrap();
352
353 client
355 .chainsync()
356 .find_intersect(vec![known_point.clone()])
357 .await
358 .expect_err("chainsync find_intersect without agency should fail");
359
360 chain_reader
362 .set_chain_point(&RawCardanoPoint::from(known_point.clone()))
363 .await
364 .unwrap();
365
366 chain_reader.get_next_chain_block().await.unwrap().unwrap()
367 });
368
369 let (_, client_res) = tokio::join!(server, client);
370 let chain_block = client_res.expect("Client failed to get next chain block");
371 match chain_block {
372 ChainBlockNextAction::RollForward { parsed_block } => {
373 assert_eq!(parsed_block, get_fake_scanned_block());
374 }
375 _ => panic!("Unexpected chain block action"),
376 }
377 }
378}