1use std::sync::Arc;
2
3use async_trait::async_trait;
4use slog::{Logger, debug, trace};
5use tokio::sync::Mutex;
6
7use mithril_common::StdResult;
8use mithril_common::entities::BlockNumber;
9use mithril_common::logging::LoggerExtensions;
10
11use crate::chain_reader::ChainBlockReader;
12use crate::chain_scanner::{BlockStreamer, ChainScannedBlocks};
13use crate::entities::{ChainBlockNextAction, RawCardanoPoint};
14
15#[derive(Debug, Clone, PartialEq)]
17enum BlockStreamerNextAction {
18 ChainBlockNextAction(ChainBlockNextAction),
20 SkipToNextAction,
22}
23
24pub struct ChainReaderBlockStreamer {
26 chain_reader: Arc<Mutex<dyn ChainBlockReader>>,
27 from: RawCardanoPoint,
28 until: BlockNumber,
29 max_roll_forwards_per_poll: usize,
30 last_polled_point: Option<RawCardanoPoint>,
31 logger: Logger,
32}
33
34#[async_trait]
35impl BlockStreamer for ChainReaderBlockStreamer {
36 async fn poll_next(&mut self) -> StdResult<Option<ChainScannedBlocks>> {
37 debug!(self.logger, ">> poll_next"; "last_polled_point" => ?self.last_polled_point);
38
39 let chain_scanned_blocks: ChainScannedBlocks;
40 let mut roll_forwards = vec![];
41 loop {
42 let block_streamer_next_action = self.get_next_chain_block_action().await?;
43 match block_streamer_next_action {
44 Some(BlockStreamerNextAction::ChainBlockNextAction(
45 ChainBlockNextAction::RollForward { parsed_block },
46 )) => {
47 self.last_polled_point = Some(RawCardanoPoint::from(&parsed_block));
48 let parsed_block_number = parsed_block.block_number;
49 roll_forwards.push(parsed_block);
50 if roll_forwards.len() >= self.max_roll_forwards_per_poll
51 || parsed_block_number >= self.until
52 {
53 return Ok(Some(ChainScannedBlocks::RollForwards(roll_forwards)));
54 }
55 }
56 Some(BlockStreamerNextAction::ChainBlockNextAction(
57 ChainBlockNextAction::RollBackward { rollback_point },
58 )) => {
59 self.last_polled_point = Some(rollback_point.clone());
60 let rollback_slot_number = rollback_point.slot_number;
61 let index_rollback = roll_forwards
62 .iter()
63 .position(|block| block.slot_number == rollback_slot_number);
64 match index_rollback {
65 Some(index_rollback) => {
66 debug!(
67 self.logger,
68 "ChainScannedBlocks handled a buffer RollBackward({rollback_slot_number:?})"
69 );
70 roll_forwards.truncate(index_rollback + 1);
71 }
72 None => {
73 debug!(
74 self.logger,
75 "ChainScannedBlocks triggered a full RollBackward({rollback_slot_number:?})"
76 );
77 chain_scanned_blocks =
78 ChainScannedBlocks::RollBackward(rollback_slot_number);
79 return Ok(Some(chain_scanned_blocks));
80 }
81 }
82 }
83 Some(BlockStreamerNextAction::SkipToNextAction) => {
84 continue;
85 }
86 None => {
87 return if roll_forwards.is_empty() {
88 Ok(None)
89 } else {
90 chain_scanned_blocks = ChainScannedBlocks::RollForwards(roll_forwards);
91 Ok(Some(chain_scanned_blocks))
92 };
93 }
94 }
95 }
96 }
97
98 fn last_polled_point(&self) -> Option<RawCardanoPoint> {
99 self.last_polled_point.clone()
100 }
101}
102
103impl ChainReaderBlockStreamer {
104 pub async fn try_new(
106 chain_reader: Arc<Mutex<dyn ChainBlockReader>>,
107 from: Option<RawCardanoPoint>,
108 until: BlockNumber,
109 max_roll_forwards_per_poll: usize,
110 logger: Logger,
111 ) -> StdResult<Self> {
112 let from = from.unwrap_or(RawCardanoPoint::origin());
113 {
114 let mut chain_reader_inner = chain_reader.try_lock()?;
115 chain_reader_inner.set_chain_point(&from).await?;
116 }
117 debug!(
118 logger, "starting ChainReaderBlockStreamer";
119 "from" => ?from, "until" => ?until, "max_roll_forwards_per_poll" => max_roll_forwards_per_poll
120 );
121
122 Ok(Self {
123 chain_reader,
124 from,
125 until,
126 max_roll_forwards_per_poll,
127 last_polled_point: None,
128 logger: logger.new_with_component_name::<Self>(),
129 })
130 }
131
132 async fn get_next_chain_block_action(&self) -> StdResult<Option<BlockStreamerNextAction>> {
133 let mut chain_reader = self.chain_reader.try_lock()?;
134 match chain_reader.get_next_chain_block().await? {
135 Some(ChainBlockNextAction::RollForward { parsed_block }) => {
136 if parsed_block.block_number > self.until {
137 trace!(
138 self.logger,
139 "Received a RollForward above threshold block number ({})",
140 parsed_block.block_number
141 );
142 Ok(None)
143 } else {
144 trace!(
145 self.logger,
146 "Received a RollForward below threshold block number ({})",
147 parsed_block.block_number
148 );
149 Ok(Some(BlockStreamerNextAction::ChainBlockNextAction(
150 ChainBlockNextAction::RollForward { parsed_block },
151 )))
152 }
153 }
154 Some(ChainBlockNextAction::RollBackward { rollback_point }) => {
155 let rollback_slot_number = rollback_point.slot_number;
156 trace!(
157 self.logger,
158 "Received a RollBackward({rollback_slot_number:?})"
159 );
160 let block_streamer_next_action = if rollback_slot_number == self.from.slot_number {
161 BlockStreamerNextAction::SkipToNextAction
162 } else {
163 BlockStreamerNextAction::ChainBlockNextAction(
164 ChainBlockNextAction::RollBackward { rollback_point },
165 )
166 };
167 Ok(Some(block_streamer_next_action))
168 }
169 None => {
170 trace!(self.logger, "Received nothing");
171 Ok(None)
172 }
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use mithril_common::entities::SlotNumber;
180
181 use crate::entities::ScannedBlock;
182 use crate::test::TestLogger;
183 use crate::test::double::FakeChainReader;
184
185 use super::*;
186
187 const MAX_ROLL_FORWARDS_PER_POLL: usize = 100;
189
190 #[tokio::test]
191 async fn test_parse_expected_nothing_strictly_above_block_number_threshold() {
192 let until_block_number = BlockNumber(10);
193 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
194 ChainBlockNextAction::RollForward {
195 parsed_block: ScannedBlock::new(
196 "hash-1",
197 until_block_number,
198 SlotNumber(100),
199 Vec::<&str>::new(),
200 ),
201 },
202 ChainBlockNextAction::RollForward {
203 parsed_block: ScannedBlock::new(
204 "hash-2",
205 until_block_number,
206 SlotNumber(100),
207 Vec::<&str>::new(),
208 ),
209 },
210 ])));
211 let mut block_streamer = ChainReaderBlockStreamer::try_new(
212 chain_reader.clone(),
213 None,
214 until_block_number - 1,
215 MAX_ROLL_FORWARDS_PER_POLL,
216 TestLogger::stdout(),
217 )
218 .await
219 .unwrap();
220
221 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
222 assert_eq!(None, scanned_blocks);
223 assert_eq!(None, block_streamer.last_polled_point());
224
225 let mut block_streamer = ChainReaderBlockStreamer::try_new(
226 chain_reader,
227 None,
228 until_block_number,
229 MAX_ROLL_FORWARDS_PER_POLL,
230 TestLogger::stdout(),
231 )
232 .await
233 .unwrap();
234
235 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
236 assert_eq!(
237 Some(ChainScannedBlocks::RollForwards(vec![ScannedBlock::new(
238 "hash-2",
239 until_block_number,
240 SlotNumber(100),
241 Vec::<&str>::new(),
242 )])),
243 scanned_blocks
244 );
245 assert_eq!(
246 block_streamer.last_polled_point(),
247 Some(RawCardanoPoint::new(SlotNumber(100), "hash-2"))
248 );
249 }
250
251 #[tokio::test]
252 async fn test_parse_expected_multiple_rollforwards_up_to_block_number_threshold() {
253 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
254 ChainBlockNextAction::RollForward {
255 parsed_block: ScannedBlock::new(
256 "hash-1",
257 BlockNumber(1),
258 SlotNumber(10),
259 Vec::<&str>::new(),
260 ),
261 },
262 ChainBlockNextAction::RollForward {
263 parsed_block: ScannedBlock::new(
264 "hash-2",
265 BlockNumber(2),
266 SlotNumber(20),
267 Vec::<&str>::new(),
268 ),
269 },
270 ChainBlockNextAction::RollForward {
271 parsed_block: ScannedBlock::new(
272 "hash-3",
273 BlockNumber(3),
274 SlotNumber(30),
275 Vec::<&str>::new(),
276 ),
277 },
278 ])));
279 let mut block_streamer = ChainReaderBlockStreamer::try_new(
280 chain_reader.clone(),
281 None,
282 BlockNumber(2),
283 MAX_ROLL_FORWARDS_PER_POLL,
284 TestLogger::stdout(),
285 )
286 .await
287 .unwrap();
288
289 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
290
291 assert_eq!(
292 Some(ChainScannedBlocks::RollForwards(vec![
293 ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
294 ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
295 ])),
296 scanned_blocks,
297 );
298
299 let chain_reader_total_remaining_next_actions =
300 chain_reader.lock().await.get_total_remaining_next_actions();
301 assert_eq!(1, chain_reader_total_remaining_next_actions);
302
303 assert_eq!(
304 block_streamer.last_polled_point(),
305 Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
306 );
307 }
308
309 #[tokio::test]
310 async fn test_parse_expected_all_rollforwards_below_threshold_when_above_highest_block_number()
311 {
312 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
313 ChainBlockNextAction::RollForward {
314 parsed_block: ScannedBlock::new(
315 "hash-1",
316 BlockNumber(1),
317 SlotNumber(10),
318 Vec::<&str>::new(),
319 ),
320 },
321 ChainBlockNextAction::RollForward {
322 parsed_block: ScannedBlock::new(
323 "hash-2",
324 BlockNumber(2),
325 SlotNumber(20),
326 Vec::<&str>::new(),
327 ),
328 },
329 ])));
330 let mut block_streamer = ChainReaderBlockStreamer::try_new(
331 chain_reader.clone(),
332 None,
333 BlockNumber(100),
334 MAX_ROLL_FORWARDS_PER_POLL,
335 TestLogger::stdout(),
336 )
337 .await
338 .unwrap();
339
340 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
341
342 assert_eq!(
343 Some(ChainScannedBlocks::RollForwards(vec![
344 ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
345 ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
346 ])),
347 scanned_blocks,
348 );
349 assert_eq!(
350 block_streamer.last_polled_point(),
351 Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
352 );
353 }
354
355 #[tokio::test]
356 async fn test_parse_expected_maximum_rollforwards_retrieved_per_poll() {
357 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
358 ChainBlockNextAction::RollForward {
359 parsed_block: ScannedBlock::new(
360 "hash-1",
361 BlockNumber(1),
362 SlotNumber(10),
363 Vec::<&str>::new(),
364 ),
365 },
366 ChainBlockNextAction::RollForward {
367 parsed_block: ScannedBlock::new(
368 "hash-2",
369 BlockNumber(2),
370 SlotNumber(20),
371 Vec::<&str>::new(),
372 ),
373 },
374 ChainBlockNextAction::RollForward {
375 parsed_block: ScannedBlock::new(
376 "hash-3",
377 BlockNumber(3),
378 SlotNumber(30),
379 Vec::<&str>::new(),
380 ),
381 },
382 ])));
383 let mut block_streamer = ChainReaderBlockStreamer::try_new(
384 chain_reader,
385 None,
386 BlockNumber(100),
387 MAX_ROLL_FORWARDS_PER_POLL,
388 TestLogger::stdout(),
389 )
390 .await
391 .unwrap();
392 block_streamer.max_roll_forwards_per_poll = 2;
393
394 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
395 assert_eq!(
396 Some(ChainScannedBlocks::RollForwards(vec![
397 ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
398 ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
399 ])),
400 scanned_blocks,
401 );
402 assert_eq!(
403 block_streamer.last_polled_point(),
404 Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
405 );
406
407 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
408 assert_eq!(
409 Some(ChainScannedBlocks::RollForwards(vec![ScannedBlock::new(
410 "hash-3",
411 BlockNumber(3),
412 SlotNumber(30),
413 Vec::<&str>::new()
414 ),])),
415 scanned_blocks,
416 );
417 assert_eq!(
418 block_streamer.last_polled_point(),
419 Some(RawCardanoPoint::new(SlotNumber(30), "hash-3"))
420 );
421
422 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
423 assert_eq!(None, scanned_blocks);
424 assert_eq!(
425 block_streamer.last_polled_point(),
426 Some(RawCardanoPoint::new(SlotNumber(30), "hash-3"))
427 );
428 }
429
430 #[tokio::test]
431 async fn test_parse_expected_nothing_when_rollbackward_on_same_point() {
432 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
433 ChainBlockNextAction::RollBackward {
434 rollback_point: RawCardanoPoint::new(SlotNumber(100), "hash-123"),
435 },
436 ])));
437 let mut block_streamer = ChainReaderBlockStreamer::try_new(
438 chain_reader,
439 Some(RawCardanoPoint::new(SlotNumber(100), "hash-123")),
440 BlockNumber(1),
441 MAX_ROLL_FORWARDS_PER_POLL,
442 TestLogger::stdout(),
443 )
444 .await
445 .unwrap();
446
447 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
448 assert_eq!(None, scanned_blocks);
449 assert_eq!(block_streamer.last_polled_point(), None);
450 }
451
452 #[tokio::test]
453 async fn test_parse_expected_rollbackward_when_on_different_point_and_no_previous_rollforward()
454 {
455 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
456 ChainBlockNextAction::RollBackward {
457 rollback_point: RawCardanoPoint::new(SlotNumber(100), "hash-10"),
458 },
459 ])));
460 let mut block_streamer = ChainReaderBlockStreamer::try_new(
461 chain_reader,
462 None,
463 BlockNumber(1),
464 MAX_ROLL_FORWARDS_PER_POLL,
465 TestLogger::stdout(),
466 )
467 .await
468 .unwrap();
469
470 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
471
472 assert_eq!(
473 Some(ChainScannedBlocks::RollBackward(SlotNumber(100))),
474 scanned_blocks,
475 );
476 assert_eq!(
477 block_streamer.last_polled_point(),
478 Some(RawCardanoPoint::new(SlotNumber(100), "hash-10"))
479 );
480
481 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
482 assert_eq!(None, scanned_blocks);
483 assert_eq!(
484 block_streamer.last_polled_point(),
485 Some(RawCardanoPoint::new(SlotNumber(100), "hash-10"))
486 );
487 }
488
489 #[tokio::test]
490 async fn test_parse_expected_rollforward_when_rollbackward_on_different_point_and_have_previous_rollforwards()
491 {
492 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
493 ChainBlockNextAction::RollForward {
494 parsed_block: ScannedBlock::new(
495 "hash-8",
496 BlockNumber(80),
497 SlotNumber(8),
498 Vec::<&str>::new(),
499 ),
500 },
501 ChainBlockNextAction::RollForward {
502 parsed_block: ScannedBlock::new(
503 "hash-9",
504 BlockNumber(90),
505 SlotNumber(9),
506 Vec::<&str>::new(),
507 ),
508 },
509 ChainBlockNextAction::RollForward {
510 parsed_block: ScannedBlock::new(
511 "hash-10",
512 BlockNumber(100),
513 SlotNumber(10),
514 Vec::<&str>::new(),
515 ),
516 },
517 ChainBlockNextAction::RollBackward {
518 rollback_point: RawCardanoPoint::new(SlotNumber(9), "hash-9"),
519 },
520 ])));
521 let mut block_streamer = ChainReaderBlockStreamer::try_new(
522 chain_reader,
523 None,
524 BlockNumber(1000),
525 MAX_ROLL_FORWARDS_PER_POLL,
526 TestLogger::stdout(),
527 )
528 .await
529 .unwrap();
530
531 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
532
533 assert_eq!(
534 Some(ChainScannedBlocks::RollForwards(vec![
535 ScannedBlock::new("hash-8", BlockNumber(80), SlotNumber(8), Vec::<&str>::new()),
536 ScannedBlock::new("hash-9", BlockNumber(90), SlotNumber(9), Vec::<&str>::new())
537 ])),
538 scanned_blocks,
539 );
540 assert_eq!(
541 block_streamer.last_polled_point(),
542 Some(RawCardanoPoint::new(SlotNumber(9), "hash-9"))
543 );
544 }
545
546 #[tokio::test]
547 async fn test_parse_expected_backward_when_rollbackward_on_different_point_and_does_not_have_previous_rollforwards()
548 {
549 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
550 ChainBlockNextAction::RollForward {
551 parsed_block: ScannedBlock::new(
552 "hash-8",
553 BlockNumber(80),
554 SlotNumber(8),
555 Vec::<&str>::new(),
556 ),
557 },
558 ChainBlockNextAction::RollForward {
559 parsed_block: ScannedBlock::new(
560 "hash-9",
561 BlockNumber(90),
562 SlotNumber(9),
563 Vec::<&str>::new(),
564 ),
565 },
566 ChainBlockNextAction::RollBackward {
567 rollback_point: RawCardanoPoint::new(SlotNumber(3), "hash-3"),
568 },
569 ])));
570 let mut block_streamer = ChainReaderBlockStreamer::try_new(
571 chain_reader,
572 None,
573 BlockNumber(1000),
574 MAX_ROLL_FORWARDS_PER_POLL,
575 TestLogger::stdout(),
576 )
577 .await
578 .unwrap();
579
580 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
581
582 assert_eq!(
583 Some(ChainScannedBlocks::RollBackward(SlotNumber(3))),
584 scanned_blocks,
585 );
586 assert_eq!(
587 block_streamer.last_polled_point(),
588 Some(RawCardanoPoint::new(SlotNumber(3), "hash-3"))
589 );
590 }
591
592 #[tokio::test]
593 async fn test_parse_expected_nothing() {
594 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![])));
595 let mut block_streamer = ChainReaderBlockStreamer::try_new(
596 chain_reader,
597 None,
598 BlockNumber(1),
599 MAX_ROLL_FORWARDS_PER_POLL,
600 TestLogger::stdout(),
601 )
602 .await
603 .unwrap();
604
605 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
606
607 assert_eq!(scanned_blocks, None);
608 }
609
610 #[tokio::test]
611 async fn test_last_polled_point_is_none_if_nothing_was_polled() {
612 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![])));
613 let block_streamer = ChainReaderBlockStreamer::try_new(
614 chain_reader,
615 None,
616 BlockNumber(1),
617 MAX_ROLL_FORWARDS_PER_POLL,
618 TestLogger::stdout(),
619 )
620 .await
621 .unwrap();
622
623 assert_eq!(block_streamer.last_polled_point(), None);
624 }
625}