mithril_client/file_downloader/
retry.rs

1use std::{path::Path, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use mithril_common::{StdResult, entities::CompressionAlgorithm};
5
6use super::{DownloadEvent, FileDownloader, FileDownloaderUri};
7
8/// Policy for retrying file downloads.
9#[derive(Debug, PartialEq, Clone)]
10pub struct FileDownloadRetryPolicy {
11    /// Number of attempts to download a file.
12    pub attempts: usize,
13    /// Delay between two attempts.
14    pub delay_between_attempts: Duration,
15}
16
17impl FileDownloadRetryPolicy {
18    /// Create a policy that never retries.
19    pub fn never() -> Self {
20        Self {
21            attempts: 1,
22            delay_between_attempts: Duration::from_secs(0),
23        }
24    }
25}
26
27impl Default for FileDownloadRetryPolicy {
28    /// Create a default retry policy.
29    fn default() -> Self {
30        Self {
31            attempts: 3,
32            delay_between_attempts: Duration::from_secs(5),
33        }
34    }
35}
36
37/// RetryDownloader is a wrapper around FileDownloader that retries downloading a file if it fails.
38pub struct RetryDownloader {
39    /// File downloader to use.
40    file_downloader: Arc<dyn FileDownloader>,
41    /// Number of attempts to download a file.
42    retry_policy: FileDownloadRetryPolicy,
43}
44
45impl RetryDownloader {
46    /// Create a new RetryDownloader.
47    pub fn new(
48        file_downloader: Arc<dyn FileDownloader>,
49        retry_policy: FileDownloadRetryPolicy,
50    ) -> Self {
51        Self {
52            file_downloader,
53            retry_policy,
54        }
55    }
56}
57
58#[async_trait]
59impl FileDownloader for RetryDownloader {
60    async fn download_unpack(
61        &self,
62        location: &FileDownloaderUri,
63        file_size: u64,
64        target_dir: &Path,
65        compression_algorithm: Option<CompressionAlgorithm>,
66        download_event_type: DownloadEvent,
67    ) -> StdResult<()> {
68        let retry_policy = &self.retry_policy;
69        let mut nb_attempts = 0;
70        loop {
71            nb_attempts += 1;
72            match self
73                .file_downloader
74                .download_unpack(
75                    location,
76                    file_size,
77                    target_dir,
78                    compression_algorithm,
79                    download_event_type.clone(),
80                )
81                .await
82            {
83                Ok(result) => return Ok(result),
84                Err(e) if nb_attempts >= retry_policy.attempts => {
85                    return Err(e.context(format!(
86                        "Download of location {location:?} failed after {nb_attempts} attempts",
87                    )));
88                }
89                _ => tokio::time::sleep(retry_policy.delay_between_attempts).await,
90            }
91        }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use std::{sync::Mutex, time::Instant};
98
99    use mithril_common::entities::FileUri;
100
101    use crate::file_downloader::MockFileDownloaderBuilder;
102
103    use super::*;
104
105    #[tokio::test]
106    async fn download_return_the_result_of_download_without_retry() {
107        let mock_file_downloader = MockFileDownloaderBuilder::default()
108            .with_file_uri("http://whatever/00001.tar.gz")
109            .with_compression(None)
110            .with_success()
111            .build();
112        let retry_downloader = RetryDownloader::new(
113            Arc::new(mock_file_downloader),
114            FileDownloadRetryPolicy::never(),
115        );
116
117        retry_downloader
118            .download_unpack(
119                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
120                0,
121                Path::new("."),
122                None,
123                DownloadEvent::Immutable {
124                    immutable_file_number: 1,
125                    download_id: "download_id".to_string(),
126                },
127            )
128            .await
129            .unwrap();
130    }
131
132    #[tokio::test]
133    async fn when_download_fails_do_not_retry_by_default() {
134        let mock_file_downloader = MockFileDownloaderBuilder::default()
135            .with_file_uri("http://whatever/00001.tar.gz")
136            .with_compression(None)
137            .with_failure()
138            .build();
139        let retry_downloader = RetryDownloader::new(
140            Arc::new(mock_file_downloader),
141            FileDownloadRetryPolicy::never(),
142        );
143
144        retry_downloader
145            .download_unpack(
146                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
147                0,
148                Path::new("."),
149                None,
150                DownloadEvent::Immutable {
151                    immutable_file_number: 1,
152                    download_id: "download_id".to_string(),
153                },
154            )
155            .await
156            .expect_err("An error should be returned when download fails");
157    }
158
159    #[tokio::test]
160    async fn should_retry_if_fail() {
161        let mock_file_downloader = MockFileDownloaderBuilder::default()
162            .with_file_uri("http://whatever/00001.tar.gz")
163            .with_compression(None)
164            .with_failure()
165            .with_times(2)
166            .next_call()
167            .with_file_uri("http://whatever/00001.tar.gz")
168            .with_compression(None)
169            .with_times(1)
170            .with_success()
171            .build();
172        let retry_downloader = RetryDownloader::new(
173            Arc::new(mock_file_downloader),
174            FileDownloadRetryPolicy {
175                attempts: 3,
176                delay_between_attempts: Duration::from_millis(10),
177            },
178        );
179
180        retry_downloader
181            .download_unpack(
182                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
183                0,
184                Path::new("."),
185                None,
186                DownloadEvent::Ancillary {
187                    download_id: "download_id".to_string(),
188                },
189            )
190            .await
191            .unwrap();
192    }
193
194    #[tokio::test]
195    async fn should_recall_a_failing_inner_downloader_up_to_the_limit() {
196        let mock_file_downloader = MockFileDownloaderBuilder::default()
197            .with_file_uri("http://whatever/00001.tar.gz")
198            .with_compression(None)
199            .with_failure()
200            .with_times(3)
201            .build();
202        let retry_downloader = RetryDownloader::new(
203            Arc::new(mock_file_downloader),
204            FileDownloadRetryPolicy {
205                attempts: 3,
206                delay_between_attempts: Duration::from_millis(10),
207            },
208        );
209
210        retry_downloader
211            .download_unpack(
212                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
213                0,
214                Path::new("."),
215                None,
216                DownloadEvent::Immutable {
217                    immutable_file_number: 1,
218                    download_id: "download_id".to_string(),
219                },
220            )
221            .await
222            .expect_err("An error should be returned when all download attempts fail");
223    }
224
225    #[tokio::test]
226    async fn should_delay_between_retries() {
227        struct FileDownloaderAssertDelay {
228            expected_delay_greater_than_or_equal: Duration,
229            expected_delay_less_than: Duration,
230            last_attempt_start_time: Mutex<Option<Instant>>,
231        }
232
233        #[async_trait]
234        impl FileDownloader for FileDownloaderAssertDelay {
235            async fn download_unpack(
236                &self,
237                _location: &FileDownloaderUri,
238                _file_size: u64,
239                _target_dir: &Path,
240                _compression_algorithm: Option<CompressionAlgorithm>,
241                _download_event_type: DownloadEvent,
242            ) -> StdResult<()> {
243                let mut last_attempt_start_time = self.last_attempt_start_time.lock().unwrap();
244                if let Some(last_start_attempt) = *last_attempt_start_time {
245                    let duration = last_start_attempt.elapsed();
246                    assert!(
247                        duration >= self.expected_delay_greater_than_or_equal,
248                        "duration should be greater than or equal to {}ms but was {}ms",
249                        self.expected_delay_greater_than_or_equal.as_millis(),
250                        duration.as_millis()
251                    );
252                    assert!(
253                        duration < self.expected_delay_less_than,
254                        "duration should be less than {}ms but was {}ms",
255                        self.expected_delay_less_than.as_millis(),
256                        duration.as_millis()
257                    );
258                }
259                *last_attempt_start_time = Some(Instant::now());
260
261                Err(anyhow::anyhow!("Download failed"))
262            }
263        }
264
265        let delay_ms = 50;
266        let mock_file_downloader = Arc::new(FileDownloaderAssertDelay {
267            expected_delay_greater_than_or_equal: Duration::from_millis(delay_ms),
268            expected_delay_less_than: Duration::from_millis(2 * delay_ms),
269            last_attempt_start_time: Mutex::new(None),
270        });
271
272        let retry_downloader = RetryDownloader::new(
273            mock_file_downloader.clone(),
274            FileDownloadRetryPolicy {
275                attempts: 4,
276                delay_between_attempts: Duration::from_millis(delay_ms),
277            },
278        );
279
280        retry_downloader
281            .download_unpack(
282                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
283                0,
284                Path::new("."),
285                None,
286                DownloadEvent::Immutable {
287                    immutable_file_number: 1,
288                    download_id: "download_id".to_string(),
289                },
290            )
291            .await
292            .expect_err("An error should be returned when all download attempts fail");
293    }
294}