mithril_aggregator/services/snapshotter/
test_doubles.rs

1use async_trait::async_trait;
2use std::fs;
3use std::fs::File;
4use std::path::{Path, PathBuf};
5use std::sync::RwLock;
6
7use mithril_common::entities::{CompressionAlgorithm, ImmutableFileNumber};
8use mithril_common::StdResult;
9
10use crate::services::Snapshotter;
11use crate::tools::file_archiver::FileArchive;
12
13/// Snapshotter that does nothing. It is mainly used for test purposes.
14pub struct DumbSnapshotter {
15    last_snapshot: RwLock<Option<FileArchive>>,
16    compression_algorithm: CompressionAlgorithm,
17    archive_size: u64,
18}
19
20impl DumbSnapshotter {
21    /// Create a new instance of DumbSnapshotter.
22    ///
23    /// The `compression_algorithm` parameter is used for the output file name extension.
24    pub fn new(compression_algorithm: CompressionAlgorithm) -> Self {
25        Self {
26            last_snapshot: RwLock::new(None),
27            compression_algorithm,
28            archive_size: 0,
29        }
30    }
31
32    /// Set the size assigned to the produced snapshots.
33    pub fn with_archive_size(mut self, size: u64) -> Self {
34        self.archive_size = size;
35        self
36    }
37
38    /// Return the last fake snapshot produced.
39    pub fn get_last_snapshot(&self) -> StdResult<Option<FileArchive>> {
40        let value = self.last_snapshot.read().unwrap().as_ref().cloned();
41
42        Ok(value)
43    }
44}
45
46impl Default for DumbSnapshotter {
47    fn default() -> Self {
48        Self {
49            last_snapshot: RwLock::new(None),
50            compression_algorithm: CompressionAlgorithm::Gzip,
51            archive_size: 0,
52        }
53    }
54}
55
56#[async_trait]
57impl Snapshotter for DumbSnapshotter {
58    async fn snapshot_all_completed_immutables(
59        &self,
60        archive_name_without_extension: &str,
61    ) -> StdResult<FileArchive> {
62        let mut value = self.last_snapshot.write().unwrap();
63        let snapshot = FileArchive::new(
64            PathBuf::from(format!(
65                "{archive_name_without_extension}.{}",
66                self.compression_algorithm.tar_file_extension()
67            )),
68            self.archive_size,
69            0,
70            self.compression_algorithm,
71        );
72        *value = Some(snapshot.clone());
73
74        Ok(snapshot)
75    }
76
77    async fn snapshot_ancillary(
78        &self,
79        _immutable_file_number: ImmutableFileNumber,
80        archive_name_without_extension: &str,
81    ) -> StdResult<FileArchive> {
82        self.snapshot_all_completed_immutables(archive_name_without_extension)
83            .await
84    }
85
86    async fn snapshot_immutable_trio(
87        &self,
88        _immutable_file_number: ImmutableFileNumber,
89        archive_name_without_extension: &str,
90    ) -> StdResult<FileArchive> {
91        self.snapshot_all_completed_immutables(archive_name_without_extension)
92            .await
93    }
94
95    fn compression_algorithm(&self) -> CompressionAlgorithm {
96        self.compression_algorithm
97    }
98}
99
100/// Snapshotter that writes empty files to the filesystem. Used for testing purposes.
101pub struct FakeSnapshotter {
102    work_dir: PathBuf,
103    compression_algorithm: CompressionAlgorithm,
104}
105
106impl FakeSnapshotter {
107    /// `FakeSnapshotter` factory, with a default compression algorithm of `Gzip`.
108    pub fn new<T: AsRef<Path>>(work_dir: T) -> Self {
109        Self {
110            work_dir: work_dir.as_ref().to_path_buf(),
111            compression_algorithm: CompressionAlgorithm::Gzip,
112        }
113    }
114
115    /// Set the compression algorithm used to for the output file name extension.
116    pub fn with_compression_algorithm(
117        mut self,
118        compression_algorithm: CompressionAlgorithm,
119    ) -> Self {
120        self.compression_algorithm = compression_algorithm;
121        self
122    }
123}
124
125#[async_trait]
126impl Snapshotter for FakeSnapshotter {
127    async fn snapshot_all_completed_immutables(
128        &self,
129        archive_name_without_extension: &str,
130    ) -> StdResult<FileArchive> {
131        let fake_archive_path = self.work_dir.join(format!(
132            "{archive_name_without_extension}.{}",
133            self.compression_algorithm.tar_file_extension()
134        ));
135        if let Some(archive_dir) = fake_archive_path.parent() {
136            fs::create_dir_all(archive_dir).unwrap();
137        }
138        File::create(&fake_archive_path).unwrap();
139
140        Ok(FileArchive::new(
141            fake_archive_path,
142            0,
143            0,
144            self.compression_algorithm,
145        ))
146    }
147
148    async fn snapshot_ancillary(
149        &self,
150        _immutable_file_number: ImmutableFileNumber,
151        archive_name_without_extension: &str,
152    ) -> StdResult<FileArchive> {
153        self.snapshot_all_completed_immutables(archive_name_without_extension)
154            .await
155    }
156
157    async fn snapshot_immutable_trio(
158        &self,
159        _immutable_file_number: ImmutableFileNumber,
160        archive_name_without_extension: &str,
161    ) -> StdResult<FileArchive> {
162        self.snapshot_all_completed_immutables(archive_name_without_extension)
163            .await
164    }
165
166    fn compression_algorithm(&self) -> CompressionAlgorithm {
167        self.compression_algorithm
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use crate::services::snapshotter::test_tools::*;
174
175    use super::*;
176
177    mod dumb_snapshotter {
178        use super::*;
179
180        #[test]
181        fn return_parametrized_compression_algorithm() {
182            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Zstandard);
183            assert_eq!(
184                CompressionAlgorithm::Zstandard,
185                snapshotter.compression_algorithm()
186            );
187        }
188
189        #[tokio::test]
190        async fn test_dumb_snapshotter_snapshot_return_archive_named_with_compression_algorithm_and_size_of_0(
191        ) {
192            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Gzip);
193
194            let snapshot = snapshotter
195                .snapshot_all_completed_immutables("archive_full_immutables")
196                .await
197                .unwrap();
198            assert_eq!(
199                PathBuf::from("archive_full_immutables.tar.gz"),
200                *snapshot.get_file_path()
201            );
202            assert_eq!(0, snapshot.get_archive_size());
203
204            let snapshot = snapshotter
205                .snapshot_ancillary(3, "archive_ancillary")
206                .await
207                .unwrap();
208            assert_eq!(
209                PathBuf::from("archive_ancillary.tar.gz"),
210                *snapshot.get_file_path()
211            );
212            assert_eq!(0, snapshot.get_archive_size());
213
214            let snapshot = snapshotter
215                .snapshot_immutable_trio(4, "archive_immutable_trio")
216                .await
217                .unwrap();
218            assert_eq!(
219                PathBuf::from("archive_immutable_trio.tar.gz"),
220                *snapshot.get_file_path()
221            );
222            assert_eq!(0, snapshot.get_archive_size());
223        }
224
225        #[tokio::test]
226        async fn test_dumb_snapshotter() {
227            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Zstandard);
228            assert!(snapshotter
229                .get_last_snapshot()
230                .expect(
231                    "Dumb snapshotter::get_last_snapshot should not fail when no last snapshot."
232                )
233                .is_none());
234
235            {
236                let full_immutables_snapshot = snapshotter
237                    .snapshot_all_completed_immutables("whatever")
238                    .await
239                    .expect("Dumb snapshotter::snapshot_all_completed_immutables should not fail.");
240                assert_eq!(
241                    Some(full_immutables_snapshot),
242                    snapshotter.get_last_snapshot().expect(
243                        "Dumb snapshotter::get_last_snapshot should not fail when some last snapshot."
244                    )
245                );
246            }
247            {
248                let ancillary_snapshot = snapshotter
249                    .snapshot_ancillary(3, "whatever")
250                    .await
251                    .expect("Dumb snapshotter::snapshot_ancillary should not fail.");
252                assert_eq!(
253                    Some(ancillary_snapshot),
254                    snapshotter.get_last_snapshot().expect(
255                        "Dumb snapshotter::get_last_snapshot should not fail when some last snapshot."
256                    )
257                );
258            }
259            {
260                let immutable_snapshot = snapshotter
261                    .snapshot_immutable_trio(4, "whatever")
262                    .await
263                    .expect("Dumb snapshotter::snapshot_immutable_trio should not fail.");
264                assert_eq!(
265                    Some(immutable_snapshot),
266                    snapshotter.get_last_snapshot().expect(
267                        "Dumb snapshotter::get_last_snapshot should not fail when some last snapshot."
268                    )
269                );
270            }
271        }
272
273        #[tokio::test]
274        async fn set_dumb_snapshotter_archive_size() {
275            let snapshotter = DumbSnapshotter::new(CompressionAlgorithm::Gzip);
276
277            // Default size is 0
278            let snapshot = snapshotter
279                .snapshot_all_completed_immutables("whatever")
280                .await
281                .unwrap();
282            assert_eq!(0, snapshot.get_archive_size());
283
284            let snapshotter = snapshotter.with_archive_size(42);
285            let snapshot = snapshotter
286                .snapshot_all_completed_immutables("whatever")
287                .await
288                .unwrap();
289            assert_eq!(42, snapshot.get_archive_size());
290        }
291    }
292
293    mod fake_snapshotter {
294        use super::*;
295
296        #[test]
297        fn return_parametrized_compression_algorithm() {
298            let snapshotter = FakeSnapshotter::new("whatever")
299                .with_compression_algorithm(CompressionAlgorithm::Zstandard);
300            assert_eq!(
301                CompressionAlgorithm::Zstandard,
302                snapshotter.compression_algorithm()
303            );
304        }
305
306        #[tokio::test]
307        async fn test_fake_snasphotter() {
308            let test_dir = get_test_directory("test_fake_snasphotter");
309            let fake_snapshotter = FakeSnapshotter::new(&test_dir)
310                .with_compression_algorithm(CompressionAlgorithm::Gzip);
311
312            for filename in [
313                "direct_child",
314                "one_level_subdir/child",
315                "two_levels/subdir/child",
316            ] {
317                {
318                    let full_immutables_snapshot = fake_snapshotter
319                        .snapshot_all_completed_immutables(filename)
320                        .await
321                        .unwrap();
322
323                    assert_eq!(
324                        full_immutables_snapshot.get_file_path(),
325                        &test_dir.join(filename).with_extension("tar.gz")
326                    );
327                    assert!(full_immutables_snapshot.get_file_path().is_file());
328                }
329                {
330                    let ancillary_snapshot = fake_snapshotter
331                        .snapshot_ancillary(3, filename)
332                        .await
333                        .unwrap();
334
335                    assert_eq!(
336                        ancillary_snapshot.get_file_path(),
337                        &test_dir.join(filename).with_extension("tar.gz")
338                    );
339                    assert!(ancillary_snapshot.get_file_path().is_file());
340                }
341                {
342                    let immutable_snapshot = fake_snapshotter
343                        .snapshot_immutable_trio(5, filename)
344                        .await
345                        .unwrap();
346
347                    assert_eq!(
348                        immutable_snapshot.get_file_path(),
349                        &test_dir.join(filename).with_extension("tar.gz")
350                    );
351                    assert!(immutable_snapshot.get_file_path().is_file());
352                }
353            }
354        }
355    }
356}