mithril_aggregator/services/snapshotter/
test_doubles.rs

1use async_trait::async_trait;
2use std::fs;
3use std::fs::File;
4use std::path::{Path, PathBuf};
5use std::sync::RwLock;
6
7use mithril_common::entities::{CompressionAlgorithm, ImmutableFileNumber};
8use mithril_common::StdResult;
9
10use crate::services::Snapshotter;
11use crate::tools::file_archiver::FileArchive;
12
13/// Snapshotter that does nothing. It is mainly used for test purposes.
14pub struct DumbSnapshotter {
15    last_snapshot: RwLock<Option<FileArchive>>,
16    compression_algorithm: CompressionAlgorithm,
17    archive_size: u64,
18}
19
20impl DumbSnapshotter {
21    /// Create a new instance of DumbSnapshotter.
22    ///
23    /// The `compression_algorithm` parameter is used for the output file name extension.
24    pub fn new(compression_algorithm: CompressionAlgorithm) -> Self {
25        Self {
26            last_snapshot: RwLock::new(None),
27            compression_algorithm,
28            archive_size: 0,
29        }
30    }
31
32    /// Set the size assigned to the produced snapshots.
33    pub fn with_archive_size(mut self, size: u64) -> Self {
34        self.archive_size = size;
35        self
36    }
37
38    /// Return the last fake snapshot produced.
39    pub fn get_last_snapshot(&self) -> StdResult<Option<FileArchive>> {
40        let value = self.last_snapshot.read().unwrap().as_ref().cloned();
41
42        Ok(value)
43    }
44}
45
46impl Default for DumbSnapshotter {
47    fn default() -> Self {
48        Self {
49            last_snapshot: RwLock::new(None),
50            compression_algorithm: CompressionAlgorithm::Gzip,
51            archive_size: 0,
52        }
53    }
54}
55
56#[async_trait]
57impl Snapshotter for DumbSnapshotter {
58    async fn snapshot_all_completed_immutables(
59        &self,
60        archive_name_without_extension: &str,
61    ) -> StdResult<FileArchive> {
62        let mut value = self.last_snapshot.write().unwrap();
63        let snapshot = FileArchive::new(
64            PathBuf::from(format!(
65                "{archive_name_without_extension}.{}",
66                self.compression_algorithm.tar_file_extension()
67            )),
68            self.archive_size,
69            0,
70            self.compression_algorithm,
71        );
72        *value = Some(snapshot.clone());
73
74        Ok(snapshot)
75    }
76
77    async fn snapshot_ancillary(
78        &self,
79        _immutable_file_number: ImmutableFileNumber,
80        archive_name_without_extension: &str,
81    ) -> StdResult<FileArchive> {
82        self.snapshot_all_completed_immutables(archive_name_without_extension)
83            .await
84    }
85
86    async fn snapshot_immutable_trio(
87        &self,
88        _immutable_file_number: ImmutableFileNumber,
89        archive_name_without_extension: &str,
90    ) -> StdResult<FileArchive> {
91        self.snapshot_all_completed_immutables(archive_name_without_extension)
92            .await
93    }
94
95    async fn compute_immutable_files_total_uncompressed_size(
96        &self,
97        _up_to_immutable_file_number: ImmutableFileNumber,
98    ) -> StdResult<u64> {
99        Ok(0)
100    }
101
102    fn compression_algorithm(&self) -> CompressionAlgorithm {
103        self.compression_algorithm
104    }
105}
106
107/// Snapshotter that writes empty files to the filesystem. Used for testing purposes.
108pub struct FakeSnapshotter {
109    work_dir: PathBuf,
110    compression_algorithm: CompressionAlgorithm,
111}
112
113impl FakeSnapshotter {
114    /// `FakeSnapshotter` factory, with a default compression algorithm of `Gzip`.
115    pub fn new<T: AsRef<Path>>(work_dir: T) -> Self {
116        Self {
117            work_dir: work_dir.as_ref().to_path_buf(),
118            compression_algorithm: CompressionAlgorithm::Gzip,
119        }
120    }
121
122    /// Set the compression algorithm used to for the output file name extension.
123    pub fn with_compression_algorithm(
124        mut self,
125        compression_algorithm: CompressionAlgorithm,
126    ) -> Self {
127        self.compression_algorithm = compression_algorithm;
128        self
129    }
130}
131
132#[async_trait]
133impl Snapshotter for FakeSnapshotter {
134    async fn snapshot_all_completed_immutables(
135        &self,
136        archive_name_without_extension: &str,
137    ) -> StdResult<FileArchive> {
138        let fake_archive_path = self.work_dir.join(format!(
139            "{archive_name_without_extension}.{}",
140            self.compression_algorithm.tar_file_extension()
141        ));
142        if let Some(archive_dir) = fake_archive_path.parent() {
143            fs::create_dir_all(archive_dir).unwrap();
144        }
145        File::create(&fake_archive_path).unwrap();
146
147        Ok(FileArchive::new(
148            fake_archive_path,
149            0,
150            0,
151            self.compression_algorithm,
152        ))
153    }
154
155    async fn snapshot_ancillary(
156        &self,
157        _immutable_file_number: ImmutableFileNumber,
158        archive_name_without_extension: &str,
159    ) -> StdResult<FileArchive> {
160        self.snapshot_all_completed_immutables(archive_name_without_extension)
161            .await
162    }
163
164    async fn snapshot_immutable_trio(
165        &self,
166        _immutable_file_number: ImmutableFileNumber,
167        archive_name_without_extension: &str,
168    ) -> StdResult<FileArchive> {
169        self.snapshot_all_completed_immutables(archive_name_without_extension)
170            .await
171    }
172
173    async fn compute_immutable_files_total_uncompressed_size(
174        &self,
175        _up_to_immutable_file_number: ImmutableFileNumber,
176    ) -> StdResult<u64> {
177        Ok(0)
178    }
179
180    fn compression_algorithm(&self) -> CompressionAlgorithm {
181        self.compression_algorithm
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use mithril_common::temp_dir_create;
188
189    use super::*;
190
191    mod dumb_snapshotter {
192        use super::*;
193
194        #[test]
195        fn return_parametrized_compression_algorithm() {
196            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Zstandard);
197            assert_eq!(
198                CompressionAlgorithm::Zstandard,
199                snapshotter.compression_algorithm()
200            );
201        }
202
203        #[tokio::test]
204        async fn test_dumb_snapshotter_snapshot_return_archive_named_with_compression_algorithm_and_size_of_0(
205        ) {
206            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Gzip);
207
208            let snapshot = snapshotter
209                .snapshot_all_completed_immutables("archive_full_immutables")
210                .await
211                .unwrap();
212            assert_eq!(
213                PathBuf::from("archive_full_immutables.tar.gz"),
214                *snapshot.get_file_path()
215            );
216            assert_eq!(0, snapshot.get_archive_size());
217
218            let snapshot = snapshotter
219                .snapshot_ancillary(3, "archive_ancillary")
220                .await
221                .unwrap();
222            assert_eq!(
223                PathBuf::from("archive_ancillary.tar.gz"),
224                *snapshot.get_file_path()
225            );
226            assert_eq!(0, snapshot.get_archive_size());
227
228            let snapshot = snapshotter
229                .snapshot_immutable_trio(4, "archive_immutable_trio")
230                .await
231                .unwrap();
232            assert_eq!(
233                PathBuf::from("archive_immutable_trio.tar.gz"),
234                *snapshot.get_file_path()
235            );
236            assert_eq!(0, snapshot.get_archive_size());
237        }
238
239        #[tokio::test]
240        async fn test_dumb_snapshotter() {
241            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Zstandard);
242            assert!(snapshotter
243                .get_last_snapshot()
244                .expect(
245                    "Dumb snapshotter::get_last_snapshot should not fail when no last snapshot."
246                )
247                .is_none());
248
249            {
250                let full_immutables_snapshot = snapshotter
251                    .snapshot_all_completed_immutables("whatever")
252                    .await
253                    .expect("Dumb snapshotter::snapshot_all_completed_immutables should not fail.");
254                assert_eq!(
255                    Some(full_immutables_snapshot),
256                    snapshotter.get_last_snapshot().expect(
257                        "Dumb snapshotter::get_last_snapshot should not fail when some last snapshot."
258                    )
259                );
260            }
261            {
262                let ancillary_snapshot = snapshotter
263                    .snapshot_ancillary(3, "whatever")
264                    .await
265                    .expect("Dumb snapshotter::snapshot_ancillary should not fail.");
266                assert_eq!(
267                    Some(ancillary_snapshot),
268                    snapshotter.get_last_snapshot().expect(
269                        "Dumb snapshotter::get_last_snapshot should not fail when some last snapshot."
270                    )
271                );
272            }
273            {
274                let immutable_snapshot = snapshotter
275                    .snapshot_immutable_trio(4, "whatever")
276                    .await
277                    .expect("Dumb snapshotter::snapshot_immutable_trio should not fail.");
278                assert_eq!(
279                    Some(immutable_snapshot),
280                    snapshotter.get_last_snapshot().expect(
281                        "Dumb snapshotter::get_last_snapshot should not fail when some last snapshot."
282                    )
283                );
284            }
285        }
286
287        #[tokio::test]
288        async fn set_dumb_snapshotter_archive_size() {
289            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Gzip);
290
291            // Default size is 0
292            let snapshot = snapshotter
293                .snapshot_all_completed_immutables("whatever")
294                .await
295                .unwrap();
296            assert_eq!(0, snapshot.get_archive_size());
297
298            let snapshotter = snapshotter.with_archive_size(42);
299            let snapshot = snapshotter
300                .snapshot_all_completed_immutables("whatever")
301                .await
302                .unwrap();
303            assert_eq!(42, snapshot.get_archive_size());
304        }
305    }
306
307    mod fake_snapshotter {
308        use super::*;
309
310        #[test]
311        fn return_parametrized_compression_algorithm() {
312            let snapshotter = FakeSnapshotter::new("whatever")
313                .with_compression_algorithm(CompressionAlgorithm::Zstandard);
314            assert_eq!(
315                CompressionAlgorithm::Zstandard,
316                snapshotter.compression_algorithm()
317            );
318        }
319
320        #[tokio::test]
321        async fn test_fake_snapshotter() {
322            let test_dir = temp_dir_create!();
323            let fake_snapshotter = FakeSnapshotter::new(&test_dir)
324                .with_compression_algorithm(CompressionAlgorithm::Gzip);
325
326            for filename in [
327                "direct_child",
328                "one_level_subdir/child",
329                "two_levels/subdir/child",
330            ] {
331                {
332                    let full_immutables_snapshot = fake_snapshotter
333                        .snapshot_all_completed_immutables(filename)
334                        .await
335                        .unwrap();
336
337                    assert_eq!(
338                        full_immutables_snapshot.get_file_path(),
339                        &test_dir.join(filename).with_extension("tar.gz")
340                    );
341                    assert!(full_immutables_snapshot.get_file_path().is_file());
342                }
343                {
344                    let ancillary_snapshot = fake_snapshotter
345                        .snapshot_ancillary(3, filename)
346                        .await
347                        .unwrap();
348
349                    assert_eq!(
350                        ancillary_snapshot.get_file_path(),
351                        &test_dir.join(filename).with_extension("tar.gz")
352                    );
353                    assert!(ancillary_snapshot.get_file_path().is_file());
354                }
355                {
356                    let immutable_snapshot = fake_snapshotter
357                        .snapshot_immutable_trio(5, filename)
358                        .await
359                        .unwrap();
360
361                    assert_eq!(
362                        immutable_snapshot.get_file_path(),
363                        &test_dir.join(filename).with_extension("tar.gz")
364                    );
365                    assert!(immutable_snapshot.get_file_path().is_file());
366                }
367            }
368        }
369    }
370}