1use std::sync::Arc;
2
3use async_trait::async_trait;
4use slog::{Logger, debug, trace};
5use tokio::sync::Mutex;
6
7use mithril_common::StdResult;
8use mithril_common::entities::BlockNumber;
9use mithril_common::logging::LoggerExtensions;
10
11use crate::chain_reader::ChainBlockReader;
12use crate::chain_scanner::{BlockStreamer, ChainScannedBlocks};
13use crate::entities::{ChainBlockNextAction, RawCardanoPoint};
14
15#[derive(Debug, Clone, PartialEq)]
17enum BlockStreamerNextAction {
18 ChainBlockNextAction(ChainBlockNextAction),
20 SkipToNextAction,
22}
23
24pub struct ChainReaderBlockStreamer {
26 chain_reader: Arc<Mutex<dyn ChainBlockReader>>,
27 from: RawCardanoPoint,
28 until: BlockNumber,
29 max_roll_forwards_per_poll: usize,
30 last_polled_point: Option<RawCardanoPoint>,
31 logger: Logger,
32}
33
34#[async_trait]
35impl BlockStreamer for ChainReaderBlockStreamer {
36 async fn poll_next(&mut self) -> StdResult<Option<ChainScannedBlocks>> {
37 debug!(self.logger, ">> poll_next");
38
39 let chain_scanned_blocks: ChainScannedBlocks;
40 let mut roll_forwards = vec![];
41 loop {
42 let block_streamer_next_action = self.get_next_chain_block_action().await?;
43 match block_streamer_next_action {
44 Some(BlockStreamerNextAction::ChainBlockNextAction(
45 ChainBlockNextAction::RollForward { parsed_block },
46 )) => {
47 self.last_polled_point = Some(RawCardanoPoint::from(&parsed_block));
48 let parsed_block_number = parsed_block.block_number;
49 roll_forwards.push(parsed_block);
50 if roll_forwards.len() >= self.max_roll_forwards_per_poll
51 || parsed_block_number >= self.until
52 {
53 return Ok(Some(ChainScannedBlocks::RollForwards(roll_forwards)));
54 }
55 }
56 Some(BlockStreamerNextAction::ChainBlockNextAction(
57 ChainBlockNextAction::RollBackward { rollback_point },
58 )) => {
59 self.last_polled_point = Some(rollback_point.clone());
60 let rollback_slot_number = rollback_point.slot_number;
61 let index_rollback = roll_forwards
62 .iter()
63 .position(|block| block.slot_number == rollback_slot_number);
64 match index_rollback {
65 Some(index_rollback) => {
66 debug!(
67 self.logger,
68 "ChainScannedBlocks handled a buffer RollBackward({rollback_slot_number:?})"
69 );
70 roll_forwards.truncate(index_rollback + 1);
71 }
72 None => {
73 debug!(
74 self.logger,
75 "ChainScannedBlocks triggered a full RollBackward({rollback_slot_number:?})"
76 );
77 chain_scanned_blocks =
78 ChainScannedBlocks::RollBackward(rollback_slot_number);
79 return Ok(Some(chain_scanned_blocks));
80 }
81 }
82 }
83 Some(BlockStreamerNextAction::SkipToNextAction) => {
84 continue;
85 }
86 None => {
87 return if roll_forwards.is_empty() {
88 Ok(None)
89 } else {
90 chain_scanned_blocks = ChainScannedBlocks::RollForwards(roll_forwards);
91 Ok(Some(chain_scanned_blocks))
92 };
93 }
94 }
95 }
96 }
97
98 fn last_polled_point(&self) -> Option<RawCardanoPoint> {
99 self.last_polled_point.clone()
100 }
101}
102
103impl ChainReaderBlockStreamer {
104 pub async fn try_new(
106 chain_reader: Arc<Mutex<dyn ChainBlockReader>>,
107 from: Option<RawCardanoPoint>,
108 until: BlockNumber,
109 max_roll_forwards_per_poll: usize,
110 logger: Logger,
111 ) -> StdResult<Self> {
112 let from = from.unwrap_or(RawCardanoPoint::origin());
113 {
114 let mut chain_reader_inner = chain_reader.try_lock()?;
115 chain_reader_inner.set_chain_point(&from).await?;
116 }
117 Ok(Self {
118 chain_reader,
119 from,
120 until,
121 max_roll_forwards_per_poll,
122 last_polled_point: None,
123 logger: logger.new_with_component_name::<Self>(),
124 })
125 }
126
127 async fn get_next_chain_block_action(&self) -> StdResult<Option<BlockStreamerNextAction>> {
128 let mut chain_reader = self.chain_reader.try_lock()?;
129 match chain_reader.get_next_chain_block().await? {
130 Some(ChainBlockNextAction::RollForward { parsed_block }) => {
131 if parsed_block.block_number > self.until {
132 trace!(
133 self.logger,
134 "Received a RollForward above threshold block number ({})",
135 parsed_block.block_number
136 );
137 Ok(None)
138 } else {
139 trace!(
140 self.logger,
141 "Received a RollForward below threshold block number ({})",
142 parsed_block.block_number
143 );
144 Ok(Some(BlockStreamerNextAction::ChainBlockNextAction(
145 ChainBlockNextAction::RollForward { parsed_block },
146 )))
147 }
148 }
149 Some(ChainBlockNextAction::RollBackward { rollback_point }) => {
150 let rollback_slot_number = rollback_point.slot_number;
151 trace!(
152 self.logger,
153 "Received a RollBackward({rollback_slot_number:?})"
154 );
155 let block_streamer_next_action = if rollback_slot_number == self.from.slot_number {
156 BlockStreamerNextAction::SkipToNextAction
157 } else {
158 BlockStreamerNextAction::ChainBlockNextAction(
159 ChainBlockNextAction::RollBackward { rollback_point },
160 )
161 };
162 Ok(Some(block_streamer_next_action))
163 }
164 None => {
165 trace!(self.logger, "Received nothing");
166 Ok(None)
167 }
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use mithril_common::entities::SlotNumber;
175
176 use crate::entities::ScannedBlock;
177 use crate::test::TestLogger;
178 use crate::test::double::FakeChainReader;
179
180 use super::*;
181
182 const MAX_ROLL_FORWARDS_PER_POLL: usize = 100;
184
185 #[tokio::test]
186 async fn test_parse_expected_nothing_strictly_above_block_number_threshold() {
187 let until_block_number = BlockNumber(10);
188 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
189 ChainBlockNextAction::RollForward {
190 parsed_block: ScannedBlock::new(
191 "hash-1",
192 until_block_number,
193 SlotNumber(100),
194 Vec::<&str>::new(),
195 ),
196 },
197 ChainBlockNextAction::RollForward {
198 parsed_block: ScannedBlock::new(
199 "hash-2",
200 until_block_number,
201 SlotNumber(100),
202 Vec::<&str>::new(),
203 ),
204 },
205 ])));
206 let mut block_streamer = ChainReaderBlockStreamer::try_new(
207 chain_reader.clone(),
208 None,
209 until_block_number - 1,
210 MAX_ROLL_FORWARDS_PER_POLL,
211 TestLogger::stdout(),
212 )
213 .await
214 .unwrap();
215
216 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
217 assert_eq!(None, scanned_blocks);
218 assert_eq!(None, block_streamer.last_polled_point());
219
220 let mut block_streamer = ChainReaderBlockStreamer::try_new(
221 chain_reader,
222 None,
223 until_block_number,
224 MAX_ROLL_FORWARDS_PER_POLL,
225 TestLogger::stdout(),
226 )
227 .await
228 .unwrap();
229
230 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
231 assert_eq!(
232 Some(ChainScannedBlocks::RollForwards(vec![ScannedBlock::new(
233 "hash-2",
234 until_block_number,
235 SlotNumber(100),
236 Vec::<&str>::new(),
237 )])),
238 scanned_blocks
239 );
240 assert_eq!(
241 block_streamer.last_polled_point(),
242 Some(RawCardanoPoint::new(SlotNumber(100), "hash-2"))
243 );
244 }
245
246 #[tokio::test]
247 async fn test_parse_expected_multiple_rollforwards_up_to_block_number_threshold() {
248 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
249 ChainBlockNextAction::RollForward {
250 parsed_block: ScannedBlock::new(
251 "hash-1",
252 BlockNumber(1),
253 SlotNumber(10),
254 Vec::<&str>::new(),
255 ),
256 },
257 ChainBlockNextAction::RollForward {
258 parsed_block: ScannedBlock::new(
259 "hash-2",
260 BlockNumber(2),
261 SlotNumber(20),
262 Vec::<&str>::new(),
263 ),
264 },
265 ChainBlockNextAction::RollForward {
266 parsed_block: ScannedBlock::new(
267 "hash-3",
268 BlockNumber(3),
269 SlotNumber(30),
270 Vec::<&str>::new(),
271 ),
272 },
273 ])));
274 let mut block_streamer = ChainReaderBlockStreamer::try_new(
275 chain_reader.clone(),
276 None,
277 BlockNumber(2),
278 MAX_ROLL_FORWARDS_PER_POLL,
279 TestLogger::stdout(),
280 )
281 .await
282 .unwrap();
283
284 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
285
286 assert_eq!(
287 Some(ChainScannedBlocks::RollForwards(vec![
288 ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
289 ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
290 ])),
291 scanned_blocks,
292 );
293
294 let chain_reader_total_remaining_next_actions =
295 chain_reader.lock().await.get_total_remaining_next_actions();
296 assert_eq!(1, chain_reader_total_remaining_next_actions);
297
298 assert_eq!(
299 block_streamer.last_polled_point(),
300 Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
301 );
302 }
303
304 #[tokio::test]
305 async fn test_parse_expected_all_rollforwards_below_threshold_when_above_highest_block_number()
306 {
307 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
308 ChainBlockNextAction::RollForward {
309 parsed_block: ScannedBlock::new(
310 "hash-1",
311 BlockNumber(1),
312 SlotNumber(10),
313 Vec::<&str>::new(),
314 ),
315 },
316 ChainBlockNextAction::RollForward {
317 parsed_block: ScannedBlock::new(
318 "hash-2",
319 BlockNumber(2),
320 SlotNumber(20),
321 Vec::<&str>::new(),
322 ),
323 },
324 ])));
325 let mut block_streamer = ChainReaderBlockStreamer::try_new(
326 chain_reader.clone(),
327 None,
328 BlockNumber(100),
329 MAX_ROLL_FORWARDS_PER_POLL,
330 TestLogger::stdout(),
331 )
332 .await
333 .unwrap();
334
335 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
336
337 assert_eq!(
338 Some(ChainScannedBlocks::RollForwards(vec![
339 ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
340 ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
341 ])),
342 scanned_blocks,
343 );
344 assert_eq!(
345 block_streamer.last_polled_point(),
346 Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
347 );
348 }
349
350 #[tokio::test]
351 async fn test_parse_expected_maximum_rollforwards_retrieved_per_poll() {
352 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
353 ChainBlockNextAction::RollForward {
354 parsed_block: ScannedBlock::new(
355 "hash-1",
356 BlockNumber(1),
357 SlotNumber(10),
358 Vec::<&str>::new(),
359 ),
360 },
361 ChainBlockNextAction::RollForward {
362 parsed_block: ScannedBlock::new(
363 "hash-2",
364 BlockNumber(2),
365 SlotNumber(20),
366 Vec::<&str>::new(),
367 ),
368 },
369 ChainBlockNextAction::RollForward {
370 parsed_block: ScannedBlock::new(
371 "hash-3",
372 BlockNumber(3),
373 SlotNumber(30),
374 Vec::<&str>::new(),
375 ),
376 },
377 ])));
378 let mut block_streamer = ChainReaderBlockStreamer::try_new(
379 chain_reader,
380 None,
381 BlockNumber(100),
382 MAX_ROLL_FORWARDS_PER_POLL,
383 TestLogger::stdout(),
384 )
385 .await
386 .unwrap();
387 block_streamer.max_roll_forwards_per_poll = 2;
388
389 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
390 assert_eq!(
391 Some(ChainScannedBlocks::RollForwards(vec![
392 ScannedBlock::new("hash-1", BlockNumber(1), SlotNumber(10), Vec::<&str>::new()),
393 ScannedBlock::new("hash-2", BlockNumber(2), SlotNumber(20), Vec::<&str>::new())
394 ])),
395 scanned_blocks,
396 );
397 assert_eq!(
398 block_streamer.last_polled_point(),
399 Some(RawCardanoPoint::new(SlotNumber(20), "hash-2"))
400 );
401
402 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
403 assert_eq!(
404 Some(ChainScannedBlocks::RollForwards(vec![ScannedBlock::new(
405 "hash-3",
406 BlockNumber(3),
407 SlotNumber(30),
408 Vec::<&str>::new()
409 ),])),
410 scanned_blocks,
411 );
412 assert_eq!(
413 block_streamer.last_polled_point(),
414 Some(RawCardanoPoint::new(SlotNumber(30), "hash-3"))
415 );
416
417 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
418 assert_eq!(None, scanned_blocks);
419 assert_eq!(
420 block_streamer.last_polled_point(),
421 Some(RawCardanoPoint::new(SlotNumber(30), "hash-3"))
422 );
423 }
424
425 #[tokio::test]
426 async fn test_parse_expected_nothing_when_rollbackward_on_same_point() {
427 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
428 ChainBlockNextAction::RollBackward {
429 rollback_point: RawCardanoPoint::new(SlotNumber(100), "hash-123"),
430 },
431 ])));
432 let mut block_streamer = ChainReaderBlockStreamer::try_new(
433 chain_reader,
434 Some(RawCardanoPoint::new(SlotNumber(100), "hash-123")),
435 BlockNumber(1),
436 MAX_ROLL_FORWARDS_PER_POLL,
437 TestLogger::stdout(),
438 )
439 .await
440 .unwrap();
441
442 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
443 assert_eq!(None, scanned_blocks);
444 assert_eq!(block_streamer.last_polled_point(), None);
445 }
446
447 #[tokio::test]
448 async fn test_parse_expected_rollbackward_when_on_different_point_and_no_previous_rollforward()
449 {
450 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
451 ChainBlockNextAction::RollBackward {
452 rollback_point: RawCardanoPoint::new(SlotNumber(100), "hash-10"),
453 },
454 ])));
455 let mut block_streamer = ChainReaderBlockStreamer::try_new(
456 chain_reader,
457 None,
458 BlockNumber(1),
459 MAX_ROLL_FORWARDS_PER_POLL,
460 TestLogger::stdout(),
461 )
462 .await
463 .unwrap();
464
465 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
466
467 assert_eq!(
468 Some(ChainScannedBlocks::RollBackward(SlotNumber(100))),
469 scanned_blocks,
470 );
471 assert_eq!(
472 block_streamer.last_polled_point(),
473 Some(RawCardanoPoint::new(SlotNumber(100), "hash-10"))
474 );
475
476 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
477 assert_eq!(None, scanned_blocks);
478 assert_eq!(
479 block_streamer.last_polled_point(),
480 Some(RawCardanoPoint::new(SlotNumber(100), "hash-10"))
481 );
482 }
483
484 #[tokio::test]
485 async fn test_parse_expected_rollforward_when_rollbackward_on_different_point_and_have_previous_rollforwards()
486 {
487 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
488 ChainBlockNextAction::RollForward {
489 parsed_block: ScannedBlock::new(
490 "hash-8",
491 BlockNumber(80),
492 SlotNumber(8),
493 Vec::<&str>::new(),
494 ),
495 },
496 ChainBlockNextAction::RollForward {
497 parsed_block: ScannedBlock::new(
498 "hash-9",
499 BlockNumber(90),
500 SlotNumber(9),
501 Vec::<&str>::new(),
502 ),
503 },
504 ChainBlockNextAction::RollForward {
505 parsed_block: ScannedBlock::new(
506 "hash-10",
507 BlockNumber(100),
508 SlotNumber(10),
509 Vec::<&str>::new(),
510 ),
511 },
512 ChainBlockNextAction::RollBackward {
513 rollback_point: RawCardanoPoint::new(SlotNumber(9), "hash-9"),
514 },
515 ])));
516 let mut block_streamer = ChainReaderBlockStreamer::try_new(
517 chain_reader,
518 None,
519 BlockNumber(1000),
520 MAX_ROLL_FORWARDS_PER_POLL,
521 TestLogger::stdout(),
522 )
523 .await
524 .unwrap();
525
526 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
527
528 assert_eq!(
529 Some(ChainScannedBlocks::RollForwards(vec![
530 ScannedBlock::new("hash-8", BlockNumber(80), SlotNumber(8), Vec::<&str>::new()),
531 ScannedBlock::new("hash-9", BlockNumber(90), SlotNumber(9), Vec::<&str>::new())
532 ])),
533 scanned_blocks,
534 );
535 assert_eq!(
536 block_streamer.last_polled_point(),
537 Some(RawCardanoPoint::new(SlotNumber(9), "hash-9"))
538 );
539 }
540
541 #[tokio::test]
542 async fn test_parse_expected_backward_when_rollbackward_on_different_point_and_does_not_have_previous_rollforwards()
543 {
544 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![
545 ChainBlockNextAction::RollForward {
546 parsed_block: ScannedBlock::new(
547 "hash-8",
548 BlockNumber(80),
549 SlotNumber(8),
550 Vec::<&str>::new(),
551 ),
552 },
553 ChainBlockNextAction::RollForward {
554 parsed_block: ScannedBlock::new(
555 "hash-9",
556 BlockNumber(90),
557 SlotNumber(9),
558 Vec::<&str>::new(),
559 ),
560 },
561 ChainBlockNextAction::RollBackward {
562 rollback_point: RawCardanoPoint::new(SlotNumber(3), "hash-3"),
563 },
564 ])));
565 let mut block_streamer = ChainReaderBlockStreamer::try_new(
566 chain_reader,
567 None,
568 BlockNumber(1000),
569 MAX_ROLL_FORWARDS_PER_POLL,
570 TestLogger::stdout(),
571 )
572 .await
573 .unwrap();
574
575 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
576
577 assert_eq!(
578 Some(ChainScannedBlocks::RollBackward(SlotNumber(3))),
579 scanned_blocks,
580 );
581 assert_eq!(
582 block_streamer.last_polled_point(),
583 Some(RawCardanoPoint::new(SlotNumber(3), "hash-3"))
584 );
585 }
586
587 #[tokio::test]
588 async fn test_parse_expected_nothing() {
589 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![])));
590 let mut block_streamer = ChainReaderBlockStreamer::try_new(
591 chain_reader,
592 None,
593 BlockNumber(1),
594 MAX_ROLL_FORWARDS_PER_POLL,
595 TestLogger::stdout(),
596 )
597 .await
598 .unwrap();
599
600 let scanned_blocks = block_streamer.poll_next().await.expect("poll_next failed");
601
602 assert_eq!(scanned_blocks, None);
603 }
604
605 #[tokio::test]
606 async fn test_last_polled_point_is_none_if_nothing_was_polled() {
607 let chain_reader = Arc::new(Mutex::new(FakeChainReader::new(vec![])));
608 let block_streamer = ChainReaderBlockStreamer::try_new(
609 chain_reader,
610 None,
611 BlockNumber(1),
612 MAX_ROLL_FORWARDS_PER_POLL,
613 TestLogger::stdout(),
614 )
615 .await
616 .unwrap();
617
618 assert_eq!(block_streamer.last_polled_point(), None);
619 }
620}