1use chrono::Utc;
2use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
3use mithril_client::MithrilResult;
4use slog::{Logger, warn};
5use std::{
6 fmt::Write,
7 ops::Deref,
8 sync::{Arc, RwLock},
9 time::{Duration, Instant},
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ProgressOutputType {
15 JsonReporter,
17 Tty,
19 Hidden,
21}
22
23impl From<ProgressOutputType> for ProgressDrawTarget {
24 fn from(value: ProgressOutputType) -> Self {
25 match value {
26 ProgressOutputType::JsonReporter => ProgressDrawTarget::hidden(),
27 ProgressOutputType::Tty => ProgressDrawTarget::stdout(),
28 ProgressOutputType::Hidden => ProgressDrawTarget::hidden(),
29 }
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum ProgressBarKind {
36 Bytes,
37 Files,
38}
39
40pub struct ProgressPrinter {
42 multi_progress: MultiProgress,
43 output_type: ProgressOutputType,
44 number_of_steps: u16,
45}
46
47impl ProgressPrinter {
48 pub fn new(output_type: ProgressOutputType, number_of_steps: u16) -> Self {
50 Self {
51 multi_progress: MultiProgress::with_draw_target(output_type.into()),
52 output_type,
53 number_of_steps,
54 }
55 }
56
57 pub fn report_step(&self, step_number: u16, text: &str) -> MithrilResult<()> {
59 match self.output_type {
60 ProgressOutputType::JsonReporter => eprintln!(
61 r#"{{"timestamp": "{timestamp}", "step_num": {step_number}, "total_steps": {number_of_steps}, "message": "{text}"}}"#,
62 timestamp = Utc::now().to_rfc3339(),
63 number_of_steps = self.number_of_steps,
64 ),
65 ProgressOutputType::Tty => self
66 .multi_progress
67 .println(format!("{step_number}/{} - {text}", self.number_of_steps))?,
68 ProgressOutputType::Hidden => (),
69 };
70
71 Ok(())
72 }
73}
74
75impl Deref for ProgressPrinter {
76 type Target = MultiProgress;
77
78 fn deref(&self) -> &Self::Target {
79 &self.multi_progress
80 }
81}
82
83#[derive(Clone)]
85pub struct ProgressBarJsonFormatter {
86 label: String,
87 kind: ProgressBarKind,
88}
89
90impl ProgressBarJsonFormatter {
91 pub fn new<T: Into<String>>(label: T, kind: ProgressBarKind) -> Self {
93 Self {
94 label: label.into(),
95 kind,
96 }
97 }
98
99 pub fn format(&self, progress_bar: &ProgressBar) -> String {
101 ProgressBarJsonFormatter::format_values(
102 &self.label,
103 self.kind,
104 Utc::now().to_rfc3339(),
105 progress_bar.position(),
106 progress_bar.length().unwrap_or(0),
107 progress_bar.eta(),
108 progress_bar.elapsed(),
109 )
110 }
111
112 fn format_values(
113 label: &str,
114 kind: ProgressBarKind,
115 timestamp: String,
116 amount_downloaded: u64,
117 amount_total: u64,
118 duration_left: Duration,
119 duration_elapsed: Duration,
120 ) -> String {
121 let amount_prefix = match kind {
122 ProgressBarKind::Bytes => "bytes",
123 ProgressBarKind::Files => "files",
124 };
125
126 format!(
127 r#"{{"label": "{}", "timestamp": "{}", "{}_downloaded": {}, "{}_total": {}, "seconds_left": {}.{:0>3}, "seconds_elapsed": {}.{:0>3}}}"#,
128 label,
129 timestamp,
130 amount_prefix,
131 amount_downloaded,
132 amount_prefix,
133 amount_total,
134 duration_left.as_secs(),
135 duration_left.subsec_millis(),
136 duration_elapsed.as_secs(),
137 duration_elapsed.subsec_millis(),
138 )
139 }
140}
141
142#[derive(Clone)]
144pub struct DownloadProgressReporter {
145 progress_bar: ProgressBar,
146 output_type: ProgressOutputType,
147 json_reporter: ProgressBarJsonFormatter,
148 last_json_report_instant: Arc<RwLock<Option<Instant>>>,
149 logger: Logger,
150}
151
152#[derive(Clone, Debug, PartialEq, Eq)]
153pub struct DownloadProgressReporterParams {
154 pub label: String,
155 pub output_type: ProgressOutputType,
156 pub progress_bar_kind: ProgressBarKind,
157 pub include_label_in_tty: bool,
158}
159
160impl DownloadProgressReporterParams {
161 pub fn style(&self) -> ProgressStyle {
162 ProgressStyle::with_template(&self.style_template())
163 .unwrap()
164 .with_key(
165 "eta",
166 |state: &indicatif::ProgressState, w: &mut dyn Write| {
167 write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
168 },
169 )
170 .progress_chars("#>-")
171 }
172
173 fn style_template(&self) -> String {
174 let label = if self.include_label_in_tty {
175 &self.label
176 } else {
177 ""
178 };
179
180 match self.progress_bar_kind {
181 ProgressBarKind::Bytes => {
182 format!(
183 "{{spinner:.green}} {label} [{{elapsed_precise}}] [{{wide_bar:.cyan/blue}}] {{bytes}}/{{total_bytes}} ({{eta}})"
184 )
185 }
186 ProgressBarKind::Files => {
187 format!(
188 "{{spinner:.green}} {label} [{{elapsed_precise}}] [{{wide_bar:.cyan/blue}}] Files: {{human_pos}}/{{human_len}} ({{eta}})"
189 )
190 }
191 }
192 }
193}
194
195impl DownloadProgressReporter {
196 pub fn new(
198 progress_bar: ProgressBar,
199 params: DownloadProgressReporterParams,
200 logger: Logger,
201 ) -> Self {
202 progress_bar.set_style(params.style());
203
204 Self {
205 progress_bar,
206 output_type: params.output_type,
207 json_reporter: ProgressBarJsonFormatter::new(¶ms.label, params.progress_bar_kind),
208 last_json_report_instant: Arc::new(RwLock::new(None)),
209 logger,
210 }
211 }
212
213 #[cfg(test)]
214 pub fn kind(&self) -> ProgressBarKind {
216 self.json_reporter.kind
217 }
218
219 pub fn report(&self, actual_position: u64) {
221 self.progress_bar.set_position(actual_position);
222 self.report_json_progress();
223 }
224
225 pub fn inc(&self, delta: u64) {
227 self.progress_bar.inc(delta);
228 self.report_json_progress();
229 }
230
231 pub fn finish(&self, message: &str) {
233 self.progress_bar.finish_with_message(message.to_string());
234 }
235
236 pub fn finish_and_clear(&self) {
238 self.progress_bar.finish_and_clear();
239 }
240
241 fn get_remaining_time_since_last_json_report(&self) -> Option<Duration> {
242 match self.last_json_report_instant.read() {
243 Ok(instant) => (*instant).map(|instant| instant.elapsed()),
244 Err(_) => None,
245 }
246 }
247
248 fn report_json_progress(&self) {
249 if let ProgressOutputType::JsonReporter = self.output_type {
250 let should_report = match self.get_remaining_time_since_last_json_report() {
251 Some(remaining_time) => remaining_time > Duration::from_millis(333),
252 None => true,
253 };
254
255 if should_report {
256 eprintln!("{}", self.json_reporter.format(&self.progress_bar));
257
258 match self.last_json_report_instant.write() {
259 Ok(mut instant) => *instant = Some(Instant::now()),
260 Err(error) => {
261 warn!(self.logger, "failed to update last json report instant"; "error" => ?error)
262 }
263 };
264 }
265 };
266 }
267
268 pub(crate) fn inner_progress_bar(&self) -> &ProgressBar {
269 &self.progress_bar
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use std::thread::sleep;
276
277 use super::*;
278 use indicatif::ProgressBar;
279 use serde_json::Value;
280
281 #[test]
282 fn json_reporter_change_downloaded_and_total_key_prefix_based_on_progress_bar_kind() {
283 fn run(kind: ProgressBarKind, expected_prefix: &str) {
284 let json_string = ProgressBarJsonFormatter::format_values(
285 "label",
286 kind,
287 "timestamp".to_string(),
288 0,
289 0,
290 Duration::from_millis(1000),
291 Duration::from_millis(2500),
292 );
293
294 assert!(
295 json_string.contains(&format!(r#""{expected_prefix}_downloaded":"#)),
296 "'{expected_prefix}_downloaded' key not found in json output: {json_string}",
297 );
298 assert!(
299 json_string.contains(&format!(r#""{expected_prefix}_total":"#)),
300 "'{expected_prefix}_total' key not found in json output: {json_string}",
301 );
302 }
303
304 run(ProgressBarKind::Bytes, "bytes");
305 run(ProgressBarKind::Files, "files");
306 }
307
308 #[test]
309 fn json_report_include_label() {
310 let json_string = ProgressBarJsonFormatter::format_values(
311 "unique_label",
312 ProgressBarKind::Bytes,
313 "timestamp".to_string(),
314 0,
315 0,
316 Duration::from_millis(7569),
317 Duration::from_millis(5124),
318 );
319
320 assert!(
321 json_string.contains(r#""label": "unique_label""#),
322 "Label key and/or value not found in json output: {json_string}",
323 );
324 }
325
326 #[test]
327 fn check_seconds_formatting_in_json_report_with_more_than_100_milliseconds() {
328 let json_string = ProgressBarJsonFormatter::format_values(
329 "label",
330 ProgressBarKind::Bytes,
331 "timestamp".to_string(),
332 0,
333 0,
334 Duration::from_millis(7569),
335 Duration::from_millis(5124),
336 );
337
338 assert!(
339 json_string.contains(r#""seconds_left": 7.569"#),
340 "Not expected value in json output: {json_string}",
341 );
342 assert!(
343 json_string.contains(r#""seconds_elapsed": 5.124"#),
344 "Not expected value in json output: {json_string}",
345 );
346 }
347
348 #[test]
349 fn check_seconds_formatting_in_json_report_with_less_than_100_milliseconds() {
350 let json_string = ProgressBarJsonFormatter::format_values(
351 "label",
352 ProgressBarKind::Bytes,
353 "timestamp".to_string(),
354 0,
355 0,
356 Duration::from_millis(7006),
357 Duration::from_millis(5004),
358 );
359
360 assert!(
361 json_string.contains(r#""seconds_left": 7.006"#),
362 "Not expected value in json output: {json_string}"
363 );
364 assert!(
365 json_string.contains(r#""seconds_elapsed": 5.004"#),
366 "Not expected value in json output: {json_string}"
367 );
368 }
369
370 #[test]
371 fn check_seconds_formatting_in_json_report_with_milliseconds_ending_by_zeros() {
372 let json_string = ProgressBarJsonFormatter::format_values(
373 "label",
374 ProgressBarKind::Bytes,
375 "timestamp".to_string(),
376 0,
377 0,
378 Duration::from_millis(7200),
379 Duration::from_millis(5100),
380 );
381
382 assert!(
383 json_string.contains(r#""seconds_left": 7.200"#),
384 "Not expected value in json output: {json_string}"
385 );
386 assert!(
387 json_string.contains(r#""seconds_elapsed": 5.100"#),
388 "Not expected value in json output: {json_string}"
389 );
390 }
391
392 #[test]
393 fn check_seconds_left_and_elapsed_time_are_used_by_the_formatter() {
394 fn format_duration(duration: &Duration) -> String {
395 format!("{}.{}", duration.as_secs(), duration.subsec_nanos())
396 }
397 fn round_at_ms(duration: Duration) -> Duration {
398 Duration::from_millis(duration.as_millis() as u64)
399 }
400
401 let progress_bar = ProgressBar::new(4);
403 sleep(Duration::from_millis(15));
405 progress_bar.set_position(1);
406
407 let duration_left_before = round_at_ms(progress_bar.eta());
408 let duration_elapsed_before = round_at_ms(progress_bar.elapsed());
409
410 let json_string =
411 ProgressBarJsonFormatter::new("label", ProgressBarKind::Bytes).format(&progress_bar);
412
413 let duration_left_after = round_at_ms(progress_bar.eta());
414 let duration_elapsed_after = round_at_ms(progress_bar.elapsed());
415
416 let delta = 0.1;
418
419 let json_value: Value = serde_json::from_str(&json_string).unwrap();
420 let seconds_left = json_value["seconds_left"].as_f64().unwrap();
421 let seconds_elapsed = json_value["seconds_elapsed"].as_f64().unwrap();
422
423 assert!(
425 seconds_elapsed * 3.0 - delta < seconds_left
426 && seconds_left < seconds_elapsed * 3.0 + delta,
427 "seconds_left should be close to 3*{} but it's {}.",
428 &seconds_elapsed,
429 &seconds_left
430 );
431
432 let duration_left = Duration::from_secs_f64(seconds_left);
433 assert!(
434 duration_left_before <= duration_left && duration_left <= duration_left_after,
435 "Duration left: {} should be between {} and {}",
436 format_duration(&duration_left),
437 format_duration(&duration_left_before),
438 format_duration(&duration_left_after),
439 );
440
441 let duration_elapsed = Duration::from_secs_f64(seconds_elapsed);
442 assert!(
443 duration_elapsed_before <= duration_elapsed
444 && duration_elapsed <= duration_elapsed_after,
445 "Duration elapsed: {} should be between {} and {}",
446 format_duration(&duration_elapsed),
447 format_duration(&duration_elapsed_before),
448 format_duration(&duration_elapsed_after),
449 );
450 }
451
452 #[test]
453 fn style_of_download_progress_reporter_when_include_label_in_tty_is_false() {
454 let params = DownloadProgressReporterParams {
455 label: "label".to_string(),
456 output_type: ProgressOutputType::Tty,
457 progress_bar_kind: ProgressBarKind::Bytes,
458 include_label_in_tty: false,
459 };
460
461 let style_template = params.style_template();
462 assert!(
463 !style_template.contains("label"),
464 "Label should not be included in the style template, got: '{style_template}'"
465 );
466 }
467
468 #[test]
469 fn style_of_download_progress_reporter_when_include_label_in_tty_is_true() {
470 let params = DownloadProgressReporterParams {
471 label: "label".to_string(),
472 output_type: ProgressOutputType::Tty,
473 progress_bar_kind: ProgressBarKind::Bytes,
474 include_label_in_tty: true,
475 };
476
477 let style_template = params.style_template();
478 assert!(
479 style_template.contains("label"),
480 "Label should be included in the style template, got: '{style_template}'"
481 );
482 }
483}