mithril_client/file_downloader/
http.rs1use std::{
2 io::{BufReader, Read, Write},
3 path::Path,
4};
5
6use anyhow::{Context, anyhow};
7use async_trait::async_trait;
8use flate2::read::GzDecoder;
9use flume::{Receiver, Sender};
10use futures::StreamExt;
11use reqwest::{Response, StatusCode, Url};
12use slog::{Logger, debug};
13use tar::Archive;
14use tokio::fs::File;
15use tokio::io::AsyncReadExt;
16
17use mithril_common::{StdResult, logging::LoggerExtensions};
18
19use crate::common::CompressionAlgorithm;
20use crate::feedback::FeedbackSender;
21use crate::utils::StreamReader;
22
23use super::{FileDownloader, FileDownloaderUri, interface::DownloadEvent};
24
25pub struct HttpFileDownloader {
27 http_client: reqwest::Client,
28 feedback_sender: FeedbackSender,
29 logger: Logger,
30}
31
32impl HttpFileDownloader {
33 pub fn new(feedback_sender: FeedbackSender, logger: Logger) -> StdResult<Self> {
35 let http_client = reqwest::ClientBuilder::new()
36 .build()
37 .with_context(|| "Building http client for HttpFileDownloader failed")?;
38
39 Ok(Self {
40 http_client,
41 feedback_sender,
42 logger: logger.new_with_component_name::<Self>(),
43 })
44 }
45
46 async fn get(&self, location: &str) -> StdResult<Response> {
47 debug!(self.logger, "GET Snapshot location='{location}'.");
48 let request_builder = self.http_client.get(location);
49 let response = request_builder.send().await.with_context(|| {
50 format!("Cannot perform a GET for the snapshot (location='{location}')")
51 })?;
52
53 match response.status() {
54 StatusCode::OK => Ok(response),
55 StatusCode::NOT_FOUND => Err(anyhow!("Location='{location} not found")),
56 status_code => Err(anyhow!("Unhandled error {status_code}")),
57 }
58 }
59
60 fn file_scheme_to_local_path(file_url: &str) -> Option<String> {
61 Url::parse(file_url)
62 .ok()
63 .filter(|url| url.scheme() == "file")
64 .and_then(|url| url.to_file_path().ok())
65 .map(|path| path.to_string_lossy().into_owned())
66 }
67
68 async fn download_local_file(
70 &self,
71 local_path: &str,
72 sender: &Sender<Vec<u8>>,
73 download_event_type: DownloadEvent,
74 file_size: u64,
75 ) -> StdResult<()> {
76 let mut downloaded_bytes: u64 = 0;
77 let mut file = File::open(local_path).await?;
78 let size = match file.metadata().await {
79 Ok(metadata) => metadata.len(),
80 Err(_) => file_size,
81 };
82
83 self.feedback_sender
84 .send_event(download_event_type.build_download_started_event(size))
85 .await;
86
87 loop {
88 let mut buffer = vec![0; 16 * 1024 * 1024];
91 let bytes_read = file.read(&mut buffer).await?;
92 if bytes_read == 0 {
93 break;
94 }
95 buffer.truncate(bytes_read);
96 sender.send_async(buffer).await.with_context(|| {
97 format!("Local file read: could not write {bytes_read} bytes to stream.")
98 })?;
99 downloaded_bytes += bytes_read as u64;
100 let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
101 self.feedback_sender.send_event(event).await;
102 }
103
104 self.feedback_sender
105 .send_event(download_event_type.build_download_completed_event())
106 .await;
107
108 Ok(())
109 }
110
111 async fn download_remote_file(
113 &self,
114 location: &str,
115 sender: &Sender<Vec<u8>>,
116 download_event_type: DownloadEvent,
117 file_size: u64,
118 ) -> StdResult<()> {
119 let mut downloaded_bytes: u64 = 0;
120 let response = self.get(location).await?;
121 let size = response.content_length().unwrap_or(file_size);
122 let mut remote_stream = response.bytes_stream();
123
124 self.feedback_sender
125 .send_event(download_event_type.build_download_started_event(size))
126 .await;
127
128 while let Some(item) = remote_stream.next().await {
129 let chunk = item.with_context(|| "Download: Could not read from byte stream")?;
130 sender.send_async(chunk.to_vec()).await.with_context(|| {
131 format!("Download: could not write {} bytes to stream.", chunk.len())
132 })?;
133 downloaded_bytes += chunk.len() as u64;
134 let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
135 self.feedback_sender.send_event(event).await;
136 }
137
138 self.feedback_sender
139 .send_event(download_event_type.build_download_completed_event())
140 .await;
141
142 Ok(())
143 }
144
145 fn unpack_file(
146 stream: Receiver<Vec<u8>>,
147 compression_algorithm: Option<CompressionAlgorithm>,
148 unpack_dir: &Path,
149 download_id: String,
150 ) -> StdResult<()> {
151 let input = StreamReader::new(stream);
152 match compression_algorithm {
153 Some(CompressionAlgorithm::Gzip) => {
154 let gzip_decoder = GzDecoder::new(input);
155 let mut file_archive = Archive::new(gzip_decoder);
156 file_archive.unpack(unpack_dir).with_context(|| {
157 format!(
158 "Could not unpack with 'Gzip' from streamed data to directory '{}'",
159 unpack_dir.display()
160 )
161 })?;
162 }
163 Some(CompressionAlgorithm::Zstandard) => {
164 let zstandard_decoder = zstd::Decoder::new(input)
165 .with_context(|| "Unpack failed: Create Zstandard decoder error")?;
166 let mut file_archive = Archive::new(zstandard_decoder);
167 file_archive.unpack(unpack_dir).with_context(|| {
168 format!(
169 "Could not unpack with 'Zstd' from streamed data to directory '{}'",
170 unpack_dir.display()
171 )
172 })?;
173 }
174 None => {
175 let file_path = unpack_dir.join(download_id);
176 if file_path.exists() {
177 std::fs::remove_file(file_path.clone())?;
178 }
179 let mut file = std::fs::File::create(file_path)?;
180 let input_buffered = BufReader::new(input);
181 for byte in input_buffered.bytes() {
182 file.write_all(&[byte?])?;
183 }
184 file.flush()?;
185 }
186 };
187
188 Ok(())
189 }
190}
191
192#[async_trait]
193impl FileDownloader for HttpFileDownloader {
194 async fn download_unpack(
195 &self,
196 location: &FileDownloaderUri,
197 file_size: u64,
198 target_dir: &Path,
199 compression_algorithm: Option<CompressionAlgorithm>,
200 download_event_type: DownloadEvent,
201 ) -> StdResult<()> {
202 if !target_dir.is_dir() {
203 Err(
204 anyhow!("target path is not a directory or does not exist: `{target_dir:?}`")
205 .context("Download-Unpack: prerequisite error"),
206 )?;
207 }
208
209 let (sender, receiver) = flume::bounded(32);
210 let dest_dir = target_dir.to_path_buf();
211 let download_id = download_event_type.download_id().to_owned();
212 let unpack_thread = tokio::task::spawn_blocking(move || -> StdResult<()> {
213 Self::unpack_file(receiver, compression_algorithm, &dest_dir, download_id)
214 });
215 if let Some(local_path) = Self::file_scheme_to_local_path(location.as_str()) {
216 self.download_local_file(&local_path, &sender, download_event_type, file_size)
217 .await?;
218 } else {
219 self.download_remote_file(location.as_str(), &sender, download_event_type, file_size)
220 .await?;
221 }
222 drop(sender);
223 unpack_thread
224 .await
225 .with_context(|| {
226 format!(
227 "Unpack: panic while unpacking to dir '{}'",
228 target_dir.display()
229 )
230 })?
231 .with_context(|| {
232 format!("Unpack: could not unpack to dir '{}'", target_dir.display())
233 })?;
234
235 Ok(())
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use std::sync::Arc;
242
243 use httpmock::MockServer;
244
245 use mithril_common::{entities::FileUri, test::TempDir};
246
247 use crate::{
248 feedback::{
249 FeedbackReceiver, MithrilEvent, MithrilEventCardanoDatabase, StackFeedbackReceiver,
250 },
251 test_utils::TestLogger,
252 };
253
254 use super::*;
255
256 #[cfg(not(target_family = "windows"))]
257 fn local_file_uri(path: &Path) -> FileDownloaderUri {
258 FileDownloaderUri::FileUri(FileUri(format!(
259 "file://{}",
260 path.canonicalize().unwrap().to_string_lossy()
261 )))
262 }
263
264 #[cfg(target_family = "windows")]
265 fn local_file_uri(path: &Path) -> FileDownloaderUri {
266 FileDownloaderUri::FileUri(FileUri(format!(
268 "file:/{}",
269 path.canonicalize()
270 .unwrap()
271 .to_string_lossy()
272 .replace("\\", "/")
273 .replace("?/", ""),
274 )))
275 }
276
277 #[tokio::test]
278 async fn test_download_http_file_send_feedback() {
279 let target_dir = TempDir::create(
280 "client-http-downloader",
281 "test_download_http_file_send_feedback",
282 );
283 let content = "Hello, world!";
284 let size = content.len() as u64;
285 let server = MockServer::start();
286 server.mock(|when, then| {
287 when.method(httpmock::Method::GET).path("/snapshot.tar");
288 then.status(200)
289 .body(content)
290 .header(reqwest::header::CONTENT_LENGTH.as_str(), size.to_string());
291 });
292 let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
293 let feedback_receiver_clone = feedback_receiver.clone() as Arc<dyn FeedbackReceiver>;
294 let http_file_downloader = HttpFileDownloader::new(
295 FeedbackSender::new(&[feedback_receiver_clone]),
296 TestLogger::stdout(),
297 )
298 .unwrap();
299 let download_id = "id".to_string();
300
301 http_file_downloader
302 .download_unpack(
303 &FileDownloaderUri::FileUri(FileUri(server.url("/snapshot.tar"))),
304 0,
305 &target_dir,
306 None,
307 DownloadEvent::Digest {
308 download_id: download_id.clone(),
309 },
310 )
311 .await
312 .unwrap();
313
314 let expected_events = vec![
315 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
316 download_id: download_id.clone(),
317 size,
318 }),
319 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
320 download_id: download_id.clone(),
321 downloaded_bytes: size,
322 size,
323 }),
324 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
325 download_id: download_id.clone(),
326 }),
327 ];
328 assert_eq!(expected_events, feedback_receiver.stacked_events());
329 }
330
331 #[tokio::test]
332 async fn test_download_local_file_send_feedback() {
333 let target_dir = TempDir::create(
334 "client-http-downloader",
335 "test_download_local_file_send_feedback",
336 );
337 let content = "Hello, world!";
338 let size = content.len() as u64;
339
340 let source_file_path = target_dir.join("snapshot.txt");
341 let mut file = std::fs::File::create(&source_file_path).unwrap();
342 file.write_all(content.as_bytes()).unwrap();
343
344 let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
345 let feedback_receiver_clone = feedback_receiver.clone() as Arc<dyn FeedbackReceiver>;
346 let http_file_downloader = HttpFileDownloader::new(
347 FeedbackSender::new(&[feedback_receiver_clone]),
348 TestLogger::stdout(),
349 )
350 .unwrap();
351 let download_id = "id".to_string();
352
353 http_file_downloader
354 .download_unpack(
355 &local_file_uri(&source_file_path),
356 0,
357 &target_dir,
358 None,
359 DownloadEvent::Digest {
360 download_id: download_id.clone(),
361 },
362 )
363 .await
364 .unwrap();
365
366 let expected_events = vec![
367 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
368 download_id: download_id.clone(),
369 size,
370 }),
371 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
372 download_id: download_id.clone(),
373 downloaded_bytes: size,
374 size,
375 }),
376 MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
377 download_id: download_id.clone(),
378 }),
379 ];
380 assert_eq!(expected_events, feedback_receiver.stacked_events());
381 }
382}