mithril_persistence/database/
version_checker.rs

1use anyhow::{anyhow, Context};
2use chrono::Utc;
3use slog::{debug, error, info, Logger};
4use std::{cmp::Ordering, collections::BTreeSet};
5
6use mithril_common::{logging::LoggerExtensions, StdError, StdResult};
7
8use super::{
9    ApplicationNodeType, DatabaseVersion, DbVersion, GetDatabaseVersionQuery,
10    UpdateDatabaseVersionQuery,
11};
12
13use crate::sqlite::{ConnectionExtensions, SqliteConnection};
14
15/// Struct to perform application version check in the database.
16pub struct DatabaseVersionChecker<'conn> {
17    /// Pathbuf to the SQLite3 file.
18    connection: &'conn SqliteConnection,
19
20    /// Application type which vesion is verified.
21    application_type: ApplicationNodeType,
22
23    /// logger
24    logger: Logger,
25
26    /// known migrations
27    migrations: BTreeSet<SqlMigration>,
28}
29
30impl<'conn> DatabaseVersionChecker<'conn> {
31    /// constructor
32    pub fn new(
33        logger: Logger,
34        application_type: ApplicationNodeType,
35        connection: &'conn SqliteConnection,
36    ) -> Self {
37        let migrations = BTreeSet::new();
38
39        Self {
40            connection,
41            application_type,
42            logger: logger.new_with_component_name::<Self>(),
43            migrations,
44        }
45    }
46
47    /// Register a migration.
48    pub fn add_migration(&mut self, migration: SqlMigration) -> &mut Self {
49        let _ = self.migrations.insert(migration);
50
51        self
52    }
53
54    /// Apply migrations
55    pub fn apply(&self) -> StdResult<()> {
56        debug!(&self.logger, "Check database version",);
57        self.create_table_if_not_exists(&self.application_type)
58            .with_context(|| "Can not create table 'db_version' while applying migrations")?;
59        let db_version = self
60            .connection
61            .fetch_first(GetDatabaseVersionQuery::get_application_version(
62                &self.application_type,
63            ))
64            .with_context(|| "Can not get application version while applying migrations")?
65            .unwrap(); // At least a record exists.
66
67        // the current database version is equal to the maximum migration
68        // version present in this software.
69        // If no migration registered then version = 0.
70        let migration_version = self.migrations.iter().map(|m| m.version).max().unwrap_or(0);
71
72        match migration_version.cmp(&db_version.version) {
73            Ordering::Greater => {
74                debug!(
75                    &self.logger,
76                    "Database needs upgrade from version '{}' to version '{}', applying new migrations…",
77                    db_version.version, migration_version
78                );
79                self.apply_migrations(&db_version, self.connection)?;
80                info!(
81                    &self.logger,
82                    "Database upgraded to version '{migration_version}'"
83                );
84            }
85            Ordering::Less => {
86                error!(
87                    &self.logger,
88                    "Software version '{}' is older than database structure version '{}'.",
89                    db_version.version,
90                    migration_version,
91                );
92
93                Err(anyhow!("This software version is older than the database structure. Aborting launch to prevent possible data corruption."))?;
94            }
95            Ordering::Equal => {
96                debug!(&self.logger, "database up to date");
97            }
98        };
99
100        Ok(())
101    }
102
103    fn apply_migrations(
104        &self,
105        starting_version: &DatabaseVersion,
106        connection: &SqliteConnection,
107    ) -> StdResult<()> {
108        for migration in &self
109            .migrations
110            .iter()
111            .filter(|&m| m.version > starting_version.version)
112            .collect::<Vec<&SqlMigration>>()
113        {
114            self.check_minimum_required_version(starting_version.version, migration)?;
115            connection.execute(&migration.alterations)?;
116            let db_version = DatabaseVersion {
117                version: migration.version,
118                application_type: self.application_type.clone(),
119                updated_at: Utc::now(),
120            };
121            let _ = connection
122                .fetch_first(UpdateDatabaseVersionQuery::one(db_version))
123                .with_context(|| {
124                    format!(
125                        "Can not save database version when applying migration: '{}'",
126                        migration.version
127                    )
128                })?;
129        }
130
131        Ok(())
132    }
133
134    /// Method to create the table at the beginning of the migration procedure.
135    /// This code is temporary and should not last.
136    pub fn create_table_if_not_exists(
137        &self,
138        application_type: &ApplicationNodeType,
139    ) -> StdResult<()> {
140        let connection = self.connection;
141        let table_exists = connection.query_single_cell::<_, i64>(
142            "select exists(select name from sqlite_master where type='table' and name='db_version') as table_exists",
143            &[],
144        )? == 1;
145
146        if !table_exists {
147            let sql = format!("
148create table db_version (application_type text not null primary key, version integer not null, updated_at text not null);
149insert into db_version (application_type, version, updated_at) values ('{application_type}', 0, '{}');
150", Utc::now().to_rfc3339());
151            connection.execute(sql)?;
152        }
153
154        Ok(())
155    }
156
157    /// Checks if the database version meets the minimum required version to apply a squashed migration.
158    /// If the database version 0 or if the migration doesn't specify a fallback distribution version, the check passes.
159    /// For migrations with a fallback distribution version, the check passes if the database version is exactly
160    /// one less than the migration version (i.e., there's no gap between them).
161    fn check_minimum_required_version(
162        &self,
163        db_version: DbVersion,
164        migration: &SqlMigration,
165    ) -> StdResult<()> {
166        if db_version == 0 {
167            return Ok(());
168        }
169
170        if let Some(fallback_distribution_version) = &migration.fallback_distribution_version {
171            let min_required_version = migration.version - 1;
172            if db_version < min_required_version {
173                return Err(self.generate_fallback_migration_error(
174                    migration.version,
175                    fallback_distribution_version,
176                ));
177            }
178        }
179
180        Ok(())
181    }
182
183    fn generate_fallback_migration_error(
184        &self,
185        migration_version: i64,
186        fallback_distribution_version: &str,
187    ) -> StdError {
188        anyhow!(
189            r#"
190                Minimum required database version is not met to apply migration '{}'.
191                Please migrate your {} node database with the minimum node version compatible available in the distribution: '{}'.
192
193                First, download the required node version in your current directory by running the following command:
194                curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-{} -d {} -p $(pwd)
195
196                Then run the database migrate command:
197                mithril-{} database migrate --stores-directory /path/to/stores-directory
198            "#,
199            migration_version,
200            self.application_type.to_string(),
201            fallback_distribution_version,
202            self.application_type.to_string(),
203            fallback_distribution_version,
204            self.application_type.to_string()
205        )
206    }
207}
208
209/// Represent a file containing SQL structure or data alterations.
210#[derive(Debug, Clone)]
211pub struct SqlMigration {
212    /// The semver version this migration targets.
213    pub version: DbVersion,
214
215    /// SQL statements to alter the database.
216    pub alterations: String,
217
218    /// The distribution version the user can fallback to in order to update their database before updating to the latest node.
219    pub fallback_distribution_version: Option<String>,
220}
221
222impl SqlMigration {
223    /// Create a new SQL migration instance.
224    pub fn new<T: Into<String>>(version: DbVersion, alteration: T) -> Self {
225        Self {
226            version,
227            alterations: alteration.into(),
228            fallback_distribution_version: None,
229        }
230    }
231
232    /// Create a new squashed SQL migration instance with the fallback distribution version.
233    pub fn new_squashed<T: Into<String>>(
234        version: DbVersion,
235        fallback_distribution_version: T,
236        alteration: T,
237    ) -> Self {
238        Self {
239            version,
240            alterations: alteration.into(),
241            fallback_distribution_version: Some(fallback_distribution_version.into()),
242        }
243    }
244}
245
246impl PartialOrd for SqlMigration {
247    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
248        Some(self.cmp(other))
249    }
250}
251
252impl Ord for SqlMigration {
253    fn cmp(&self, other: &Self) -> Ordering {
254        self.version.cmp(&other.version)
255    }
256}
257
258impl PartialEq for SqlMigration {
259    fn eq(&self, other: &Self) -> bool {
260        self.version.eq(&other.version)
261    }
262}
263
264impl Eq for SqlMigration {}
265
266#[cfg(test)]
267mod tests {
268    use anyhow::Context;
269    use mithril_common::test_utils::TempDir;
270    use mithril_common::{current_function, StdResult};
271    use sqlite::{Connection, ConnectionThreadSafe};
272    use std::path::PathBuf;
273
274    use super::*;
275
276    const CREATE_TABLE_SQL_REQUEST: &str = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
277    const ALTER_TABLE_SQL_REQUEST: &str = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
278
279    fn discard_logger() -> Logger {
280        Logger::root(slog::Discard, slog::o!())
281    }
282
283    fn check_database_version(connection: &SqliteConnection, db_version: DbVersion) {
284        let version = connection
285            .fetch_first(GetDatabaseVersionQuery::get_application_version(
286                &ApplicationNodeType::Aggregator,
287            ))
288            .unwrap()
289            .unwrap();
290
291        assert_eq!(db_version, version.version);
292    }
293
294    fn create_sqlite_file(test_name: &str) -> StdResult<(PathBuf, SqliteConnection)> {
295        let dirpath = TempDir::create("mithril_test_database", test_name);
296        let filepath = dirpath.join("db.sqlite3");
297
298        let connection = Connection::open_thread_safe(&filepath)
299            .with_context(|| "connection to sqlite file failure")?;
300
301        Ok((filepath, connection))
302    }
303
304    fn get_table_whatever_column_count(connection: &SqliteConnection) -> i64 {
305        let sql = "select count(*) as column_count from pragma_table_info('whatever');";
306        let column_count = connection
307            .prepare(sql)
308            .unwrap()
309            .iter()
310            .next()
311            .unwrap()
312            .unwrap()
313            .read::<i64, _>(0);
314
315        column_count
316    }
317
318    fn create_db_checker(connection: &ConnectionThreadSafe) -> DatabaseVersionChecker {
319        DatabaseVersionChecker::new(
320            discard_logger(),
321            ApplicationNodeType::Aggregator,
322            connection,
323        )
324    }
325
326    #[test]
327    fn test_upgrade_with_migration() {
328        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
329        let mut db_checker = create_db_checker(&connection);
330
331        db_checker.apply().unwrap();
332        assert_eq!(0, get_table_whatever_column_count(&connection));
333
334        db_checker.apply().unwrap();
335        assert_eq!(0, get_table_whatever_column_count(&connection));
336
337        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
338        let migration = SqlMigration {
339            version: 1,
340            alterations: alterations.to_string(),
341            fallback_distribution_version: None,
342        };
343        db_checker.add_migration(migration);
344        db_checker.apply().unwrap();
345        assert_eq!(1, get_table_whatever_column_count(&connection));
346        check_database_version(&connection, 1);
347
348        db_checker.apply().unwrap();
349        assert_eq!(1, get_table_whatever_column_count(&connection));
350        check_database_version(&connection, 1);
351
352        let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
353        let migration = SqlMigration {
354            version: 2,
355            alterations: alterations.to_string(),
356            fallback_distribution_version: None,
357        };
358        db_checker.add_migration(migration);
359        db_checker.apply().unwrap();
360        assert_eq!(2, get_table_whatever_column_count(&connection));
361        check_database_version(&connection, 2);
362
363        // in the test below both migrations are declared in reversed order to
364        // ensure they are played in the right order. The last one depends on
365        // the 3rd.
366        let alterations = "alter table whatever add column one_last_thing text; update whatever set one_last_thing = more_thing";
367        let migration = SqlMigration {
368            version: 4,
369            alterations: alterations.to_string(),
370            fallback_distribution_version: None,
371        };
372        db_checker.add_migration(migration);
373        let alterations = "alter table whatever add column more_thing text; update whatever set more_thing = 'more thing'";
374        let migration = SqlMigration {
375            version: 3,
376            alterations: alterations.to_string(),
377            fallback_distribution_version: None,
378        };
379        db_checker.add_migration(migration);
380        db_checker.apply().unwrap();
381        assert_eq!(4, get_table_whatever_column_count(&connection));
382        check_database_version(&connection, 4);
383    }
384
385    #[test]
386    fn test_upgrade_with_migration_with_a_version_gap() {
387        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
388        let mut db_checker = create_db_checker(&connection);
389
390        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
391        let migration = SqlMigration {
392            version: 3,
393            alterations: alterations.to_string(),
394            fallback_distribution_version: None,
395        };
396        db_checker.add_migration(migration);
397        db_checker.apply().unwrap();
398        assert_eq!(1, get_table_whatever_column_count(&connection));
399        check_database_version(&connection, 3);
400
401        let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
402        let migration_with_version_gap = SqlMigration {
403            version: 10,
404            alterations: alterations.to_string(),
405            fallback_distribution_version: None,
406        };
407        db_checker.add_migration(migration_with_version_gap);
408        db_checker.apply().unwrap();
409        assert_eq!(2, get_table_whatever_column_count(&connection));
410        check_database_version(&connection, 10);
411    }
412
413    #[test]
414    fn starting_with_migration() {
415        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
416        let mut db_checker = create_db_checker(&connection);
417
418        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
419        let migration = SqlMigration {
420            version: 1,
421            alterations: alterations.to_string(),
422            fallback_distribution_version: None,
423        };
424        db_checker.add_migration(migration);
425        db_checker.apply().unwrap();
426        assert_eq!(1, get_table_whatever_column_count(&connection));
427        check_database_version(&connection, 1);
428    }
429
430    #[test]
431    /// This test case ensure that when multiple migrations are played and one fails:
432    /// * previous migrations are ok and the database version is updated
433    /// * further migrations are not played.
434    fn test_failing_migration() {
435        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
436        let mut db_checker = create_db_checker(&connection);
437        // Table whatever does not exist, this should fail with error.
438        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
439        let migration = SqlMigration {
440            version: 1,
441            alterations: alterations.to_string(),
442            fallback_distribution_version: None,
443        };
444        db_checker.add_migration(migration);
445        let alterations = "alter table wrong add column thing_content text; update whatever set thing_content = 'some content'";
446        let migration = SqlMigration {
447            version: 2,
448            alterations: alterations.to_string(),
449            fallback_distribution_version: None,
450        };
451        db_checker.add_migration(migration);
452        let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
453        let migration = SqlMigration {
454            version: 3,
455            alterations: alterations.to_string(),
456            fallback_distribution_version: None,
457        };
458        db_checker.add_migration(migration);
459        db_checker.apply().unwrap_err();
460        check_database_version(&connection, 1);
461    }
462
463    #[test]
464    fn test_fail_downgrading() {
465        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
466        let mut db_checker = create_db_checker(&connection);
467        let migration = SqlMigration {
468            version: 1,
469            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
470            fallback_distribution_version: None,
471        };
472        db_checker.add_migration(migration);
473        db_checker.apply().unwrap();
474        check_database_version(&connection, 1);
475
476        // re instantiate a new checker with no migration registered (version 0).
477        let db_checker = create_db_checker(&connection);
478        assert!(
479            db_checker.apply().is_err(),
480            "using an old version with an up to date database should fail"
481        );
482        check_database_version(&connection, 1);
483    }
484
485    #[test]
486    fn check_minimum_required_version_does_not_fail_when_no_fallback_distribution_version() {
487        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
488        let db_checker = create_db_checker(&connection);
489
490        let alterations = CREATE_TABLE_SQL_REQUEST;
491        let migration = SqlMigration {
492            version: 3,
493            alterations: alterations.to_string(),
494            fallback_distribution_version: None,
495        };
496
497        db_checker
498            .check_minimum_required_version(1, &migration)
499            .expect(
500            "Check minimum required version should not fail when no fallback distribution version",
501        );
502    }
503
504    #[test]
505    fn check_minimum_required_version_does_not_fail_when_fallback_distribution_version_with_fresh_database(
506    ) {
507        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
508        let db_checker = create_db_checker(&connection);
509
510        let alterations = CREATE_TABLE_SQL_REQUEST;
511        let migration = SqlMigration {
512            version: 2,
513            alterations: alterations.to_string(),
514            fallback_distribution_version: Some("2511.0".to_string()),
515        };
516
517        db_checker
518            .check_minimum_required_version(0, &migration)
519            .expect("Check minimum required version should not fail with fresh database");
520    }
521
522    #[test]
523    fn check_minimum_required_version_does_not_fail_when_no_gap_between_db_version_and_migration_version(
524    ) {
525        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
526        let db_checker = create_db_checker(&connection);
527
528        let migration = SqlMigration {
529            version: 2,
530            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
531            fallback_distribution_version: Some("2511.0".to_string()),
532        };
533
534        db_checker
535            .check_minimum_required_version(1, &migration)
536            .expect("Check minimum required version should not fail when no gap between db version and migration version");
537    }
538
539    #[test]
540    fn check_minimum_required_version_fails_when_gap_between_db_version_and_migration_version() {
541        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
542        let db_checker = DatabaseVersionChecker::new(
543            discard_logger(),
544            ApplicationNodeType::Aggregator,
545            &connection,
546        );
547
548        let migration = SqlMigration {
549            version: 3,
550            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
551            fallback_distribution_version: Some("2511.0".to_string()),
552        };
553
554        let error = db_checker
555            .check_minimum_required_version(1, &migration)
556            .expect_err("Check minimum required version should fail when gap between db version and migration version");
557
558        assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
559    }
560
561    #[test]
562    fn apply_fails_when_trying_to_apply_squashed_migration_on_old_database() {
563        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
564        let mut db_checker = DatabaseVersionChecker::new(
565            discard_logger(),
566            ApplicationNodeType::Aggregator,
567            &connection,
568        );
569
570        let migration = SqlMigration {
571            version: 1,
572            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
573            fallback_distribution_version: None,
574        };
575        db_checker.add_migration(migration);
576        db_checker.apply().unwrap();
577        check_database_version(&connection, 1);
578
579        let squashed_migration = SqlMigration {
580            version: 3,
581            alterations: ALTER_TABLE_SQL_REQUEST.to_string(),
582            fallback_distribution_version: Some("2511.0".to_string()),
583        };
584        db_checker.add_migration(squashed_migration);
585
586        let error = db_checker
587            .apply()
588            .expect_err("Should fail when applying squashed migration on old database");
589
590        assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
591    }
592}