mithril_client/file_downloader/
retry.rs1use std::{path::Path, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use mithril_common::{entities::CompressionAlgorithm, StdResult};
5
6use super::{DownloadEvent, FileDownloader, FileDownloaderUri};
7
8#[derive(Debug, PartialEq, Clone)]
10pub struct FileDownloadRetryPolicy {
11 pub attempts: usize,
13 pub delay_between_attempts: Duration,
15}
16
17impl FileDownloadRetryPolicy {
18 pub fn never() -> Self {
20 Self {
21 attempts: 1,
22 delay_between_attempts: Duration::from_secs(0),
23 }
24 }
25}
26
27impl Default for FileDownloadRetryPolicy {
28 fn default() -> Self {
30 Self {
31 attempts: 3,
32 delay_between_attempts: Duration::from_secs(5),
33 }
34 }
35}
36
37pub struct RetryDownloader {
39 file_downloader: Arc<dyn FileDownloader>,
41 retry_policy: FileDownloadRetryPolicy,
43}
44
45impl RetryDownloader {
46 pub fn new(
48 file_downloader: Arc<dyn FileDownloader>,
49 retry_policy: FileDownloadRetryPolicy,
50 ) -> Self {
51 Self {
52 file_downloader,
53 retry_policy,
54 }
55 }
56}
57
58#[async_trait]
59impl FileDownloader for RetryDownloader {
60 async fn download_unpack(
61 &self,
62 location: &FileDownloaderUri,
63 file_size: u64,
64 target_dir: &Path,
65 compression_algorithm: Option<CompressionAlgorithm>,
66 download_event_type: DownloadEvent,
67 ) -> StdResult<()> {
68 let retry_policy = &self.retry_policy;
69 let mut nb_attempts = 0;
70 loop {
71 nb_attempts += 1;
72 match self
73 .file_downloader
74 .download_unpack(
75 location,
76 file_size,
77 target_dir,
78 compression_algorithm,
79 download_event_type.clone(),
80 )
81 .await
82 {
83 Ok(result) => return Ok(result),
84 Err(_) if nb_attempts >= retry_policy.attempts => {
85 return Err(anyhow::anyhow!(
86 "Download of location {:?} failed after {} attempts",
87 location,
88 nb_attempts
89 ));
90 }
91 _ => tokio::time::sleep(retry_policy.delay_between_attempts).await,
92 }
93 }
94 }
95}
96
97#[cfg(test)]
98mod tests {
99
100 use std::{sync::Mutex, time::Instant};
101
102 use mithril_common::entities::FileUri;
103
104 use crate::file_downloader::MockFileDownloaderBuilder;
105
106 use super::*;
107
108 #[tokio::test]
109 async fn download_return_the_result_of_download_without_retry() {
110 let mock_file_downloader = MockFileDownloaderBuilder::default()
111 .with_file_uri("http://whatever/00001.tar.gz")
112 .with_compression(None)
113 .with_success()
114 .build();
115 let retry_downloader = RetryDownloader::new(
116 Arc::new(mock_file_downloader),
117 FileDownloadRetryPolicy::never(),
118 );
119
120 retry_downloader
121 .download_unpack(
122 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
123 0,
124 Path::new("."),
125 None,
126 DownloadEvent::Immutable {
127 immutable_file_number: 1,
128 download_id: "download_id".to_string(),
129 },
130 )
131 .await
132 .unwrap();
133 }
134
135 #[tokio::test]
136 async fn when_download_fails_do_not_retry_by_default() {
137 let mock_file_downloader = MockFileDownloaderBuilder::default()
138 .with_file_uri("http://whatever/00001.tar.gz")
139 .with_compression(None)
140 .with_failure()
141 .build();
142 let retry_downloader = RetryDownloader::new(
143 Arc::new(mock_file_downloader),
144 FileDownloadRetryPolicy::never(),
145 );
146
147 retry_downloader
148 .download_unpack(
149 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
150 0,
151 Path::new("."),
152 None,
153 DownloadEvent::Immutable {
154 immutable_file_number: 1,
155 download_id: "download_id".to_string(),
156 },
157 )
158 .await
159 .expect_err("An error should be returned when download fails");
160 }
161
162 #[tokio::test]
163 async fn should_retry_if_fail() {
164 let mock_file_downloader = MockFileDownloaderBuilder::default()
165 .with_file_uri("http://whatever/00001.tar.gz")
166 .with_compression(None)
167 .with_failure()
168 .with_times(2)
169 .next_call()
170 .with_file_uri("http://whatever/00001.tar.gz")
171 .with_compression(None)
172 .with_times(1)
173 .with_success()
174 .build();
175 let retry_downloader = RetryDownloader::new(
176 Arc::new(mock_file_downloader),
177 FileDownloadRetryPolicy {
178 attempts: 3,
179 delay_between_attempts: Duration::from_millis(10),
180 },
181 );
182
183 retry_downloader
184 .download_unpack(
185 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
186 0,
187 Path::new("."),
188 None,
189 DownloadEvent::Ancillary {
190 download_id: "download_id".to_string(),
191 },
192 )
193 .await
194 .unwrap();
195 }
196
197 #[tokio::test]
198 async fn should_recall_a_failing_inner_downloader_up_to_the_limit() {
199 let mock_file_downloader = MockFileDownloaderBuilder::default()
200 .with_file_uri("http://whatever/00001.tar.gz")
201 .with_compression(None)
202 .with_failure()
203 .with_times(3)
204 .build();
205 let retry_downloader = RetryDownloader::new(
206 Arc::new(mock_file_downloader),
207 FileDownloadRetryPolicy {
208 attempts: 3,
209 delay_between_attempts: Duration::from_millis(10),
210 },
211 );
212
213 retry_downloader
214 .download_unpack(
215 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
216 0,
217 Path::new("."),
218 None,
219 DownloadEvent::Immutable {
220 immutable_file_number: 1,
221 download_id: "download_id".to_string(),
222 },
223 )
224 .await
225 .expect_err("An error should be returned when all download attempts fail");
226 }
227
228 #[tokio::test]
229 async fn should_delay_between_retries() {
230 struct FileDownloaderAssertDelay {
231 expected_delay_greater_than_or_equal: Duration,
232 expected_delay_less_than: Duration,
233 last_attempt_start_time: Mutex<Option<Instant>>,
234 }
235
236 #[async_trait]
237 impl FileDownloader for FileDownloaderAssertDelay {
238 async fn download_unpack(
239 &self,
240 _location: &FileDownloaderUri,
241 _file_size: u64,
242 _target_dir: &Path,
243 _compression_algorithm: Option<CompressionAlgorithm>,
244 _download_event_type: DownloadEvent,
245 ) -> StdResult<()> {
246 let mut last_attempt_start_time = self.last_attempt_start_time.lock().unwrap();
247 if let Some(last_start_attempt) = *last_attempt_start_time {
248 let duration = last_start_attempt.elapsed();
249 assert!(
250 duration >= self.expected_delay_greater_than_or_equal,
251 "duration should be greater than or equal to {}ms but was {}ms",
252 self.expected_delay_greater_than_or_equal.as_millis(),
253 duration.as_millis()
254 );
255 assert!(
256 duration < self.expected_delay_less_than,
257 "duration should be less than {}ms but was {}ms",
258 self.expected_delay_less_than.as_millis(),
259 duration.as_millis()
260 );
261 }
262 *last_attempt_start_time = Some(Instant::now());
263
264 Err(anyhow::anyhow!("Download failed"))
265 }
266 }
267
268 let delay_ms = 50;
269 let mock_file_downloader = Arc::new(FileDownloaderAssertDelay {
270 expected_delay_greater_than_or_equal: Duration::from_millis(delay_ms),
271 expected_delay_less_than: Duration::from_millis(2 * delay_ms),
272 last_attempt_start_time: Mutex::new(None),
273 });
274
275 let retry_downloader = RetryDownloader::new(
276 mock_file_downloader.clone(),
277 FileDownloadRetryPolicy {
278 attempts: 4,
279 delay_between_attempts: Duration::from_millis(delay_ms),
280 },
281 );
282
283 retry_downloader
284 .download_unpack(
285 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
286 0,
287 Path::new("."),
288 None,
289 DownloadEvent::Immutable {
290 immutable_file_number: 1,
291 download_id: "download_id".to_string(),
292 },
293 )
294 .await
295 .expect_err("An error should be returned when all download attempts fail");
296 }
297}