mithril_client/file_downloader/
http.rs

1use std::{
2    io::{BufReader, Read, Write},
3    path::Path,
4};
5
6use anyhow::{anyhow, Context};
7use async_trait::async_trait;
8use flate2::read::GzDecoder;
9use flume::{Receiver, Sender};
10use futures::StreamExt;
11use reqwest::{Response, StatusCode, Url};
12use slog::{debug, Logger};
13use tar::Archive;
14use tokio::fs::File;
15use tokio::io::AsyncReadExt;
16
17use mithril_common::{logging::LoggerExtensions, StdResult};
18
19use crate::common::CompressionAlgorithm;
20use crate::feedback::FeedbackSender;
21use crate::utils::StreamReader;
22
23use super::{interface::DownloadEvent, FileDownloader, FileDownloaderUri};
24
25/// A file downloader that only handles download through HTTP.
26pub struct HttpFileDownloader {
27    http_client: reqwest::Client,
28    feedback_sender: FeedbackSender,
29    logger: Logger,
30}
31
32impl HttpFileDownloader {
33    /// Constructs a new `HttpFileDownloader`.
34    pub fn new(feedback_sender: FeedbackSender, logger: Logger) -> StdResult<Self> {
35        let http_client = reqwest::ClientBuilder::new()
36            .build()
37            .with_context(|| "Building http client for HttpFileDownloader failed")?;
38
39        Ok(Self {
40            http_client,
41            feedback_sender,
42            logger: logger.new_with_component_name::<Self>(),
43        })
44    }
45
46    async fn get(&self, location: &str) -> StdResult<Response> {
47        debug!(self.logger, "GET Snapshot location='{location}'.");
48        let request_builder = self.http_client.get(location);
49        let response = request_builder.send().await.with_context(|| {
50            format!("Cannot perform a GET for the snapshot (location='{location}')")
51        })?;
52
53        match response.status() {
54            StatusCode::OK => Ok(response),
55            StatusCode::NOT_FOUND => Err(anyhow!("Location='{location} not found")),
56            status_code => Err(anyhow!("Unhandled error {status_code}")),
57        }
58    }
59
60    fn file_scheme_to_local_path(file_url: &str) -> Option<String> {
61        Url::parse(file_url)
62            .ok()
63            .filter(|url| url.scheme() == "file")
64            .and_then(|url| url.to_file_path().ok())
65            .map(|path| path.to_string_lossy().into_owned())
66    }
67
68    /// Stream the `location` directly from the local filesystem
69    async fn download_local_file(
70        &self,
71        local_path: &str,
72        sender: &Sender<Vec<u8>>,
73        download_event_type: DownloadEvent,
74        file_size: u64,
75    ) -> StdResult<()> {
76        let mut downloaded_bytes: u64 = 0;
77        let mut file = File::open(local_path).await?;
78        let size = match file.metadata().await {
79            Ok(metadata) => metadata.len(),
80            Err(_) => file_size,
81        };
82
83        self.feedback_sender
84            .send_event(download_event_type.build_download_started_event(size))
85            .await;
86
87        loop {
88            // We can either allocate here each time, or clone a shared buffer into sender.
89            // A larger read buffer is faster, less context switches:
90            let mut buffer = vec![0; 16 * 1024 * 1024];
91            let bytes_read = file.read(&mut buffer).await?;
92            if bytes_read == 0 {
93                break;
94            }
95            buffer.truncate(bytes_read);
96            sender.send_async(buffer).await.with_context(|| {
97                format!(
98                    "Local file read: could not write {} bytes to stream.",
99                    bytes_read
100                )
101            })?;
102            downloaded_bytes += bytes_read as u64;
103            let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
104            self.feedback_sender.send_event(event).await;
105        }
106
107        self.feedback_sender
108            .send_event(download_event_type.build_download_completed_event())
109            .await;
110
111        Ok(())
112    }
113
114    /// Stream the `location` remotely
115    async fn download_remote_file(
116        &self,
117        location: &str,
118        sender: &Sender<Vec<u8>>,
119        download_event_type: DownloadEvent,
120        file_size: u64,
121    ) -> StdResult<()> {
122        let mut downloaded_bytes: u64 = 0;
123        let response = self.get(location).await?;
124        let size = response.content_length().unwrap_or(file_size);
125        let mut remote_stream = response.bytes_stream();
126
127        self.feedback_sender
128            .send_event(download_event_type.build_download_started_event(size))
129            .await;
130
131        while let Some(item) = remote_stream.next().await {
132            let chunk = item.with_context(|| "Download: Could not read from byte stream")?;
133            sender.send_async(chunk.to_vec()).await.with_context(|| {
134                format!("Download: could not write {} bytes to stream.", chunk.len())
135            })?;
136            downloaded_bytes += chunk.len() as u64;
137            let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
138            self.feedback_sender.send_event(event).await;
139        }
140
141        self.feedback_sender
142            .send_event(download_event_type.build_download_completed_event())
143            .await;
144
145        Ok(())
146    }
147
148    fn unpack_file(
149        stream: Receiver<Vec<u8>>,
150        compression_algorithm: Option<CompressionAlgorithm>,
151        unpack_dir: &Path,
152        download_id: String,
153    ) -> StdResult<()> {
154        let input = StreamReader::new(stream);
155        match compression_algorithm {
156            Some(CompressionAlgorithm::Gzip) => {
157                let gzip_decoder = GzDecoder::new(input);
158                let mut file_archive = Archive::new(gzip_decoder);
159                file_archive.unpack(unpack_dir).with_context(|| {
160                    format!(
161                        "Could not unpack with 'Gzip' from streamed data to directory '{}'",
162                        unpack_dir.display()
163                    )
164                })?;
165            }
166            Some(CompressionAlgorithm::Zstandard) => {
167                let zstandard_decoder = zstd::Decoder::new(input)
168                    .with_context(|| "Unpack failed: Create Zstandard decoder error")?;
169                let mut file_archive = Archive::new(zstandard_decoder);
170                file_archive.unpack(unpack_dir).with_context(|| {
171                    format!(
172                        "Could not unpack with 'Zstd' from streamed data to directory '{}'",
173                        unpack_dir.display()
174                    )
175                })?;
176            }
177            None => {
178                let file_path = unpack_dir.join(download_id);
179                if file_path.exists() {
180                    std::fs::remove_file(file_path.clone())?;
181                }
182                let mut file = std::fs::File::create(file_path)?;
183                let input_buffered = BufReader::new(input);
184                for byte in input_buffered.bytes() {
185                    file.write_all(&[byte?])?;
186                }
187                file.flush()?;
188            }
189        };
190
191        Ok(())
192    }
193}
194
195#[async_trait]
196impl FileDownloader for HttpFileDownloader {
197    async fn download_unpack(
198        &self,
199        location: &FileDownloaderUri,
200        file_size: u64,
201        target_dir: &Path,
202        compression_algorithm: Option<CompressionAlgorithm>,
203        download_event_type: DownloadEvent,
204    ) -> StdResult<()> {
205        if !target_dir.is_dir() {
206            Err(
207                anyhow!("target path is not a directory or does not exist: `{target_dir:?}`")
208                    .context("Download-Unpack: prerequisite error"),
209            )?;
210        }
211
212        let (sender, receiver) = flume::bounded(32);
213        let dest_dir = target_dir.to_path_buf();
214        let download_id = download_event_type.download_id().to_owned();
215        let unpack_thread = tokio::task::spawn_blocking(move || -> StdResult<()> {
216            Self::unpack_file(receiver, compression_algorithm, &dest_dir, download_id)
217        });
218        if let Some(local_path) = Self::file_scheme_to_local_path(location.as_str()) {
219            self.download_local_file(&local_path, &sender, download_event_type, file_size)
220                .await?;
221        } else {
222            self.download_remote_file(location.as_str(), &sender, download_event_type, file_size)
223                .await?;
224        }
225        drop(sender);
226        unpack_thread
227            .await
228            .with_context(|| {
229                format!(
230                    "Unpack: panic while unpacking to dir '{}'",
231                    target_dir.display()
232                )
233            })?
234            .with_context(|| {
235                format!("Unpack: could not unpack to dir '{}'", target_dir.display())
236            })?;
237
238        Ok(())
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use std::sync::Arc;
245
246    use httpmock::MockServer;
247
248    use mithril_common::{entities::FileUri, test_utils::TempDir};
249
250    use crate::{
251        feedback::{MithrilEvent, MithrilEventCardanoDatabase, StackFeedbackReceiver},
252        test_utils::TestLogger,
253    };
254
255    use super::*;
256
257    #[cfg(not(target_family = "windows"))]
258    fn local_file_uri(path: &Path) -> FileDownloaderUri {
259        FileDownloaderUri::FileUri(FileUri(format!(
260            "file://{}",
261            path.canonicalize().unwrap().to_string_lossy()
262        )))
263    }
264
265    #[cfg(target_family = "windows")]
266    fn local_file_uri(path: &Path) -> FileDownloaderUri {
267        // We need to transform `\\?\C:\data\Temp\mithril_test\snapshot.txt` to `file://C:/data/Temp/mithril_test/snapshot.txt`
268        FileDownloaderUri::FileUri(FileUri(format!(
269            "file:/{}",
270            path.canonicalize()
271                .unwrap()
272                .to_string_lossy()
273                .replace("\\", "/")
274                .replace("?/", ""),
275        )))
276    }
277
278    #[tokio::test]
279    async fn test_download_http_file_send_feedback() {
280        let target_dir = TempDir::create(
281            "client-http-downloader",
282            "test_download_http_file_send_feedback",
283        );
284        let content = "Hello, world!";
285        let size = content.len() as u64;
286        let server = MockServer::start();
287        server.mock(|when, then| {
288            when.method(httpmock::Method::GET).path("/snapshot.tar");
289            then.status(200)
290                .body(content)
291                .header(reqwest::header::CONTENT_LENGTH.as_str(), size.to_string());
292        });
293        let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
294        let http_file_downloader = HttpFileDownloader::new(
295            FeedbackSender::new(&[feedback_receiver.clone()]),
296            TestLogger::stdout(),
297        )
298        .unwrap();
299        let download_id = "id".to_string();
300
301        http_file_downloader
302            .download_unpack(
303                &FileDownloaderUri::FileUri(FileUri(server.url("/snapshot.tar"))),
304                0,
305                &target_dir,
306                None,
307                DownloadEvent::Digest {
308                    download_id: download_id.clone(),
309                },
310            )
311            .await
312            .unwrap();
313
314        let expected_events = vec![
315            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
316                download_id: download_id.clone(),
317                size,
318            }),
319            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
320                download_id: download_id.clone(),
321                downloaded_bytes: size,
322                size,
323            }),
324            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
325                download_id: download_id.clone(),
326            }),
327        ];
328        assert_eq!(expected_events, feedback_receiver.stacked_events());
329    }
330
331    #[tokio::test]
332    async fn test_download_local_file_send_feedback() {
333        let target_dir = TempDir::create(
334            "client-http-downloader",
335            "test_download_local_file_send_feedback",
336        );
337        let content = "Hello, world!";
338        let size = content.len() as u64;
339
340        let source_file_path = target_dir.join("snapshot.txt");
341        let mut file = std::fs::File::create(&source_file_path).unwrap();
342        file.write_all(content.as_bytes()).unwrap();
343
344        let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
345        let http_file_downloader = HttpFileDownloader::new(
346            FeedbackSender::new(&[feedback_receiver.clone()]),
347            TestLogger::stdout(),
348        )
349        .unwrap();
350        let download_id = "id".to_string();
351
352        http_file_downloader
353            .download_unpack(
354                &local_file_uri(&source_file_path),
355                0,
356                &target_dir,
357                None,
358                DownloadEvent::Digest {
359                    download_id: download_id.clone(),
360                },
361            )
362            .await
363            .unwrap();
364
365        let expected_events = vec![
366            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
367                download_id: download_id.clone(),
368                size,
369            }),
370            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
371                download_id: download_id.clone(),
372                downloaded_bytes: size,
373                size,
374            }),
375            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
376                download_id: download_id.clone(),
377            }),
378        ];
379        assert_eq!(expected_events, feedback_receiver.stacked_events());
380    }
381}