mithril_client/file_downloader/
retry.rs

1use std::{path::Path, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use mithril_common::{entities::CompressionAlgorithm, StdResult};
5
6use super::{DownloadEvent, FileDownloader, FileDownloaderUri};
7
8/// Policy for retrying file downloads.
9#[derive(Debug, PartialEq, Clone)]
10pub struct FileDownloadRetryPolicy {
11    /// Number of attempts to download a file.
12    pub attempts: usize,
13    /// Delay between two attempts.
14    pub delay_between_attempts: Duration,
15}
16
17impl FileDownloadRetryPolicy {
18    /// Create a policy that never retries.
19    pub fn never() -> Self {
20        Self {
21            attempts: 1,
22            delay_between_attempts: Duration::from_secs(0),
23        }
24    }
25}
26
27impl Default for FileDownloadRetryPolicy {
28    /// Create a default retry policy.
29    fn default() -> Self {
30        Self {
31            attempts: 3,
32            delay_between_attempts: Duration::from_secs(5),
33        }
34    }
35}
36
37/// RetryDownloader is a wrapper around FileDownloader that retries downloading a file if it fails.
38pub struct RetryDownloader {
39    /// File downloader to use.
40    file_downloader: Arc<dyn FileDownloader>,
41    /// Number of attempts to download a file.
42    retry_policy: FileDownloadRetryPolicy,
43}
44
45impl RetryDownloader {
46    /// Create a new RetryDownloader.
47    pub fn new(
48        file_downloader: Arc<dyn FileDownloader>,
49        retry_policy: FileDownloadRetryPolicy,
50    ) -> Self {
51        Self {
52            file_downloader,
53            retry_policy,
54        }
55    }
56}
57
58#[async_trait]
59impl FileDownloader for RetryDownloader {
60    async fn download_unpack(
61        &self,
62        location: &FileDownloaderUri,
63        file_size: u64,
64        target_dir: &Path,
65        compression_algorithm: Option<CompressionAlgorithm>,
66        download_event_type: DownloadEvent,
67    ) -> StdResult<()> {
68        let retry_policy = &self.retry_policy;
69        let mut nb_attempts = 0;
70        loop {
71            nb_attempts += 1;
72            match self
73                .file_downloader
74                .download_unpack(
75                    location,
76                    file_size,
77                    target_dir,
78                    compression_algorithm,
79                    download_event_type.clone(),
80                )
81                .await
82            {
83                Ok(result) => return Ok(result),
84                Err(_) if nb_attempts >= retry_policy.attempts => {
85                    return Err(anyhow::anyhow!(
86                        "Download of location {:?} failed after {} attempts",
87                        location,
88                        nb_attempts
89                    ));
90                }
91                _ => tokio::time::sleep(retry_policy.delay_between_attempts).await,
92            }
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99
100    use std::{sync::Mutex, time::Instant};
101
102    use mithril_common::entities::FileUri;
103
104    use crate::file_downloader::MockFileDownloaderBuilder;
105
106    use super::*;
107
108    #[tokio::test]
109    async fn download_return_the_result_of_download_without_retry() {
110        let mock_file_downloader = MockFileDownloaderBuilder::default()
111            .with_file_uri("http://whatever/00001.tar.gz")
112            .with_compression(None)
113            .with_success()
114            .build();
115        let retry_downloader = RetryDownloader::new(
116            Arc::new(mock_file_downloader),
117            FileDownloadRetryPolicy::never(),
118        );
119
120        retry_downloader
121            .download_unpack(
122                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
123                0,
124                Path::new("."),
125                None,
126                DownloadEvent::Immutable {
127                    immutable_file_number: 1,
128                    download_id: "download_id".to_string(),
129                },
130            )
131            .await
132            .unwrap();
133    }
134
135    #[tokio::test]
136    async fn when_download_fails_do_not_retry_by_default() {
137        let mock_file_downloader = MockFileDownloaderBuilder::default()
138            .with_file_uri("http://whatever/00001.tar.gz")
139            .with_compression(None)
140            .with_failure()
141            .build();
142        let retry_downloader = RetryDownloader::new(
143            Arc::new(mock_file_downloader),
144            FileDownloadRetryPolicy::never(),
145        );
146
147        retry_downloader
148            .download_unpack(
149                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
150                0,
151                Path::new("."),
152                None,
153                DownloadEvent::Immutable {
154                    immutable_file_number: 1,
155                    download_id: "download_id".to_string(),
156                },
157            )
158            .await
159            .expect_err("An error should be returned when download fails");
160    }
161
162    #[tokio::test]
163    async fn should_retry_if_fail() {
164        let mock_file_downloader = MockFileDownloaderBuilder::default()
165            .with_file_uri("http://whatever/00001.tar.gz")
166            .with_compression(None)
167            .with_failure()
168            .with_times(2)
169            .next_call()
170            .with_file_uri("http://whatever/00001.tar.gz")
171            .with_compression(None)
172            .with_times(1)
173            .with_success()
174            .build();
175        let retry_downloader = RetryDownloader::new(
176            Arc::new(mock_file_downloader),
177            FileDownloadRetryPolicy {
178                attempts: 3,
179                delay_between_attempts: Duration::from_millis(10),
180            },
181        );
182
183        retry_downloader
184            .download_unpack(
185                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
186                0,
187                Path::new("."),
188                None,
189                DownloadEvent::Ancillary {
190                    download_id: "download_id".to_string(),
191                },
192            )
193            .await
194            .unwrap();
195    }
196
197    #[tokio::test]
198    async fn should_recall_a_failing_inner_downloader_up_to_the_limit() {
199        let mock_file_downloader = MockFileDownloaderBuilder::default()
200            .with_file_uri("http://whatever/00001.tar.gz")
201            .with_compression(None)
202            .with_failure()
203            .with_times(3)
204            .build();
205        let retry_downloader = RetryDownloader::new(
206            Arc::new(mock_file_downloader),
207            FileDownloadRetryPolicy {
208                attempts: 3,
209                delay_between_attempts: Duration::from_millis(10),
210            },
211        );
212
213        retry_downloader
214            .download_unpack(
215                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
216                0,
217                Path::new("."),
218                None,
219                DownloadEvent::Immutable {
220                    immutable_file_number: 1,
221                    download_id: "download_id".to_string(),
222                },
223            )
224            .await
225            .expect_err("An error should be returned when all download attempts fail");
226    }
227
228    #[tokio::test]
229    async fn should_delay_between_retries() {
230        struct FileDownloaderAssertDelay {
231            expected_delay_greater_than_or_equal: Duration,
232            expected_delay_less_than: Duration,
233            last_attempt_start_time: Mutex<Option<Instant>>,
234        }
235
236        #[async_trait]
237        impl FileDownloader for FileDownloaderAssertDelay {
238            async fn download_unpack(
239                &self,
240                _location: &FileDownloaderUri,
241                _file_size: u64,
242                _target_dir: &Path,
243                _compression_algorithm: Option<CompressionAlgorithm>,
244                _download_event_type: DownloadEvent,
245            ) -> StdResult<()> {
246                let mut last_attempt_start_time = self.last_attempt_start_time.lock().unwrap();
247                if let Some(last_start_attempt) = *last_attempt_start_time {
248                    let duration = last_start_attempt.elapsed();
249                    assert!(
250                        duration >= self.expected_delay_greater_than_or_equal,
251                        "duration should be greater than or equal to {}ms but was {}ms",
252                        self.expected_delay_greater_than_or_equal.as_millis(),
253                        duration.as_millis()
254                    );
255                    assert!(
256                        duration < self.expected_delay_less_than,
257                        "duration should be less than {}ms but was {}ms",
258                        self.expected_delay_less_than.as_millis(),
259                        duration.as_millis()
260                    );
261                }
262                *last_attempt_start_time = Some(Instant::now());
263
264                Err(anyhow::anyhow!("Download failed"))
265            }
266        }
267
268        let delay_ms = 50;
269        let mock_file_downloader = Arc::new(FileDownloaderAssertDelay {
270            expected_delay_greater_than_or_equal: Duration::from_millis(delay_ms),
271            expected_delay_less_than: Duration::from_millis(2 * delay_ms),
272            last_attempt_start_time: Mutex::new(None),
273        });
274
275        let retry_downloader = RetryDownloader::new(
276            mock_file_downloader.clone(),
277            FileDownloadRetryPolicy {
278                attempts: 4,
279                delay_between_attempts: Duration::from_millis(delay_ms),
280            },
281        );
282
283        retry_downloader
284            .download_unpack(
285                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
286                0,
287                Path::new("."),
288                None,
289                DownloadEvent::Immutable {
290                    immutable_file_number: 1,
291                    download_id: "download_id".to_string(),
292                },
293            )
294            .await
295            .expect_err("An error should be returned when all download attempts fail");
296    }
297}