1use chrono::Utc;
2use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
3use mithril_client::MithrilResult;
4use slog::{warn, Logger};
5use std::{
6 fmt::Write,
7 ops::Deref,
8 sync::{Arc, RwLock},
9 time::{Duration, Instant},
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ProgressOutputType {
15 JsonReporter,
17 Tty,
19 Hidden,
21}
22
23impl From<ProgressOutputType> for ProgressDrawTarget {
24 fn from(value: ProgressOutputType) -> Self {
25 match value {
26 ProgressOutputType::JsonReporter => ProgressDrawTarget::hidden(),
27 ProgressOutputType::Tty => ProgressDrawTarget::stdout(),
28 ProgressOutputType::Hidden => ProgressDrawTarget::hidden(),
29 }
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum ProgressBarKind {
36 Bytes,
37 Files,
38}
39
40pub struct ProgressPrinter {
42 multi_progress: MultiProgress,
43 output_type: ProgressOutputType,
44 number_of_steps: u16,
45}
46
47impl ProgressPrinter {
48 pub fn new(output_type: ProgressOutputType, number_of_steps: u16) -> Self {
50 Self {
51 multi_progress: MultiProgress::with_draw_target(output_type.into()),
52 output_type,
53 number_of_steps,
54 }
55 }
56
57 pub fn report_step(&self, step_number: u16, text: &str) -> MithrilResult<()> {
59 match self.output_type {
60 ProgressOutputType::JsonReporter => eprintln!(
61 r#"{{"timestamp": "{timestamp}", "step_num": {step_number}, "total_steps": {number_of_steps}, "message": "{text}"}}"#,
62 timestamp = Utc::now().to_rfc3339(),
63 number_of_steps = self.number_of_steps,
64 ),
65 ProgressOutputType::Tty => self
66 .multi_progress
67 .println(format!("{step_number}/{} - {text}", self.number_of_steps))?,
68 ProgressOutputType::Hidden => (),
69 };
70
71 Ok(())
72 }
73}
74
75impl Deref for ProgressPrinter {
76 type Target = MultiProgress;
77
78 fn deref(&self) -> &Self::Target {
79 &self.multi_progress
80 }
81}
82
83#[derive(Clone)]
85pub struct ProgressBarJsonFormatter {
86 label: String,
87 kind: ProgressBarKind,
88}
89
90impl ProgressBarJsonFormatter {
91 pub fn new<T: Into<String>>(label: T, kind: ProgressBarKind) -> Self {
93 Self {
94 label: label.into(),
95 kind,
96 }
97 }
98
99 pub fn format(&self, progress_bar: &ProgressBar) -> String {
101 ProgressBarJsonFormatter::format_values(
102 &self.label,
103 self.kind,
104 Utc::now().to_rfc3339(),
105 progress_bar.position(),
106 progress_bar.length().unwrap_or(0),
107 progress_bar.eta(),
108 progress_bar.elapsed(),
109 )
110 }
111
112 fn format_values(
113 label: &str,
114 kind: ProgressBarKind,
115 timestamp: String,
116 amount_downloaded: u64,
117 amount_total: u64,
118 duration_left: Duration,
119 duration_elapsed: Duration,
120 ) -> String {
121 let amount_prefix = match kind {
122 ProgressBarKind::Bytes => "bytes",
123 ProgressBarKind::Files => "files",
124 };
125
126 format!(
127 r#"{{"label": "{}", "timestamp": "{}", "{}_downloaded": {}, "{}_total": {}, "seconds_left": {}.{:0>3}, "seconds_elapsed": {}.{:0>3}}}"#,
128 label,
129 timestamp,
130 amount_prefix,
131 amount_downloaded,
132 amount_prefix,
133 amount_total,
134 duration_left.as_secs(),
135 duration_left.subsec_millis(),
136 duration_elapsed.as_secs(),
137 duration_elapsed.subsec_millis(),
138 )
139 }
140}
141
142#[derive(Clone)]
144pub struct DownloadProgressReporter {
145 progress_bar: ProgressBar,
146 output_type: ProgressOutputType,
147 json_reporter: ProgressBarJsonFormatter,
148 last_json_report_instant: Arc<RwLock<Option<Instant>>>,
149 logger: Logger,
150}
151
152#[derive(Clone, Debug, PartialEq, Eq)]
153pub struct DownloadProgressReporterParams {
154 pub label: String,
155 pub output_type: ProgressOutputType,
156 pub progress_bar_kind: ProgressBarKind,
157 pub include_label_in_tty: bool,
158}
159
160impl DownloadProgressReporterParams {
161 pub fn style(&self) -> ProgressStyle {
162 ProgressStyle::with_template(&self.style_template())
163 .unwrap()
164 .with_key(
165 "eta",
166 |state: &indicatif::ProgressState, w: &mut dyn Write| {
167 write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
168 },
169 )
170 .progress_chars("#>-")
171 }
172
173 fn style_template(&self) -> String {
174 let label = if self.include_label_in_tty {
175 &self.label
176 } else {
177 ""
178 };
179
180 match self.progress_bar_kind {
181 ProgressBarKind::Bytes => {
182 format!("{{spinner:.green}} {label} [{{elapsed_precise}}] [{{wide_bar:.cyan/blue}}] {{bytes}}/{{total_bytes}} ({{eta}})")
183 }
184 ProgressBarKind::Files => {
185 format!("{{spinner:.green}} {label} [{{elapsed_precise}}] [{{wide_bar:.cyan/blue}}] Files: {{human_pos}}/{{human_len}} ({{eta}})")
186 }
187 }
188 }
189}
190
191impl DownloadProgressReporter {
192 pub fn new(
194 progress_bar: ProgressBar,
195 params: DownloadProgressReporterParams,
196 logger: Logger,
197 ) -> Self {
198 progress_bar.set_style(params.style());
199
200 Self {
201 progress_bar,
202 output_type: params.output_type,
203 json_reporter: ProgressBarJsonFormatter::new(¶ms.label, params.progress_bar_kind),
204 last_json_report_instant: Arc::new(RwLock::new(None)),
205 logger,
206 }
207 }
208
209 #[cfg(test)]
210 pub fn kind(&self) -> ProgressBarKind {
212 self.json_reporter.kind
213 }
214
215 pub fn report(&self, actual_position: u64) {
217 self.progress_bar.set_position(actual_position);
218 self.report_json_progress();
219 }
220
221 pub fn inc(&self, delta: u64) {
223 self.progress_bar.inc(delta);
224 self.report_json_progress();
225 }
226
227 pub fn finish(&self, message: &str) {
229 self.progress_bar.finish_with_message(message.to_string());
230 }
231
232 pub fn finish_and_clear(&self) {
234 self.progress_bar.finish_and_clear();
235 }
236
237 fn get_remaining_time_since_last_json_report(&self) -> Option<Duration> {
238 match self.last_json_report_instant.read() {
239 Ok(instant) => (*instant).map(|instant| instant.elapsed()),
240 Err(_) => None,
241 }
242 }
243
244 fn report_json_progress(&self) {
245 if let ProgressOutputType::JsonReporter = self.output_type {
246 let should_report = match self.get_remaining_time_since_last_json_report() {
247 Some(remaining_time) => remaining_time > Duration::from_millis(333),
248 None => true,
249 };
250
251 if should_report {
252 eprintln!("{}", self.json_reporter.format(&self.progress_bar));
253
254 match self.last_json_report_instant.write() {
255 Ok(mut instant) => *instant = Some(Instant::now()),
256 Err(error) => {
257 warn!(self.logger, "failed to update last json report instant"; "error" => ?error)
258 }
259 };
260 }
261 };
262 }
263
264 pub(crate) fn inner_progress_bar(&self) -> &ProgressBar {
265 &self.progress_bar
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use std::thread::sleep;
272
273 use super::*;
274 use indicatif::ProgressBar;
275 use serde_json::Value;
276
277 #[test]
278 fn json_reporter_change_downloaded_and_total_key_prefix_based_on_progress_bar_kind() {
279 fn run(kind: ProgressBarKind, expected_prefix: &str) {
280 let json_string = ProgressBarJsonFormatter::format_values(
281 "label",
282 kind,
283 "timestamp".to_string(),
284 0,
285 0,
286 Duration::from_millis(1000),
287 Duration::from_millis(2500),
288 );
289
290 assert!(
291 json_string.contains(&format!(r#""{expected_prefix}_downloaded":"#)),
292 "'{expected_prefix}_downloaded' key not found in json output: {json_string}",
293 );
294 assert!(
295 json_string.contains(&format!(r#""{expected_prefix}_total":"#)),
296 "'{expected_prefix}_total' key not found in json output: {json_string}",
297 );
298 }
299
300 run(ProgressBarKind::Bytes, "bytes");
301 run(ProgressBarKind::Files, "files");
302 }
303
304 #[test]
305 fn json_report_include_label() {
306 let json_string = ProgressBarJsonFormatter::format_values(
307 "unique_label",
308 ProgressBarKind::Bytes,
309 "timestamp".to_string(),
310 0,
311 0,
312 Duration::from_millis(7569),
313 Duration::from_millis(5124),
314 );
315
316 assert!(
317 json_string.contains(r#""label": "unique_label""#),
318 "Label key and/or value not found in json output: {json_string}",
319 );
320 }
321
322 #[test]
323 fn check_seconds_formatting_in_json_report_with_more_than_100_milliseconds() {
324 let json_string = ProgressBarJsonFormatter::format_values(
325 "label",
326 ProgressBarKind::Bytes,
327 "timestamp".to_string(),
328 0,
329 0,
330 Duration::from_millis(7569),
331 Duration::from_millis(5124),
332 );
333
334 assert!(
335 json_string.contains(r#""seconds_left": 7.569"#),
336 "Not expected value in json output: {json_string}",
337 );
338 assert!(
339 json_string.contains(r#""seconds_elapsed": 5.124"#),
340 "Not expected value in json output: {json_string}",
341 );
342 }
343
344 #[test]
345 fn check_seconds_formatting_in_json_report_with_less_than_100_milliseconds() {
346 let json_string = ProgressBarJsonFormatter::format_values(
347 "label",
348 ProgressBarKind::Bytes,
349 "timestamp".to_string(),
350 0,
351 0,
352 Duration::from_millis(7006),
353 Duration::from_millis(5004),
354 );
355
356 assert!(
357 json_string.contains(r#""seconds_left": 7.006"#),
358 "Not expected value in json output: {}",
359 json_string
360 );
361 assert!(
362 json_string.contains(r#""seconds_elapsed": 5.004"#),
363 "Not expected value in json output: {}",
364 json_string
365 );
366 }
367
368 #[test]
369 fn check_seconds_formatting_in_json_report_with_milliseconds_ending_by_zeros() {
370 let json_string = ProgressBarJsonFormatter::format_values(
371 "label",
372 ProgressBarKind::Bytes,
373 "timestamp".to_string(),
374 0,
375 0,
376 Duration::from_millis(7200),
377 Duration::from_millis(5100),
378 );
379
380 assert!(
381 json_string.contains(r#""seconds_left": 7.200"#),
382 "Not expected value in json output: {}",
383 json_string
384 );
385 assert!(
386 json_string.contains(r#""seconds_elapsed": 5.100"#),
387 "Not expected value in json output: {}",
388 json_string
389 );
390 }
391
392 #[test]
393 fn check_seconds_left_and_elapsed_time_are_used_by_the_formatter() {
394 fn format_duration(duration: &Duration) -> String {
395 format!("{}.{}", duration.as_secs(), duration.subsec_nanos())
396 }
397 fn round_at_ms(duration: Duration) -> Duration {
398 Duration::from_millis(duration.as_millis() as u64)
399 }
400
401 let progress_bar = ProgressBar::new(4);
403 sleep(Duration::from_millis(15));
405 progress_bar.set_position(1);
406
407 let duration_left_before = round_at_ms(progress_bar.eta());
408 let duration_elapsed_before = round_at_ms(progress_bar.elapsed());
409
410 let json_string =
411 ProgressBarJsonFormatter::new("label", ProgressBarKind::Bytes).format(&progress_bar);
412
413 let duration_left_after = round_at_ms(progress_bar.eta());
414 let duration_elapsed_after = round_at_ms(progress_bar.elapsed());
415
416 let delta = 0.1;
418
419 let json_value: Value = serde_json::from_str(&json_string).unwrap();
420 let seconds_left = json_value["seconds_left"].as_f64().unwrap();
421 let seconds_elapsed = json_value["seconds_elapsed"].as_f64().unwrap();
422
423 assert!(
425 seconds_elapsed * 3.0 - delta < seconds_left
426 && seconds_left < seconds_elapsed * 3.0 + delta,
427 "seconds_left should be close to 3*{} but it's {}.",
428 &seconds_elapsed,
429 &seconds_left
430 );
431
432 let duration_left = Duration::from_secs_f64(seconds_left);
433 assert!(
434 duration_left_before <= duration_left && duration_left <= duration_left_after,
435 "Duration left: {} should be between {} and {}",
436 format_duration(&duration_left),
437 format_duration(&duration_left_before),
438 format_duration(&duration_left_after),
439 );
440
441 let duration_elapsed = Duration::from_secs_f64(seconds_elapsed);
442 assert!(
443 duration_elapsed_before <= duration_elapsed
444 && duration_elapsed <= duration_elapsed_after,
445 "Duration elapsed: {} should be between {} and {}",
446 format_duration(&duration_elapsed),
447 format_duration(&duration_elapsed_before),
448 format_duration(&duration_elapsed_after),
449 );
450 }
451
452 #[test]
453 fn style_of_download_progress_reporter_when_include_label_in_tty_is_false() {
454 let params = DownloadProgressReporterParams {
455 label: "label".to_string(),
456 output_type: ProgressOutputType::Tty,
457 progress_bar_kind: ProgressBarKind::Bytes,
458 include_label_in_tty: false,
459 };
460
461 let style_template = params.style_template();
462 assert!(
463 !style_template.contains("label"),
464 "Label should not be included in the style template, got: '{style_template}'"
465 );
466 }
467
468 #[test]
469 fn style_of_download_progress_reporter_when_include_label_in_tty_is_true() {
470 let params = DownloadProgressReporterParams {
471 label: "label".to_string(),
472 output_type: ProgressOutputType::Tty,
473 progress_bar_kind: ProgressBarKind::Bytes,
474 include_label_in_tty: true,
475 };
476
477 let style_template = params.style_template();
478 assert!(
479 style_template.contains("label"),
480 "Label should be included in the style template, got: '{style_template}'"
481 );
482 }
483}