mithril_client/file_downloader/
http.rs

1use std::{
2    io::{self, BufReader, Write},
3    path::Path,
4};
5
6use anyhow::{Context, anyhow};
7use async_trait::async_trait;
8use flate2::read::GzDecoder;
9use flume::{Receiver, Sender};
10use futures::StreamExt;
11use reqwest::{Response, StatusCode, Url};
12use slog::{Logger, debug};
13use tar::Archive;
14use tokio::fs::File;
15use tokio::io::AsyncReadExt;
16
17use mithril_common::{StdResult, logging::LoggerExtensions};
18
19use crate::common::CompressionAlgorithm;
20use crate::feedback::FeedbackSender;
21use crate::utils::StreamReader;
22
23use super::{FileDownloader, FileDownloaderUri, interface::DownloadEvent};
24
25/// A file downloader that only handles download through HTTP.
26pub struct HttpFileDownloader {
27    http_client: reqwest::Client,
28    feedback_sender: FeedbackSender,
29    logger: Logger,
30}
31
32impl HttpFileDownloader {
33    /// Constructs a new `HttpFileDownloader`.
34    pub fn new(feedback_sender: FeedbackSender, logger: Logger) -> StdResult<Self> {
35        let http_client = reqwest::ClientBuilder::new()
36            .build()
37            .with_context(|| "Building http client for HttpFileDownloader failed")?;
38
39        Ok(Self {
40            http_client,
41            feedback_sender,
42            logger: logger.new_with_component_name::<Self>(),
43        })
44    }
45
46    async fn get(&self, location: &str) -> StdResult<Response> {
47        debug!(self.logger, "GET Snapshot location='{location}'.");
48        let request_builder = self.http_client.get(location);
49        let response = request_builder.send().await.with_context(|| {
50            format!("Cannot perform a GET for the snapshot (location='{location}')")
51        })?;
52
53        match response.status() {
54            StatusCode::OK => Ok(response),
55            StatusCode::NOT_FOUND => Err(anyhow!("Location='{location} not found")),
56            status_code => Err(anyhow!("Unhandled error {status_code}")),
57        }
58    }
59
60    fn file_scheme_to_local_path(file_url: &str) -> Option<String> {
61        Url::parse(file_url)
62            .ok()
63            .filter(|url| url.scheme() == "file")
64            .and_then(|url| url.to_file_path().ok())
65            .map(|path| path.to_string_lossy().into_owned())
66    }
67
68    /// Stream the `location` directly from the local filesystem
69    async fn download_local_file(
70        &self,
71        local_path: &str,
72        sender: &Sender<Vec<u8>>,
73        download_event_type: DownloadEvent,
74        file_size: u64,
75    ) -> StdResult<()> {
76        let mut downloaded_bytes: u64 = 0;
77        let mut file = File::open(local_path).await?;
78        let size = match file.metadata().await {
79            Ok(metadata) => metadata.len(),
80            Err(_) => file_size,
81        };
82
83        self.feedback_sender
84            .send_event(download_event_type.build_download_started_event(size))
85            .await;
86
87        loop {
88            // We can either allocate here each time, or clone a shared buffer into sender.
89            // A larger read buffer is faster, less context switches:
90            let mut buffer = vec![0; 16 * 1024 * 1024];
91            let bytes_read = file.read(&mut buffer).await?;
92            if bytes_read == 0 {
93                break;
94            }
95            buffer.truncate(bytes_read);
96            sender.send_async(buffer).await.with_context(|| {
97                format!("Local file read: could not write {bytes_read} bytes to stream.")
98            })?;
99            downloaded_bytes += bytes_read as u64;
100            let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
101            self.feedback_sender.send_event(event).await;
102        }
103
104        self.feedback_sender
105            .send_event(download_event_type.build_download_completed_event())
106            .await;
107
108        Ok(())
109    }
110
111    /// Stream the `location` remotely
112    async fn download_remote_file(
113        &self,
114        location: &str,
115        sender: &Sender<Vec<u8>>,
116        download_event_type: DownloadEvent,
117        file_size: u64,
118    ) -> StdResult<()> {
119        let mut downloaded_bytes: u64 = 0;
120        let response = self.get(location).await?;
121        let size = response.content_length().unwrap_or(file_size);
122        let mut remote_stream = response.bytes_stream();
123
124        self.feedback_sender
125            .send_event(download_event_type.build_download_started_event(size))
126            .await;
127
128        while let Some(item) = remote_stream.next().await {
129            let chunk = item.with_context(|| "Download: Could not read from byte stream")?;
130            sender.send_async(chunk.to_vec()).await.with_context(|| {
131                format!("Download: could not write {} bytes to stream.", chunk.len())
132            })?;
133            downloaded_bytes += chunk.len() as u64;
134            let event = download_event_type.build_download_progress_event(downloaded_bytes, size);
135            self.feedback_sender.send_event(event).await;
136        }
137
138        self.feedback_sender
139            .send_event(download_event_type.build_download_completed_event())
140            .await;
141
142        Ok(())
143    }
144
145    fn unpack_file(
146        stream: Receiver<Vec<u8>>,
147        compression_algorithm: Option<CompressionAlgorithm>,
148        unpack_dir: &Path,
149        download_id: String,
150    ) -> StdResult<()> {
151        let input = StreamReader::new(stream);
152        match compression_algorithm {
153            Some(CompressionAlgorithm::Gzip) => {
154                let gzip_decoder = GzDecoder::new(input);
155                let mut file_archive = Archive::new(gzip_decoder);
156                file_archive.unpack(unpack_dir).with_context(|| {
157                    format!(
158                        "Could not unpack with 'Gzip' from streamed data to directory '{}'",
159                        unpack_dir.display()
160                    )
161                })?;
162            }
163            Some(CompressionAlgorithm::Zstandard) => {
164                let zstandard_decoder = zstd::Decoder::new(input)
165                    .with_context(|| "Unpack failed: Create Zstandard decoder error")?;
166                let mut file_archive = Archive::new(zstandard_decoder);
167                file_archive.unpack(unpack_dir).with_context(|| {
168                    format!(
169                        "Could not unpack with 'Zstd' from streamed data to directory '{}'",
170                        unpack_dir.display()
171                    )
172                })?;
173            }
174            None => {
175                let file_path = unpack_dir.join(download_id);
176                if file_path.exists() {
177                    std::fs::remove_file(file_path.clone())?;
178                }
179                let mut file = std::fs::File::create(file_path)?;
180                let mut input_buffered = BufReader::new(input);
181                io::copy(&mut input_buffered, &mut file)?;
182                file.flush()?;
183            }
184        };
185
186        Ok(())
187    }
188}
189
190#[async_trait]
191impl FileDownloader for HttpFileDownloader {
192    async fn download_unpack(
193        &self,
194        location: &FileDownloaderUri,
195        file_size: u64,
196        target_dir: &Path,
197        compression_algorithm: Option<CompressionAlgorithm>,
198        download_event_type: DownloadEvent,
199    ) -> StdResult<()> {
200        if !target_dir.is_dir() {
201            Err(
202                anyhow!("target path is not a directory or does not exist: `{target_dir:?}`")
203                    .context("Download-Unpack: prerequisite error"),
204            )?;
205        }
206
207        let (sender, receiver) = flume::bounded(32);
208        let dest_dir = target_dir.to_path_buf();
209        let download_id = download_event_type.download_id().to_owned();
210        let unpack_thread = tokio::task::spawn_blocking(move || -> StdResult<()> {
211            Self::unpack_file(receiver, compression_algorithm, &dest_dir, download_id)
212        });
213        if let Some(local_path) = Self::file_scheme_to_local_path(location.as_str()) {
214            self.download_local_file(&local_path, &sender, download_event_type, file_size)
215                .await?;
216        } else {
217            self.download_remote_file(location.as_str(), &sender, download_event_type, file_size)
218                .await?;
219        }
220        drop(sender);
221        unpack_thread
222            .await
223            .with_context(|| {
224                format!(
225                    "Unpack: panic while unpacking to dir '{}'",
226                    target_dir.display()
227                )
228            })?
229            .with_context(|| {
230                format!("Unpack: could not unpack to dir '{}'", target_dir.display())
231            })?;
232
233        Ok(())
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use std::sync::Arc;
240
241    use httpmock::MockServer;
242
243    use mithril_common::{entities::FileUri, test::TempDir};
244
245    use crate::{
246        feedback::{
247            FeedbackReceiver, MithrilEvent, MithrilEventCardanoDatabase, StackFeedbackReceiver,
248        },
249        test_utils::TestLogger,
250    };
251
252    use super::*;
253
254    #[cfg(not(target_family = "windows"))]
255    fn local_file_uri(path: &Path) -> FileDownloaderUri {
256        FileDownloaderUri::FileUri(FileUri(format!(
257            "file://{}",
258            path.canonicalize().unwrap().to_string_lossy()
259        )))
260    }
261
262    #[cfg(target_family = "windows")]
263    fn local_file_uri(path: &Path) -> FileDownloaderUri {
264        // We need to transform `\\?\C:\data\Temp\mithril_test\snapshot.txt` to `file://C:/data/Temp/mithril_test/snapshot.txt`
265        FileDownloaderUri::FileUri(FileUri(format!(
266            "file:/{}",
267            path.canonicalize()
268                .unwrap()
269                .to_string_lossy()
270                .replace("\\", "/")
271                .replace("?/", ""),
272        )))
273    }
274
275    #[tokio::test]
276    async fn test_download_http_file_send_feedback() {
277        let target_dir = TempDir::create(
278            "client-http-downloader",
279            "test_download_http_file_send_feedback",
280        );
281        let content = "Hello, world!";
282        let size = content.len() as u64;
283        let server = MockServer::start();
284        server.mock(|when, then| {
285            when.method(httpmock::Method::GET).path("/snapshot.tar");
286            then.status(200)
287                .body(content)
288                .header(reqwest::header::CONTENT_LENGTH.as_str(), size.to_string());
289        });
290        let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
291        let feedback_receiver_clone = feedback_receiver.clone() as Arc<dyn FeedbackReceiver>;
292        let http_file_downloader = HttpFileDownloader::new(
293            FeedbackSender::new(&[feedback_receiver_clone]),
294            TestLogger::stdout(),
295        )
296        .unwrap();
297        let download_id = "id".to_string();
298
299        http_file_downloader
300            .download_unpack(
301                &FileDownloaderUri::FileUri(FileUri(server.url("/snapshot.tar"))),
302                0,
303                &target_dir,
304                None,
305                DownloadEvent::Digest {
306                    download_id: download_id.clone(),
307                },
308            )
309            .await
310            .unwrap();
311
312        let expected_events = vec![
313            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
314                download_id: download_id.clone(),
315                size,
316            }),
317            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
318                download_id: download_id.clone(),
319                downloaded_bytes: size,
320                size,
321            }),
322            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
323                download_id: download_id.clone(),
324            }),
325        ];
326        assert_eq!(expected_events, feedback_receiver.stacked_events());
327    }
328
329    #[tokio::test]
330    async fn test_download_local_file_send_feedback() {
331        let target_dir = TempDir::create(
332            "client-http-downloader",
333            "test_download_local_file_send_feedback",
334        );
335        let content = "Hello, world!";
336        let size = content.len() as u64;
337
338        let source_file_path = target_dir.join("snapshot.txt");
339        let mut file = std::fs::File::create(&source_file_path).unwrap();
340        file.write_all(content.as_bytes()).unwrap();
341
342        let feedback_receiver = Arc::new(StackFeedbackReceiver::new());
343        let feedback_receiver_clone = feedback_receiver.clone() as Arc<dyn FeedbackReceiver>;
344        let http_file_downloader = HttpFileDownloader::new(
345            FeedbackSender::new(&[feedback_receiver_clone]),
346            TestLogger::stdout(),
347        )
348        .unwrap();
349        let download_id = "id".to_string();
350
351        http_file_downloader
352            .download_unpack(
353                &local_file_uri(&source_file_path),
354                0,
355                &target_dir,
356                None,
357                DownloadEvent::Digest {
358                    download_id: download_id.clone(),
359                },
360            )
361            .await
362            .unwrap();
363
364        let expected_events = vec![
365            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadStarted {
366                download_id: download_id.clone(),
367                size,
368            }),
369            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadProgress {
370                download_id: download_id.clone(),
371                downloaded_bytes: size,
372                size,
373            }),
374            MithrilEvent::CardanoDatabase(MithrilEventCardanoDatabase::DigestDownloadCompleted {
375                download_id: download_id.clone(),
376            }),
377        ];
378        assert_eq!(expected_events, feedback_receiver.stacked_events());
379    }
380}