mithril_client/file_downloader/
retry.rs1use std::{path::Path, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use mithril_common::{entities::CompressionAlgorithm, StdResult};
5
6use super::{DownloadEvent, FileDownloader, FileDownloaderUri};
7
8#[derive(Debug, PartialEq, Clone)]
10pub struct FileDownloadRetryPolicy {
11 pub attempts: usize,
13 pub delay_between_attempts: Duration,
15}
16
17impl FileDownloadRetryPolicy {
18 pub fn never() -> Self {
20 Self {
21 attempts: 1,
22 delay_between_attempts: Duration::from_secs(0),
23 }
24 }
25}
26
27impl Default for FileDownloadRetryPolicy {
28 fn default() -> Self {
30 Self {
31 attempts: 3,
32 delay_between_attempts: Duration::from_secs(5),
33 }
34 }
35}
36
37pub struct RetryDownloader {
39 file_downloader: Arc<dyn FileDownloader>,
41 retry_policy: FileDownloadRetryPolicy,
43}
44
45impl RetryDownloader {
46 pub fn new(
48 file_downloader: Arc<dyn FileDownloader>,
49 retry_policy: FileDownloadRetryPolicy,
50 ) -> Self {
51 Self {
52 file_downloader,
53 retry_policy,
54 }
55 }
56}
57
58#[async_trait]
59impl FileDownloader for RetryDownloader {
60 async fn download_unpack(
61 &self,
62 location: &FileDownloaderUri,
63 file_size: u64,
64 target_dir: &Path,
65 compression_algorithm: Option<CompressionAlgorithm>,
66 download_event_type: DownloadEvent,
67 ) -> StdResult<()> {
68 let retry_policy = &self.retry_policy;
69 let mut nb_attempts = 0;
70 loop {
71 nb_attempts += 1;
72 match self
73 .file_downloader
74 .download_unpack(
75 location,
76 file_size,
77 target_dir,
78 compression_algorithm,
79 download_event_type.clone(),
80 )
81 .await
82 {
83 Ok(result) => return Ok(result),
84 Err(e) if nb_attempts >= retry_policy.attempts => {
85 return Err(anyhow::anyhow!(e).context(format!(
86 "Download of location {location:?} failed after {nb_attempts} attempts",
87 )));
88 }
89 _ => tokio::time::sleep(retry_policy.delay_between_attempts).await,
90 }
91 }
92 }
93}
94
95#[cfg(test)]
96mod tests {
97
98 use std::{sync::Mutex, time::Instant};
99
100 use mithril_common::entities::FileUri;
101
102 use crate::file_downloader::MockFileDownloaderBuilder;
103
104 use super::*;
105
106 #[tokio::test]
107 async fn download_return_the_result_of_download_without_retry() {
108 let mock_file_downloader = MockFileDownloaderBuilder::default()
109 .with_file_uri("http://whatever/00001.tar.gz")
110 .with_compression(None)
111 .with_success()
112 .build();
113 let retry_downloader = RetryDownloader::new(
114 Arc::new(mock_file_downloader),
115 FileDownloadRetryPolicy::never(),
116 );
117
118 retry_downloader
119 .download_unpack(
120 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
121 0,
122 Path::new("."),
123 None,
124 DownloadEvent::Immutable {
125 immutable_file_number: 1,
126 download_id: "download_id".to_string(),
127 },
128 )
129 .await
130 .unwrap();
131 }
132
133 #[tokio::test]
134 async fn when_download_fails_do_not_retry_by_default() {
135 let mock_file_downloader = MockFileDownloaderBuilder::default()
136 .with_file_uri("http://whatever/00001.tar.gz")
137 .with_compression(None)
138 .with_failure()
139 .build();
140 let retry_downloader = RetryDownloader::new(
141 Arc::new(mock_file_downloader),
142 FileDownloadRetryPolicy::never(),
143 );
144
145 retry_downloader
146 .download_unpack(
147 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
148 0,
149 Path::new("."),
150 None,
151 DownloadEvent::Immutable {
152 immutable_file_number: 1,
153 download_id: "download_id".to_string(),
154 },
155 )
156 .await
157 .expect_err("An error should be returned when download fails");
158 }
159
160 #[tokio::test]
161 async fn should_retry_if_fail() {
162 let mock_file_downloader = MockFileDownloaderBuilder::default()
163 .with_file_uri("http://whatever/00001.tar.gz")
164 .with_compression(None)
165 .with_failure()
166 .with_times(2)
167 .next_call()
168 .with_file_uri("http://whatever/00001.tar.gz")
169 .with_compression(None)
170 .with_times(1)
171 .with_success()
172 .build();
173 let retry_downloader = RetryDownloader::new(
174 Arc::new(mock_file_downloader),
175 FileDownloadRetryPolicy {
176 attempts: 3,
177 delay_between_attempts: Duration::from_millis(10),
178 },
179 );
180
181 retry_downloader
182 .download_unpack(
183 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
184 0,
185 Path::new("."),
186 None,
187 DownloadEvent::Ancillary {
188 download_id: "download_id".to_string(),
189 },
190 )
191 .await
192 .unwrap();
193 }
194
195 #[tokio::test]
196 async fn should_recall_a_failing_inner_downloader_up_to_the_limit() {
197 let mock_file_downloader = MockFileDownloaderBuilder::default()
198 .with_file_uri("http://whatever/00001.tar.gz")
199 .with_compression(None)
200 .with_failure()
201 .with_times(3)
202 .build();
203 let retry_downloader = RetryDownloader::new(
204 Arc::new(mock_file_downloader),
205 FileDownloadRetryPolicy {
206 attempts: 3,
207 delay_between_attempts: Duration::from_millis(10),
208 },
209 );
210
211 retry_downloader
212 .download_unpack(
213 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
214 0,
215 Path::new("."),
216 None,
217 DownloadEvent::Immutable {
218 immutable_file_number: 1,
219 download_id: "download_id".to_string(),
220 },
221 )
222 .await
223 .expect_err("An error should be returned when all download attempts fail");
224 }
225
226 #[tokio::test]
227 async fn should_delay_between_retries() {
228 struct FileDownloaderAssertDelay {
229 expected_delay_greater_than_or_equal: Duration,
230 expected_delay_less_than: Duration,
231 last_attempt_start_time: Mutex<Option<Instant>>,
232 }
233
234 #[async_trait]
235 impl FileDownloader for FileDownloaderAssertDelay {
236 async fn download_unpack(
237 &self,
238 _location: &FileDownloaderUri,
239 _file_size: u64,
240 _target_dir: &Path,
241 _compression_algorithm: Option<CompressionAlgorithm>,
242 _download_event_type: DownloadEvent,
243 ) -> StdResult<()> {
244 let mut last_attempt_start_time = self.last_attempt_start_time.lock().unwrap();
245 if let Some(last_start_attempt) = *last_attempt_start_time {
246 let duration = last_start_attempt.elapsed();
247 assert!(
248 duration >= self.expected_delay_greater_than_or_equal,
249 "duration should be greater than or equal to {}ms but was {}ms",
250 self.expected_delay_greater_than_or_equal.as_millis(),
251 duration.as_millis()
252 );
253 assert!(
254 duration < self.expected_delay_less_than,
255 "duration should be less than {}ms but was {}ms",
256 self.expected_delay_less_than.as_millis(),
257 duration.as_millis()
258 );
259 }
260 *last_attempt_start_time = Some(Instant::now());
261
262 Err(anyhow::anyhow!("Download failed"))
263 }
264 }
265
266 let delay_ms = 50;
267 let mock_file_downloader = Arc::new(FileDownloaderAssertDelay {
268 expected_delay_greater_than_or_equal: Duration::from_millis(delay_ms),
269 expected_delay_less_than: Duration::from_millis(2 * delay_ms),
270 last_attempt_start_time: Mutex::new(None),
271 });
272
273 let retry_downloader = RetryDownloader::new(
274 mock_file_downloader.clone(),
275 FileDownloadRetryPolicy {
276 attempts: 4,
277 delay_between_attempts: Duration::from_millis(delay_ms),
278 },
279 );
280
281 retry_downloader
282 .download_unpack(
283 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
284 0,
285 Path::new("."),
286 None,
287 DownloadEvent::Immutable {
288 immutable_file_number: 1,
289 download_id: "download_id".to_string(),
290 },
291 )
292 .await
293 .expect_err("An error should be returned when all download attempts fail");
294 }
295}