mithril_client_cli/utils/
multi_download_progress_reporter.rs

1use std::collections::HashMap;
2
3use indicatif::{MultiProgress, ProgressBar};
4use slog::Logger;
5use tokio::sync::RwLock;
6
7use super::{
8    DownloadProgressReporter, DownloadProgressReporterParams, ProgressBarKind, ProgressOutputType,
9};
10
11/// A progress reporter that can handle multiple downloads at once.
12///
13/// It shows a global progress bar for all downloads and individual progress bars for each download.
14pub struct MultiDownloadProgressReporter {
15    output_type: ProgressOutputType,
16    parent_container: MultiProgress,
17    main_reporter: DownloadProgressReporter,
18    dl_reporters: RwLock<HashMap<String, DownloadProgressReporter>>,
19    logger: Logger,
20}
21
22impl MultiDownloadProgressReporter {
23    /// Initialize a new `MultiDownloadProgressReporter`.
24    pub fn new(total_files: u64, output_type: ProgressOutputType, logger: Logger) -> Self {
25        let parent_container = MultiProgress::new();
26        parent_container.set_draw_target(output_type.into());
27        let main_pb = parent_container.add(ProgressBar::new(total_files));
28        let main_reporter = DownloadProgressReporter::new(
29            main_pb,
30            DownloadProgressReporterParams {
31                label: "Files".to_string(),
32                output_type,
33                progress_bar_kind: ProgressBarKind::Files,
34                include_label_in_tty: false,
35            },
36            logger.clone(),
37        );
38
39        Self {
40            output_type,
41            parent_container,
42            main_reporter,
43            dl_reporters: RwLock::new(HashMap::new()),
44            logger,
45        }
46    }
47
48    #[cfg(test)]
49    /// Get the position of the main progress bar.
50    pub fn position(&self) -> u64 {
51        self.main_reporter.inner_progress_bar().position()
52    }
53
54    #[cfg(test)]
55    /// Get the total number of downloads.
56    pub fn total_downloads(&self) -> u64 {
57        self.main_reporter
58            .inner_progress_bar()
59            .length()
60            .unwrap_or(0)
61    }
62
63    #[cfg(test)]
64    /// Get the number of active downloads.
65    pub async fn number_of_active_downloads(&self) -> usize {
66        self.dl_reporters.read().await.len()
67    }
68
69    /// Bump the main progress bar by one.
70    pub fn bump_main_bar_progress(&self) {
71        self.main_reporter.inc(1);
72    }
73
74    /// Add a new child bar to the progress reporter.
75    pub async fn add_child_bar<T: Into<String>>(&self, name: T, kind: ProgressBarKind, total: u64) {
76        let name = name.into();
77        let dl_progress_bar = self.parent_container.add(ProgressBar::new(total));
78        let dl_reporter = DownloadProgressReporter::new(
79            dl_progress_bar,
80            DownloadProgressReporterParams {
81                label: name.to_owned(),
82                output_type: self.output_type,
83                progress_bar_kind: kind,
84                include_label_in_tty: true,
85            },
86            self.logger.clone(),
87        );
88
89        let mut reporters = self.dl_reporters.write().await;
90        reporters.insert(name, dl_reporter);
91    }
92
93    /// Report progress of a child bar, updating the progress bar to the given actual_position.
94    pub async fn progress_child_bar<T: AsRef<str>>(&self, name: T, actual_position: u64) {
95        if let Some(child_reporter) = self.get_child_bar(name.as_ref()).await {
96            child_reporter.report(actual_position);
97        }
98    }
99
100    /// Finish a child bar, removing it from the progress reporter an bumping the main progress bar.
101    pub async fn finish_child_bar<T: Into<String>>(&self, name: T) {
102        let name = name.into();
103        if let Some(child_reporter) = self.get_child_bar(&name).await {
104            child_reporter.finish_and_clear();
105            self.parent_container
106                .remove(child_reporter.inner_progress_bar());
107
108            let mut reporters = self.dl_reporters.write().await;
109            reporters.remove(&name);
110
111            self.bump_main_bar_progress();
112        }
113    }
114
115    /// Finish all child bars and the main progress bar and prints a message.
116    pub async fn finish_all(&self, message: &str) {
117        let mut reporters = self.dl_reporters.write().await;
118        for (_name, reporter) in reporters.iter() {
119            reporter.finish_and_clear();
120            self.parent_container.remove(reporter.inner_progress_bar());
121        }
122        reporters.clear();
123
124        self.main_reporter.finish(message);
125    }
126
127    async fn get_child_bar(&self, name: &str) -> Option<DownloadProgressReporter> {
128        let cdl_reporters = self.dl_reporters.read().await;
129        cdl_reporters.get(name).cloned()
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use slog::o;
136
137    use super::*;
138
139    #[test]
140    fn main_progress_bar_is_of_file_kind() {
141        let multi_dl_reporter = MultiDownloadProgressReporter::new(
142            1,
143            ProgressOutputType::Hidden,
144            slog::Logger::root(slog::Discard, o!()),
145        );
146
147        assert_eq!(
148            multi_dl_reporter.main_reporter.kind(),
149            ProgressBarKind::Files
150        );
151    }
152
153    #[tokio::test]
154    async fn adding_new_child_bar() {
155        let multi_dl_reporter = MultiDownloadProgressReporter::new(
156            1,
157            ProgressOutputType::Hidden,
158            slog::Logger::root(slog::Discard, o!()),
159        );
160
161        multi_dl_reporter
162            .add_child_bar("name", ProgressBarKind::Bytes, 1000)
163            .await;
164
165        assert!(multi_dl_reporter
166            .get_child_bar("name")
167            .await
168            .is_some_and(|dl_reporter| dl_reporter.kind() == ProgressBarKind::Bytes));
169    }
170
171    #[tokio::test]
172    async fn finishing_child_bar() {
173        let multi_dl_reporter = MultiDownloadProgressReporter::new(
174            1,
175            ProgressOutputType::Hidden,
176            slog::Logger::root(slog::Discard, o!()),
177        );
178
179        multi_dl_reporter
180            .add_child_bar("name", ProgressBarKind::Bytes, 1000)
181            .await;
182
183        assert_eq!(
184            multi_dl_reporter
185                .main_reporter
186                .inner_progress_bar()
187                .position(),
188            0
189        );
190
191        multi_dl_reporter.finish_child_bar("name").await;
192
193        assert_eq!(
194            multi_dl_reporter
195                .main_reporter
196                .inner_progress_bar()
197                .position(),
198            1
199        );
200        assert!(multi_dl_reporter.get_child_bar("name").await.is_none());
201    }
202
203    #[tokio::test]
204    async fn finishing_child_bar_that_does_not_exist() {
205        let multi_dl_reporter = MultiDownloadProgressReporter::new(
206            1,
207            ProgressOutputType::Hidden,
208            slog::Logger::root(slog::Discard, o!()),
209        );
210
211        assert!(multi_dl_reporter.get_child_bar("name").await.is_none());
212
213        multi_dl_reporter.finish_child_bar("name").await;
214
215        assert_eq!(
216            multi_dl_reporter
217                .main_reporter
218                .inner_progress_bar()
219                .position(),
220            0
221        );
222    }
223
224    #[tokio::test]
225    async fn finishing_all_remove_all_child_bars() {
226        let total_files = 132;
227        let multi_dl_reporter = MultiDownloadProgressReporter::new(
228            total_files,
229            ProgressOutputType::Hidden,
230            slog::Logger::root(slog::Discard, o!()),
231        );
232
233        multi_dl_reporter
234            .add_child_bar("first", ProgressBarKind::Bytes, 10)
235            .await;
236        multi_dl_reporter
237            .add_child_bar("second", ProgressBarKind::Bytes, 20)
238            .await;
239        assert_eq!(multi_dl_reporter.dl_reporters.read().await.len(), 2);
240
241        multi_dl_reporter.finish_all("message").await;
242
243        assert_eq!(multi_dl_reporter.dl_reporters.read().await.len(), 0);
244        assert_eq!(
245            multi_dl_reporter
246                .main_reporter
247                .inner_progress_bar()
248                .position(),
249            total_files
250        );
251        assert!(multi_dl_reporter
252            .main_reporter
253            .inner_progress_bar()
254            .is_finished());
255    }
256
257    #[tokio::test]
258    async fn progress_child_bar_to_the_given_bytes() {
259        let multi_dl_reporter = MultiDownloadProgressReporter::new(
260            4,
261            ProgressOutputType::Hidden,
262            slog::Logger::root(slog::Discard, o!()),
263        );
264
265        multi_dl_reporter
266            .add_child_bar("updated", ProgressBarKind::Bytes, 10)
267            .await;
268        multi_dl_reporter
269            .add_child_bar("other", ProgressBarKind::Bytes, 20)
270            .await;
271
272        let updated_progress_bar = multi_dl_reporter.get_child_bar("updated").await.unwrap();
273        let other_progress_bar = multi_dl_reporter.get_child_bar("other").await.unwrap();
274
275        assert_eq!(updated_progress_bar.inner_progress_bar().position(), 0);
276
277        multi_dl_reporter.progress_child_bar("updated", 5).await;
278        assert_eq!(updated_progress_bar.inner_progress_bar().position(), 5);
279
280        multi_dl_reporter.progress_child_bar("updated", 9).await;
281        assert_eq!(updated_progress_bar.inner_progress_bar().position(), 9);
282
283        assert_eq!(
284            other_progress_bar.inner_progress_bar().position(),
285            0,
286            "Other progress bar should not be affected by updating the 'updated' progress bar"
287        );
288    }
289
290    #[tokio::test]
291    async fn progress_child_bar_that_does_not_exist_do_nothing() {
292        let multi_dl_reporter = MultiDownloadProgressReporter::new(
293            2,
294            ProgressOutputType::Hidden,
295            slog::Logger::root(slog::Discard, o!()),
296        );
297
298        multi_dl_reporter.progress_child_bar("not_exist", 5).await;
299
300        assert!(multi_dl_reporter.get_child_bar("not_exist").await.is_none());
301    }
302
303    #[test]
304    fn bump_main_bar_progress_increase_its_value_by_one() {
305        let multi_dl_reporter = MultiDownloadProgressReporter::new(
306            2,
307            ProgressOutputType::Hidden,
308            slog::Logger::root(slog::Discard, o!()),
309        );
310
311        assert_eq!(
312            0,
313            multi_dl_reporter
314                .main_reporter
315                .inner_progress_bar()
316                .position()
317        );
318
319        multi_dl_reporter.bump_main_bar_progress();
320
321        assert_eq!(
322            1,
323            multi_dl_reporter
324                .main_reporter
325                .inner_progress_bar()
326                .position()
327        );
328    }
329}