mithril_common/cardano_block_scanner/
chain_reader_block_streamer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use slog::{debug, trace, Logger};
5use tokio::sync::Mutex;
6
7use crate::cardano_block_scanner::{BlockStreamer, ChainScannedBlocks, RawCardanoPoint};
8use crate::chain_reader::{ChainBlockNextAction, ChainBlockReader};
9use crate::entities::BlockNumber;
10use crate::logging::LoggerExtensions;
11use crate::StdResult;
12
13/// The action that indicates what to do next with the streamer
14#[derive(Debug, Clone, PartialEq)]
15enum BlockStreamerNextAction {
16    /// Use a [ChainBlockNextAction]
17    ChainBlockNextAction(ChainBlockNextAction),
18    /// Skip to the next action
19    SkipToNextAction,
20}
21
22/// [Block streamer][BlockStreamer] that streams blocks with a [Chain block reader][ChainBlockReader]
23pub struct ChainReaderBlockStreamer {
24    chain_reader: Arc<Mutex<dyn ChainBlockReader>>,
25    from: RawCardanoPoint,
26    until: BlockNumber,
27    max_roll_forwards_per_poll: usize,
28    last_polled_point: Option<RawCardanoPoint>,
29    logger: Logger,
30}
31
32#[async_trait]
33impl BlockStreamer for ChainReaderBlockStreamer {
34    async fn poll_next(&mut self) -> StdResult<Option<ChainScannedBlocks>> {
35        debug!(self.logger, ">> poll_next");
36
37        let chain_scanned_blocks: ChainScannedBlocks;
38        let mut roll_forwards = vec![];
39        loop {
40            let block_streamer_next_action = self.get_next_chain_block_action().await?;
41            match block_streamer_next_action {
42                Some(BlockStreamerNextAction::ChainBlockNextAction(
43                    ChainBlockNextAction::RollForward { parsed_block },
44                )) => {
45                    self.last_polled_point = Some(RawCardanoPoint::from(&parsed_block));
46                    let parsed_block_number = parsed_block.block_number;
47                    roll_forwards.push(parsed_block);
48                    if roll_forwards.len() >= self.max_roll_forwards_per_poll
49                        || parsed_block_number >= self.until
50                    {
51                        return Ok(Some(ChainScannedBlocks::RollForwards(roll_forwards)));
52                    }
53                }
54                Some(BlockStreamerNextAction::ChainBlockNextAction(
55                    ChainBlockNextAction::RollBackward { rollback_point },
56                )) => {
57                    self.last_polled_point = Some(rollback_point.clone());
58                    let rollback_slot_number = rollback_point.slot_number;
59                    let index_rollback = roll_forwards
60                        .iter()
61                        .position(|block| block.slot_number == rollback_slot_number);
62                    match index_rollback {
63                        Some(index_rollback) => {
64                            debug!(
65                                self.logger,
66                                "ChainScannedBlocks handled a buffer RollBackward({rollback_slot_number:?})"
67                            );
68                            roll_forwards.truncate(index_rollback + 1);
69                        }
70                        None => {
71                            debug!(
72                                self.logger,
73                                "ChainScannedBlocks triggered a full RollBackward({rollback_slot_number:?})"
74                            );
75                            chain_scanned_blocks =
76                                ChainScannedBlocks::RollBackward(rollback_slot_number);
77                            return Ok(Some(chain_scanned_blocks));
78                        }
79                    }
80                }
81                Some(BlockStreamerNextAction::SkipToNextAction) => {
82                    continue;
83                }
84                None => {
85                    return if roll_forwards.is_empty() {
86                        Ok(None)
87                    } else {
88                        chain_scanned_blocks = ChainScannedBlocks::RollForwards(roll_forwards);
89                        Ok(Some(chain_scanned_blocks))
90                    }
91                }
92            }
93        }
94    }
95
96    fn last_polled_point(&self) -> Option<RawCardanoPoint> {
97        self.last_polled_point.clone()
98    }
99}
100
101impl ChainReaderBlockStreamer {
102    /// Factory
103    pub async fn try_new(
104        chain_reader: Arc<Mutex<dyn ChainBlockReader>>,
105        from: Option<RawCardanoPoint>,
106        until: BlockNumber,
107        max_roll_forwards_per_poll: usize,
108        logger: Logger,
109    ) -> StdResult<Self> {
110        let from = from.unwrap_or(RawCardanoPoint::origin());
111        {
112            let mut chain_reader_inner = chain_reader.try_lock()?;
113            chain_reader_inner.set_chain_point(&from).await?;
114        }
115        Ok(Self {
116            chain_reader,
117            from,
118            until,
119            max_roll_forwards_per_poll,
120            last_polled_point: None,
121            logger: logger.new_with_component_name::<Self>(),
122        })
123    }
124
125    async fn get_next_chain_block_action(&self) -> StdResult<Option<BlockStreamerNextAction>> {
126        let mut chain_reader = self.chain_reader.try_lock()?;
127        match chain_reader.get_next_chain_block().await? {
128            Some(ChainBlockNextAction::RollForward { parsed_block }) => {
129                if parsed_block.block_number > self.until {
130                    trace!(
131                        self.logger,
132                        "Received a RollForward above threshold block number ({})",
133                        parsed_block.block_number
134                    );
135                    Ok(None)
136                } else {
137                    trace!(
138                        self.logger,
139                        "Received a RollForward below threshold block number ({})",
140                        parsed_block.block_number
141                    );
142                    Ok(Some(BlockStreamerNextAction::ChainBlockNextAction(
143                        ChainBlockNextAction::RollForward { parsed_block },
144                    )))
145                }
146            }
147            Some(ChainBlockNextAction::RollBackward { rollback_point }) => {
148                let rollback_slot_number = rollback_point.slot_number;
149                trace!(
150                    self.logger,
151                    "Received a RollBackward({rollback_slot_number:?})"
152                );
153                let block_streamer_next_action = if rollback_slot_number == self.from.slot_number {
154                    BlockStreamerNextAction::SkipToNextAction
155                } else {
156                    BlockStreamerNextAction::ChainBlockNextAction(
157                        ChainBlockNextAction::RollBackward { rollback_point },
158                    )
159                };
160                Ok(Some(block_streamer_next_action))
161            }
162            None => {
163                trace!(self.logger, "Received nothing");
164                Ok(None)
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use crate::cardano_block_scanner::ScannedBlock;
173    use crate::chain_reader::FakeChainReader;
174    use crate::entities::SlotNumber;
175    use crate::test_utils::TestLogger;
176
177    use super::*;
178
179    /// The maximum number of roll forwards during a poll
180    const MAX_ROLL_FORWARDS_PER_POLL: usize = 100;
181
182    #[tokio::test]
183    async fn test_parse_expected_nothing_strictly_above_block_number_threshold() {
184        let until_block_number = BlockNumber(10);
185        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
186            ChainBlockNextAction::RollForward {
187                parsed_block: ScannedBlock::new(
188                    "hash-1",
189                    until_block_number,
190                    SlotNumber(100),
191                    Vec::<&str>::new(),
192                ),
193            },
194            ChainBlockNextAction::RollForward {
195                parsed_block: ScannedBlock::new(
196                    "hash-2",
197                    until_block_number,
198                    SlotNumber(100),
199                    Vec::<&str>::new(),
200                ),
201            },
202        ])));
203        let mut block_streamer = ChainReaderBlockStreamer::try_new(
204            chain_reader.clone(),
205            None,
206            until_block_number - 1,
207            MAX_ROLL_FORWARDS_PER_POLL,
208            TestLogger::stdout(),
209        )
210        .await
211        .unwrap();
212
213        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
214        assert_eq!(None, scanned_blocks);
215        assert_eq!(None, block_streamer.last_polled_point());
216
217        let mut block_streamer = ChainReaderBlockStreamer::try_new(
218            chain_reader,
219            None,
220            until_block_number,
221            MAX_ROLL_FORWARDS_PER_POLL,
222            TestLogger::stdout(),
223        )
224        .await
225        .unwrap();
226
227        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
228        assert_eq!(
229            Some(ChainScannedBlocks::RollForwards(vec![ScannedBlock::new(
230                "hash-2",
231                until_block_number,
232                SlotNumber(100),
233                Vec::<&str>::new(),
234            )])),
235            scanned_blocks
236        );
237        assert_eq!(
238            block_streamer.last_polled_point(),
239            Some(RawCardanoPoint::new(SlotNumber(100), "hash-2"))
240        );
241    }
242
243    #[tokio::test]
244    async fn test_parse_expected_multiple_rollforwards_up_to_block_number_threshold() {
245        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
246            ChainBlockNextAction::RollForward {
247                parsed_block: ScannedBlock::new(
248                    "hash-1",
249                    BlockNumber(1),
250                    SlotNumber(10),
251                    Vec::<&str>::new(),
252                ),
253            },
254            ChainBlockNextAction::RollForward {
255                parsed_block: ScannedBlock::new(
256                    "hash-2",
257                    BlockNumber(2),
258                    SlotNumber(20),
259                    Vec::<&str>::new(),
260                ),
261            },
262            ChainBlockNextAction::RollForward {
263                parsed_block: ScannedBlock::new(
264                    "hash-3",
265                    BlockNumber(3),
266                    SlotNumber(30),
267                    Vec::<&str>::new(),
268                ),
269            },
270        ])));
271        let mut block_streamer = ChainReaderBlockStreamer::try_new(
272            chain_reader.clone(),
273            None,
274            BlockNumber(2),
275            MAX_ROLL_FORWARDS_PER_POLL,
276            TestLogger::stdout(),
277        )
278        .await
279        .unwrap();
280
281        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
282
283        assert_eq!(
284            Some(ChainScannedBlocks::RollForwards(vec![
285                ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
286                ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
287            ])),
288            scanned_blocks,
289        );
290
291        let chain_reader_total_remaining_next_actions =
292            chain_reader.lock().await.get_total_remaining_next_actions();
293        assert_eq!(1, chain_reader_total_remaining_next_actions);
294
295        assert_eq!(
296            block_streamer.last_polled_point(),
297            Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
298        );
299    }
300
301    #[tokio::test]
302    async fn test_parse_expected_all_rollforwards_below_threshold_when_above_highest_block_number()
303    {
304        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
305            ChainBlockNextAction::RollForward {
306                parsed_block: ScannedBlock::new(
307                    "hash-1",
308                    BlockNumber(1),
309                    SlotNumber(10),
310                    Vec::<&str>::new(),
311                ),
312            },
313            ChainBlockNextAction::RollForward {
314                parsed_block: ScannedBlock::new(
315                    "hash-2",
316                    BlockNumber(2),
317                    SlotNumber(20),
318                    Vec::<&str>::new(),
319                ),
320            },
321        ])));
322        let mut block_streamer = ChainReaderBlockStreamer::try_new(
323            chain_reader.clone(),
324            None,
325            BlockNumber(100),
326            MAX_ROLL_FORWARDS_PER_POLL,
327            TestLogger::stdout(),
328        )
329        .await
330        .unwrap();
331
332        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
333
334        assert_eq!(
335            Some(ChainScannedBlocks::RollForwards(vec![
336                ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
337                ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
338            ])),
339            scanned_blocks,
340        );
341        assert_eq!(
342            block_streamer.last_polled_point(),
343            Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
344        );
345    }
346
347    #[tokio::test]
348    async fn test_parse_expected_maximum_rollforwards_retrieved_per_poll() {
349        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
350            ChainBlockNextAction::RollForward {
351                parsed_block: ScannedBlock::new(
352                    "hash-1",
353                    BlockNumber(1),
354                    SlotNumber(10),
355                    Vec::<&str>::new(),
356                ),
357            },
358            ChainBlockNextAction::RollForward {
359                parsed_block: ScannedBlock::new(
360                    "hash-2",
361                    BlockNumber(2),
362                    SlotNumber(20),
363                    Vec::<&str>::new(),
364                ),
365            },
366            ChainBlockNextAction::RollForward {
367                parsed_block: ScannedBlock::new(
368                    "hash-3",
369                    BlockNumber(3),
370                    SlotNumber(30),
371                    Vec::<&str>::new(),
372                ),
373            },
374        ])));
375        let mut block_streamer = ChainReaderBlockStreamer::try_new(
376            chain_reader,
377            None,
378            BlockNumber(100),
379            MAX_ROLL_FORWARDS_PER_POLL,
380            TestLogger::stdout(),
381        )
382        .await
383        .unwrap();
384        block_streamer.max_roll_forwards_per_poll = 2;
385
386        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
387        assert_eq!(
388            Some(ChainScannedBlocks::RollForwards(vec![
389                ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
390                ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
391            ])),
392            scanned_blocks,
393        );
394        assert_eq!(
395            block_streamer.last_polled_point(),
396            Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
397        );
398
399        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
400        assert_eq!(
401            Some(ChainScannedBlocks::RollForwards(vec![ScannedBlock::new(
402                "hash-3",
403                BlockNumber(3),
404                SlotNumber(30),
405                Vec::<&str>::new()
406            ),])),
407            scanned_blocks,
408        );
409        assert_eq!(
410            block_streamer.last_polled_point(),
411            Some(RawCardanoPoint::new(SlotNumber(30), "hash-3"))
412        );
413
414        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
415        assert_eq!(None, scanned_blocks);
416        assert_eq!(
417            block_streamer.last_polled_point(),
418            Some(RawCardanoPoint::new(SlotNumber(30), "hash-3"))
419        );
420    }
421
422    #[tokio::test]
423    async fn test_parse_expected_nothing_when_rollbackward_on_same_point() {
424        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
425            ChainBlockNextAction::RollBackward {
426                rollback_point: RawCardanoPoint::new(SlotNumber(100), "hash-123"),
427            },
428        ])));
429        let mut block_streamer = ChainReaderBlockStreamer::try_new(
430            chain_reader,
431            Some(RawCardanoPoint::new(SlotNumber(100), "hash-123")),
432            BlockNumber(1),
433            MAX_ROLL_FORWARDS_PER_POLL,
434            TestLogger::stdout(),
435        )
436        .await
437        .unwrap();
438
439        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
440        assert_eq!(None, scanned_blocks);
441        assert_eq!(block_streamer.last_polled_point(), None);
442    }
443
444    #[tokio::test]
445    async fn test_parse_expected_rollbackward_when_on_different_point_and_no_previous_rollforward()
446    {
447        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
448            ChainBlockNextAction::RollBackward {
449                rollback_point: RawCardanoPoint::new(SlotNumber(100), "hash-10"),
450            },
451        ])));
452        let mut block_streamer = ChainReaderBlockStreamer::try_new(
453            chain_reader,
454            None,
455            BlockNumber(1),
456            MAX_ROLL_FORWARDS_PER_POLL,
457            TestLogger::stdout(),
458        )
459        .await
460        .unwrap();
461
462        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
463
464        assert_eq!(
465            Some(ChainScannedBlocks::RollBackward(SlotNumber(100))),
466            scanned_blocks,
467        );
468        assert_eq!(
469            block_streamer.last_polled_point(),
470            Some(RawCardanoPoint::new(SlotNumber(100), "hash-10"))
471        );
472
473        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
474        assert_eq!(None, scanned_blocks);
475        assert_eq!(
476            block_streamer.last_polled_point(),
477            Some(RawCardanoPoint::new(SlotNumber(100), "hash-10"))
478        );
479    }
480
481    #[tokio::test]
482    async fn test_parse_expected_rollforward_when_rollbackward_on_different_point_and_have_previous_rollforwards(
483    ) {
484        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
485            ChainBlockNextAction::RollForward {
486                parsed_block: ScannedBlock::new(
487                    "hash-8",
488                    BlockNumber(80),
489                    SlotNumber(8),
490                    Vec::<&str>::new(),
491                ),
492            },
493            ChainBlockNextAction::RollForward {
494                parsed_block: ScannedBlock::new(
495                    "hash-9",
496                    BlockNumber(90),
497                    SlotNumber(9),
498                    Vec::<&str>::new(),
499                ),
500            },
501            ChainBlockNextAction::RollForward {
502                parsed_block: ScannedBlock::new(
503                    "hash-10",
504                    BlockNumber(100),
505                    SlotNumber(10),
506                    Vec::<&str>::new(),
507                ),
508            },
509            ChainBlockNextAction::RollBackward {
510                rollback_point: RawCardanoPoint::new(SlotNumber(9), "hash-9"),
511            },
512        ])));
513        let mut block_streamer = ChainReaderBlockStreamer::try_new(
514            chain_reader,
515            None,
516            BlockNumber(1000),
517            MAX_ROLL_FORWARDS_PER_POLL,
518            TestLogger::stdout(),
519        )
520        .await
521        .unwrap();
522
523        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
524
525        assert_eq!(
526            Some(ChainScannedBlocks::RollForwards(vec![
527                ScannedBlock::new("hash-8", BlockNumber(80), SlotNumber(8), Vec::<&str>::new()),
528                ScannedBlock::new("hash-9", BlockNumber(90), SlotNumber(9), Vec::<&str>::new())
529            ])),
530            scanned_blocks,
531        );
532        assert_eq!(
533            block_streamer.last_polled_point(),
534            Some(RawCardanoPoint::new(SlotNumber(9), "hash-9"))
535        );
536    }
537
538    #[tokio::test]
539    async fn test_parse_expected_backward_when_rollbackward_on_different_point_and_does_not_have_previous_rollforwards(
540    ) {
541        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
542            ChainBlockNextAction::RollForward {
543                parsed_block: ScannedBlock::new(
544                    "hash-8",
545                    BlockNumber(80),
546                    SlotNumber(8),
547                    Vec::<&str>::new(),
548                ),
549            },
550            ChainBlockNextAction::RollForward {
551                parsed_block: ScannedBlock::new(
552                    "hash-9",
553                    BlockNumber(90),
554                    SlotNumber(9),
555                    Vec::<&str>::new(),
556                ),
557            },
558            ChainBlockNextAction::RollBackward {
559                rollback_point: RawCardanoPoint::new(SlotNumber(3), "hash-3"),
560            },
561        ])));
562        let mut block_streamer = ChainReaderBlockStreamer::try_new(
563            chain_reader,
564            None,
565            BlockNumber(1000),
566            MAX_ROLL_FORWARDS_PER_POLL,
567            TestLogger::stdout(),
568        )
569        .await
570        .unwrap();
571
572        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
573
574        assert_eq!(
575            Some(ChainScannedBlocks::RollBackward(SlotNumber(3))),
576            scanned_blocks,
577        );
578        assert_eq!(
579            block_streamer.last_polled_point(),
580            Some(RawCardanoPoint::new(SlotNumber(3), "hash-3"))
581        );
582    }
583
584    #[tokio::test]
585    async fn test_parse_expected_nothing() {
586        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![])));
587        let mut block_streamer = ChainReaderBlockStreamer::try_new(
588            chain_reader,
589            None,
590            BlockNumber(1),
591            MAX_ROLL_FORWARDS_PER_POLL,
592            TestLogger::stdout(),
593        )
594        .await
595        .unwrap();
596
597        let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
598
599        assert_eq!(scanned_blocks, None);
600    }
601
602    #[tokio::test]
603    async fn test_last_polled_point_is_none_if_nothing_was_polled() {
604        let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![])));
605        let block_streamer = ChainReaderBlockStreamer::try_new(
606            chain_reader,
607            None,
608            BlockNumber(1),
609            MAX_ROLL_FORWARDS_PER_POLL,
610            TestLogger::stdout(),
611        )
612        .await
613        .unwrap();
614
615        assert_eq!(block_streamer.last_polled_point(), None);
616    }
617}