mithril_client/file_downloader/
http.rs1use std::{
2 io::{BufReader, Read, Write},
3 path::Path,
4};
5
6use anyhow::{anyhow, Context};
7use async_trait::async_trait;
8use flate2::read::GzDecoder;
9use flume::{Receiver, Sender};
10use futures::StreamExt;
11use reqwest::{Response, StatusCode, Url};
12use slog::{debug, Logger};
13use tar::Archive;
14use tokio::fs::File;
15use tokio::io::AsyncReadExt;
16
17use mithril_common::{logging::LoggerExtensions, StdResult};
18
19use crate::common::CompressionAlgorithm;
20use crate::feedback::FeedbackSender;
21use crate::utils::StreamReader;
22
23use super::{interface::DownloadEvent, FileDownloader, FileDownloaderUri};
24
25pub struct HttpFileDownloader {
27 http_client: reqwest::Client,
28 feedback_sender: FeedbackSender,
29 logger: Logger,
30}
31
32impl HttpFileDownloader {
33 pub fn new(feedback_sender: FeedbackSender, logger: Logger) -> StdResult<Self> {
35 let http_client = reqwest::ClientBuilder::new()
36 .build()
37 .with_context(|| "Building http client for HttpFileDownloader failed")?;
38
39 Ok(Self {
40 http_client,
41 feedback_sender,
42 logger: logger.new_with_component_name::<Self>(),
43 })
44 }
45
46 async fn get(&self, location: &str) -> StdResult<Response> {
47 debug!(self.logger, "GET Snapshot location='{location}'.");
48 let request_builder = self.http_client.get(location);
49 let response = request_builder.send().await.with_context(|| {
50 format!("Cannot perform a GET for the snapshot (location='{location}')")
51 })?;
52
53 match response.status() {
54 StatusCode::OK => Ok(response),
55 StatusCode::NOT_FOUND => Err(anyhow!("Location='{location} not found")),
56 status_code => Err(anyhow!("Unhandled error {status_code}")),
57 }
58 }
59
60 fn file_scheme_to_local_path(file_url: &str) -> Option<String> {
61 Url::parse(file_url)
62 .ok()
63 .filter(|url| url.scheme() == "file")
64 .and_then(|url| url.to_file_path().ok())
65 .map(|path| path.to_string_lossy().into_owned())
66 }
67
68 async fn download_local_file(
70 &self,
71 local_path: &str,
72 sender: &Sender<Vec<u8>>,
73 download_event_type: DownloadEvent,
74 file_size: u64,
75 ) -> StdResult<()> {
76 let mut downloaded_bytes: u64 = 0;
77 let mut file = File::open(local_path).await?;
78 let size = match file.metadata().await {
79 Ok(metadata) => metadata.len(),
80 Err(_) => file_size,
81 };
82
83 self.feedback_sender
84 .send_event(download_event_type.build_download_started_event(size))
85 .await;
86
87 loop {
88 let mut buffer = vec![0; 16 * 1024 * 1024];
91 let bytes_read = file.read(&mut buffer).await?;
92 if bytes_read == 0 {
93 break;
94 }
95 buffer.truncate(bytes_read);
96 sender.send_async(buffer).await.with_context(|| {
97 format!(
98 "Local file read: could not write {} bytes to stream.",
99 bytes_read
100 )
101 })?;
102 downloaded_bytes += bytes_read as u64;
103 let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
104 self.feedback_sender.send_event(event).await;
105 }
106
107 self.feedback_sender
108 .send_event(download_event_type.build_download_completed_event())
109 .await;
110
111 Ok(())
112 }
113
114 async fn download_remote_file(
116 &self,
117 location: &str,
118 sender: &Sender<Vec<u8>>,
119 download_event_type: DownloadEvent,
120 file_size: u64,
121 ) -> StdResult<()> {
122 let mut downloaded_bytes: u64 = 0;
123 let response = self.get(location).await?;
124 let size = response.content_length().unwrap_or(file_size);
125 let mut remote_stream = response.bytes_stream();
126
127 self.feedback_sender
128 .send_event(download_event_type.build_download_started_event(size))
129 .await;
130
131 while let Some(item) = remote_stream.next().await {
132 let chunk = item.with_context(|| "Download: Could not read from byte stream")?;
133 sender.send_async(chunk.to_vec()).await.with_context(|| {
134 format!("Download: could not write {} bytes to stream.", chunk.len())
135 })?;
136 downloaded_bytes += chunk.len() as u64;
137 let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
138 self.feedback_sender.send_event(event).await;
139 }
140
141 self.feedback_sender
142 .send_event(download_event_type.build_download_completed_event())
143 .await;
144
145 Ok(())
146 }
147
148 fn unpack_file(
149 stream: Receiver<Vec<u8>>,
150 compression_algorithm: Option<CompressionAlgorithm>,
151 unpack_dir: &Path,
152 download_id: String,
153 ) -> StdResult<()> {
154 let input = StreamReader::new(stream);
155 match compression_algorithm {
156 Some(CompressionAlgorithm::Gzip) => {
157 let gzip_decoder = GzDecoder::new(input);
158 let mut file_archive = Archive::new(gzip_decoder);
159 file_archive.unpack(unpack_dir).with_context(|| {
160 format!(
161 "Could not unpack with 'Gzip' from streamed data to directory '{}'",
162 unpack_dir.display()
163 )
164 })?;
165 }
166 Some(CompressionAlgorithm::Zstandard) => {
167 let zstandard_decoder = zstd::Decoder::new(input)
168 .with_context(|| "Unpack failed: Create Zstandard decoder error")?;
169 let mut file_archive = Archive::new(zstandard_decoder);
170 file_archive.unpack(unpack_dir).with_context(|| {
171 format!(
172 "Could not unpack with 'Zstd' from streamed data to directory '{}'",
173 unpack_dir.display()
174 )
175 })?;
176 }
177 None => {
178 let file_path = unpack_dir.join(download_id);
179 if file_path.exists() {
180 std::fs::remove_file(file_path.clone())?;
181 }
182 let mut file = std::fs::File::create(file_path)?;
183 let input_buffered = BufReader::new(input);
184 for byte in input_buffered.bytes() {
185 file.write_all(&[byte?])?;
186 }
187 file.flush()?;
188 }
189 };
190
191 Ok(())
192 }
193}
194
195#[async_trait]
196impl FileDownloader for HttpFileDownloader {
197 async fn download_unpack(
198 &self,
199 location: &FileDownloaderUri,
200 file_size: u64,
201 target_dir: &Path,
202 compression_algorithm: Option<CompressionAlgorithm>,
203 download_event_type: DownloadEvent,
204 ) -> StdResult<()> {
205 if !target_dir.is_dir() {
206 Err(
207 anyhow!("target path is not a directory or does not exist: `{target_dir:?}`")
208 .context("Download-Unpack: prerequisite error"),
209 )?;
210 }
211
212 let (sender, receiver) = flume::bounded(32);
213 let dest_dir = target_dir.to_path_buf();
214 let download_id = download_event_type.download_id().to_owned();
215 let unpack_thread = tokio::task::spawn_blocking(move || -> StdResult<()> {
216 Self::unpack_file(receiver, compression_algorithm, &dest_dir, download_id)
217 });
218 if let Some(local_path) = Self::file_scheme_to_local_path(location.as_str()) {
219 self.download_local_file(&local_path, &sender, download_event_type, file_size)
220 .await?;
221 } else {
222 self.download_remote_file(location.as_str(), &sender, download_event_type, file_size)
223 .await?;
224 }
225 drop(sender);
226 unpack_thread
227 .await
228 .with_context(|| {
229 format!(
230 "Unpack: panic while unpacking to dir '{}'",
231 target_dir.display()
232 )
233 })?
234 .with_context(|| {
235 format!("Unpack: could not unpack to dir '{}'", target_dir.display())
236 })?;
237
238 Ok(())
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use std::sync::Arc;
245
246 use httpmock::MockServer;
247
248 use mithril_common::{entities::FileUri, test_utils::TempDir};
249
250 use crate::{
251 feedback::{MithrilEvent, MithrilEventCardanoDatabase, StackFeedbackReceiver},
252 test_utils::TestLogger,
253 };
254
255 use super::*;
256
257 #[cfg(not(target_family = "windows"))]
258 fn local_file_uri(path: &Path) -> FileDownloaderUri {
259 FileDownloaderUri::FileUri(FileUri(format!(
260 "file://{}",
261 path.canonicalize().unwrap().to_string_lossy()
262 )))
263 }
264
265 #[cfg(target_family = "windows")]
266 fn local_file_uri(path: &Path) -> FileDownloaderUri {
267 FileDownloaderUri::FileUri(FileUri(format!(
269 "file:/{}",
270 path.canonicalize()
271 .unwrap()
272 .to_string_lossy()
273 .replace("\\", "/")
274 .replace("?/", ""),
275 )))
276 }
277
278 #[tokio::test]
279 async fn test_download_http_file_send_feedback() {
280 let target_dir = TempDir::create(
281 "client-http-downloader",
282 "test_download_http_file_send_feedback",
283 );
284 let content = "Hello, world!";
285 let size = content.len() as u64;
286 let server = MockServer::start();
287 server.mock(|when, then| {
288 when.method(httpmock::Method::GET).path("/snapshot.tar");
289 then.status(200)
290 .body(content)
291 .header(reqwest::header::CONTENT_LENGTH.as_str(), size.to_string());
292 });
293 let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
294 let http_file_downloader = HttpFileDownloader::new(
295 FeedbackSender::new(&[feedback_receiver.clone()]),
296 TestLogger::stdout(),
297 )
298 .unwrap();
299 let download_id = "id".to_string();
300
301 http_file_downloader
302 .download_unpack(
303 &FileDownloaderUri::FileUri(FileUri(server.url("/snapshot.tar"))),
304 0,
305 &target_dir,
306 None,
307 DownloadEvent::Digest {
308 download_id: download_id.clone(),
309 },
310 )
311 .await
312 .unwrap();
313
314 let expected_events = vec![
315 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
316 download_id: download_id.clone(),
317 size,
318 }),
319 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
320 download_id: download_id.clone(),
321 downloaded_bytes: size,
322 size,
323 }),
324 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
325 download_id: download_id.clone(),
326 }),
327 ];
328 assert_eq!(expected_events, feedback_receiver.stacked_events());
329 }
330
331 #[tokio::test]
332 async fn test_download_local_file_send_feedback() {
333 let target_dir = TempDir::create(
334 "client-http-downloader",
335 "test_download_local_file_send_feedback",
336 );
337 let content = "Hello, world!";
338 let size = content.len() as u64;
339
340 let source_file_path = target_dir.join("snapshot.txt");
341 let mut file = std::fs::File::create(&source_file_path).unwrap();
342 file.write_all(content.as_bytes()).unwrap();
343
344 let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
345 let http_file_downloader = HttpFileDownloader::new(
346 FeedbackSender::new(&[feedback_receiver.clone()]),
347 TestLogger::stdout(),
348 )
349 .unwrap();
350 let download_id = "id".to_string();
351
352 http_file_downloader
353 .download_unpack(
354 &local_file_uri(&source_file_path),
355 0,
356 &target_dir,
357 None,
358 DownloadEvent::Digest {
359 download_id: download_id.clone(),
360 },
361 )
362 .await
363 .unwrap();
364
365 let expected_events = vec![
366 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
367 download_id: download_id.clone(),
368 size,
369 }),
370 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
371 download_id: download_id.clone(),
372 downloaded_bytes: size,
373 size,
374 }),
375 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
376 download_id: download_id.clone(),
377 }),
378 ];
379 assert_eq!(expected_events, feedback_receiver.stacked_events());
380 }
381}