mithril_client/file_downloader/
retry.rs

1use std::{path::Path, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use mithril_common::{entities::CompressionAlgorithm, StdResult};
5
6use super::{DownloadEvent, FileDownloader, FileDownloaderUri};
7
8/// Policy for retrying file downloads.
9#[derive(Debug, PartialEq, Clone)]
10pub struct FileDownloadRetryPolicy {
11    /// Number of attempts to download a file.
12    pub attempts: usize,
13    /// Delay between two attempts.
14    pub delay_between_attempts: Duration,
15}
16
17impl FileDownloadRetryPolicy {
18    /// Create a policy that never retries.
19    pub fn never() -> Self {
20        Self {
21            attempts: 1,
22            delay_between_attempts: Duration::from_secs(0),
23        }
24    }
25}
26
27impl Default for FileDownloadRetryPolicy {
28    /// Create a default retry policy.
29    fn default() -> Self {
30        Self {
31            attempts: 3,
32            delay_between_attempts: Duration::from_secs(5),
33        }
34    }
35}
36
37/// RetryDownloader is a wrapper around FileDownloader that retries downloading a file if it fails.
38pub struct RetryDownloader {
39    /// File downloader to use.
40    file_downloader: Arc<dyn FileDownloader>,
41    /// Number of attempts to download a file.
42    retry_policy: FileDownloadRetryPolicy,
43}
44
45impl RetryDownloader {
46    /// Create a new RetryDownloader.
47    pub fn new(
48        file_downloader: Arc<dyn FileDownloader>,
49        retry_policy: FileDownloadRetryPolicy,
50    ) -> Self {
51        Self {
52            file_downloader,
53            retry_policy,
54        }
55    }
56}
57
58#[async_trait]
59impl FileDownloader for RetryDownloader {
60    async fn download_unpack(
61        &self,
62        location: &FileDownloaderUri,
63        file_size: u64,
64        target_dir: &Path,
65        compression_algorithm: Option<CompressionAlgorithm>,
66        download_event_type: DownloadEvent,
67    ) -> StdResult<()> {
68        let retry_policy = &self.retry_policy;
69        let mut nb_attempts = 0;
70        loop {
71            nb_attempts += 1;
72            match self
73                .file_downloader
74                .download_unpack(
75                    location,
76                    file_size,
77                    target_dir,
78                    compression_algorithm,
79                    download_event_type.clone(),
80                )
81                .await
82            {
83                Ok(result) => return Ok(result),
84                Err(e) if nb_attempts >= retry_policy.attempts => {
85                    return Err(anyhow::anyhow!(e).context(format!(
86                        "Download of location {location:?} failed after {nb_attempts} attempts",
87                    )));
88                }
89                _ => tokio::time::sleep(retry_policy.delay_between_attempts).await,
90            }
91        }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97
98    use std::{sync::Mutex, time::Instant};
99
100    use mithril_common::entities::FileUri;
101
102    use crate::file_downloader::MockFileDownloaderBuilder;
103
104    use super::*;
105
106    #[tokio::test]
107    async fn download_return_the_result_of_download_without_retry() {
108        let mock_file_downloader = MockFileDownloaderBuilder::default()
109            .with_file_uri("http://whatever/00001.tar.gz")
110            .with_compression(None)
111            .with_success()
112            .build();
113        let retry_downloader = RetryDownloader::new(
114            Arc::new(mock_file_downloader),
115            FileDownloadRetryPolicy::never(),
116        );
117
118        retry_downloader
119            .download_unpack(
120                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
121                0,
122                Path::new("."),
123                None,
124                DownloadEvent::Immutable {
125                    immutable_file_number: 1,
126                    download_id: "download_id".to_string(),
127                },
128            )
129            .await
130            .unwrap();
131    }
132
133    #[tokio::test]
134    async fn when_download_fails_do_not_retry_by_default() {
135        let mock_file_downloader = MockFileDownloaderBuilder::default()
136            .with_file_uri("http://whatever/00001.tar.gz")
137            .with_compression(None)
138            .with_failure()
139            .build();
140        let retry_downloader = RetryDownloader::new(
141            Arc::new(mock_file_downloader),
142            FileDownloadRetryPolicy::never(),
143        );
144
145        retry_downloader
146            .download_unpack(
147                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
148                0,
149                Path::new("."),
150                None,
151                DownloadEvent::Immutable {
152                    immutable_file_number: 1,
153                    download_id: "download_id".to_string(),
154                },
155            )
156            .await
157            .expect_err("An error should be returned when download fails");
158    }
159
160    #[tokio::test]
161    async fn should_retry_if_fail() {
162        let mock_file_downloader = MockFileDownloaderBuilder::default()
163            .with_file_uri("http://whatever/00001.tar.gz")
164            .with_compression(None)
165            .with_failure()
166            .with_times(2)
167            .next_call()
168            .with_file_uri("http://whatever/00001.tar.gz")
169            .with_compression(None)
170            .with_times(1)
171            .with_success()
172            .build();
173        let retry_downloader = RetryDownloader::new(
174            Arc::new(mock_file_downloader),
175            FileDownloadRetryPolicy {
176                attempts: 3,
177                delay_between_attempts: Duration::from_millis(10),
178            },
179        );
180
181        retry_downloader
182            .download_unpack(
183                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
184                0,
185                Path::new("."),
186                None,
187                DownloadEvent::Ancillary {
188                    download_id: "download_id".to_string(),
189                },
190            )
191            .await
192            .unwrap();
193    }
194
195    #[tokio::test]
196    async fn should_recall_a_failing_inner_downloader_up_to_the_limit() {
197        let mock_file_downloader = MockFileDownloaderBuilder::default()
198            .with_file_uri("http://whatever/00001.tar.gz")
199            .with_compression(None)
200            .with_failure()
201            .with_times(3)
202            .build();
203        let retry_downloader = RetryDownloader::new(
204            Arc::new(mock_file_downloader),
205            FileDownloadRetryPolicy {
206                attempts: 3,
207                delay_between_attempts: Duration::from_millis(10),
208            },
209        );
210
211        retry_downloader
212            .download_unpack(
213                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
214                0,
215                Path::new("."),
216                None,
217                DownloadEvent::Immutable {
218                    immutable_file_number: 1,
219                    download_id: "download_id".to_string(),
220                },
221            )
222            .await
223            .expect_err("An error should be returned when all download attempts fail");
224    }
225
226    #[tokio::test]
227    async fn should_delay_between_retries() {
228        struct FileDownloaderAssertDelay {
229            expected_delay_greater_than_or_equal: Duration,
230            expected_delay_less_than: Duration,
231            last_attempt_start_time: Mutex<Option<Instant>>,
232        }
233
234        #[async_trait]
235        impl FileDownloader for FileDownloaderAssertDelay {
236            async fn download_unpack(
237                &self,
238                _location: &FileDownloaderUri,
239                _file_size: u64,
240                _target_dir: &Path,
241                _compression_algorithm: Option<CompressionAlgorithm>,
242                _download_event_type: DownloadEvent,
243            ) -> StdResult<()> {
244                let mut last_attempt_start_time = self.last_attempt_start_time.lock().unwrap();
245                if let Some(last_start_attempt) = *last_attempt_start_time {
246                    let duration = last_start_attempt.elapsed();
247                    assert!(
248                        duration >= self.expected_delay_greater_than_or_equal,
249                        "duration should be greater than or equal to {}ms but was {}ms",
250                        self.expected_delay_greater_than_or_equal.as_millis(),
251                        duration.as_millis()
252                    );
253                    assert!(
254                        duration < self.expected_delay_less_than,
255                        "duration should be less than {}ms but was {}ms",
256                        self.expected_delay_less_than.as_millis(),
257                        duration.as_millis()
258                    );
259                }
260                *last_attempt_start_time = Some(Instant::now());
261
262                Err(anyhow::anyhow!("Download failed"))
263            }
264        }
265
266        let delay_ms = 50;
267        let mock_file_downloader = Arc::new(FileDownloaderAssertDelay {
268            expected_delay_greater_than_or_equal: Duration::from_millis(delay_ms),
269            expected_delay_less_than: Duration::from_millis(2 * delay_ms),
270            last_attempt_start_time: Mutex::new(None),
271        });
272
273        let retry_downloader = RetryDownloader::new(
274            mock_file_downloader.clone(),
275            FileDownloadRetryPolicy {
276                attempts: 4,
277                delay_between_attempts: Duration::from_millis(delay_ms),
278            },
279        );
280
281        retry_downloader
282            .download_unpack(
283                &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
284                0,
285                Path::new("."),
286                None,
287                DownloadEvent::Immutable {
288                    immutable_file_number: 1,
289                    download_id: "download_id".to_string(),
290                },
291            )
292            .await
293            .expect_err("An error should be returned when all download attempts fail");
294    }
295}