mithril_client/file_downloader/
retry.rs1use std::{path::Path, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use mithril_common::{StdResult, entities::CompressionAlgorithm};
5
6use super::{DownloadEvent, FileDownloader, FileDownloaderUri};
7
8#[derive(Debug, PartialEq, Clone)]
10pub struct FileDownloadRetryPolicy {
11 pub attempts: usize,
13 pub delay_between_attempts: Duration,
15}
16
17impl FileDownloadRetryPolicy {
18 pub fn never() -> Self {
20 Self {
21 attempts: 1,
22 delay_between_attempts: Duration::from_secs(0),
23 }
24 }
25}
26
27impl Default for FileDownloadRetryPolicy {
28 fn default() -> Self {
30 Self {
31 attempts: 3,
32 delay_between_attempts: Duration::from_secs(5),
33 }
34 }
35}
36
37pub struct RetryDownloader {
39 file_downloader: Arc<dyn FileDownloader>,
41 retry_policy: FileDownloadRetryPolicy,
43}
44
45impl RetryDownloader {
46 pub fn new(
48 file_downloader: Arc<dyn FileDownloader>,
49 retry_policy: FileDownloadRetryPolicy,
50 ) -> Self {
51 Self {
52 file_downloader,
53 retry_policy,
54 }
55 }
56}
57
58#[async_trait]
59impl FileDownloader for RetryDownloader {
60 async fn download_unpack(
61 &self,
62 location: &FileDownloaderUri,
63 file_size: u64,
64 target_dir: &Path,
65 compression_algorithm: Option<CompressionAlgorithm>,
66 download_event_type: DownloadEvent,
67 ) -> StdResult<()> {
68 let retry_policy = &self.retry_policy;
69 let mut nb_attempts = 0;
70 loop {
71 nb_attempts += 1;
72 match self
73 .file_downloader
74 .download_unpack(
75 location,
76 file_size,
77 target_dir,
78 compression_algorithm,
79 download_event_type.clone(),
80 )
81 .await
82 {
83 Ok(result) => return Ok(result),
84 Err(e) if nb_attempts >= retry_policy.attempts => {
85 return Err(e.context(format!(
86 "Download of location {location:?} failed after {nb_attempts} attempts",
87 )));
88 }
89 _ => tokio::time::sleep(retry_policy.delay_between_attempts).await,
90 }
91 }
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use std::{sync::Mutex, time::Instant};
98
99 use mithril_common::entities::FileUri;
100
101 use crate::file_downloader::MockFileDownloaderBuilder;
102
103 use super::*;
104
105 #[tokio::test]
106 async fn download_return_the_result_of_download_without_retry() {
107 let mock_file_downloader = MockFileDownloaderBuilder::default()
108 .with_file_uri("http://whatever/00001.tar.gz")
109 .with_compression(None)
110 .with_success()
111 .build();
112 let retry_downloader = RetryDownloader::new(
113 Arc::new(mock_file_downloader),
114 FileDownloadRetryPolicy::never(),
115 );
116
117 retry_downloader
118 .download_unpack(
119 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
120 0,
121 Path::new("."),
122 None,
123 DownloadEvent::Immutable {
124 immutable_file_number: 1,
125 download_id: "download_id".to_string(),
126 },
127 )
128 .await
129 .unwrap();
130 }
131
132 #[tokio::test]
133 async fn when_download_fails_do_not_retry_by_default() {
134 let mock_file_downloader = MockFileDownloaderBuilder::default()
135 .with_file_uri("http://whatever/00001.tar.gz")
136 .with_compression(None)
137 .with_failure()
138 .build();
139 let retry_downloader = RetryDownloader::new(
140 Arc::new(mock_file_downloader),
141 FileDownloadRetryPolicy::never(),
142 );
143
144 retry_downloader
145 .download_unpack(
146 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
147 0,
148 Path::new("."),
149 None,
150 DownloadEvent::Immutable {
151 immutable_file_number: 1,
152 download_id: "download_id".to_string(),
153 },
154 )
155 .await
156 .expect_err("An error should be returned when download fails");
157 }
158
159 #[tokio::test]
160 async fn should_retry_if_fail() {
161 let mock_file_downloader = MockFileDownloaderBuilder::default()
162 .with_file_uri("http://whatever/00001.tar.gz")
163 .with_compression(None)
164 .with_failure()
165 .with_times(2)
166 .next_call()
167 .with_file_uri("http://whatever/00001.tar.gz")
168 .with_compression(None)
169 .with_times(1)
170 .with_success()
171 .build();
172 let retry_downloader = RetryDownloader::new(
173 Arc::new(mock_file_downloader),
174 FileDownloadRetryPolicy {
175 attempts: 3,
176 delay_between_attempts: Duration::from_millis(10),
177 },
178 );
179
180 retry_downloader
181 .download_unpack(
182 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
183 0,
184 Path::new("."),
185 None,
186 DownloadEvent::Ancillary {
187 download_id: "download_id".to_string(),
188 },
189 )
190 .await
191 .unwrap();
192 }
193
194 #[tokio::test]
195 async fn should_recall_a_failing_inner_downloader_up_to_the_limit() {
196 let mock_file_downloader = MockFileDownloaderBuilder::default()
197 .with_file_uri("http://whatever/00001.tar.gz")
198 .with_compression(None)
199 .with_failure()
200 .with_times(3)
201 .build();
202 let retry_downloader = RetryDownloader::new(
203 Arc::new(mock_file_downloader),
204 FileDownloadRetryPolicy {
205 attempts: 3,
206 delay_between_attempts: Duration::from_millis(10),
207 },
208 );
209
210 retry_downloader
211 .download_unpack(
212 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
213 0,
214 Path::new("."),
215 None,
216 DownloadEvent::Immutable {
217 immutable_file_number: 1,
218 download_id: "download_id".to_string(),
219 },
220 )
221 .await
222 .expect_err("An error should be returned when all download attempts fail");
223 }
224
225 #[tokio::test]
226 async fn should_delay_between_retries() {
227 struct FileDownloaderAssertDelay {
228 expected_delay_greater_than_or_equal: Duration,
229 expected_delay_less_than: Duration,
230 last_attempt_start_time: Mutex<Option<Instant>>,
231 }
232
233 #[async_trait]
234 impl FileDownloader for FileDownloaderAssertDelay {
235 async fn download_unpack(
236 &self,
237 _location: &FileDownloaderUri,
238 _file_size: u64,
239 _target_dir: &Path,
240 _compression_algorithm: Option<CompressionAlgorithm>,
241 _download_event_type: DownloadEvent,
242 ) -> StdResult<()> {
243 let mut last_attempt_start_time = self.last_attempt_start_time.lock().unwrap();
244 if let Some(last_start_attempt) = *last_attempt_start_time {
245 let duration = last_start_attempt.elapsed();
246 assert!(
247 duration >= self.expected_delay_greater_than_or_equal,
248 "duration should be greater than or equal to {}ms but was {}ms",
249 self.expected_delay_greater_than_or_equal.as_millis(),
250 duration.as_millis()
251 );
252 assert!(
253 duration < self.expected_delay_less_than,
254 "duration should be less than {}ms but was {}ms",
255 self.expected_delay_less_than.as_millis(),
256 duration.as_millis()
257 );
258 }
259 *last_attempt_start_time = Some(Instant::now());
260
261 Err(anyhow::anyhow!("Download failed"))
262 }
263 }
264
265 let delay_ms = 50;
266 let mock_file_downloader = Arc::new(FileDownloaderAssertDelay {
267 expected_delay_greater_than_or_equal: Duration::from_millis(delay_ms),
268 expected_delay_less_than: Duration::from_millis(2 * delay_ms),
269 last_attempt_start_time: Mutex::new(None),
270 });
271
272 let retry_downloader = RetryDownloader::new(
273 mock_file_downloader.clone(),
274 FileDownloadRetryPolicy {
275 attempts: 4,
276 delay_between_attempts: Duration::from_millis(delay_ms),
277 },
278 );
279
280 retry_downloader
281 .download_unpack(
282 &FileDownloaderUri::FileUri(FileUri("http://whatever/00001.tar.gz".to_string())),
283 0,
284 Path::new("."),
285 None,
286 DownloadEvent::Immutable {
287 immutable_file_number: 1,
288 download_id: "download_id".to_string(),
289 },
290 )
291 .await
292 .expect_err("An error should be returned when all download attempts fail");
293 }
294}