mithril_persistence/database/
version_checker.rs

1use anyhow::{Context, anyhow};
2use chrono::Utc;
3use slog::{Logger, debug, error, info};
4use std::{cmp::Ordering, collections::BTreeSet};
5
6use mithril_common::{StdError, StdResult, logging::LoggerExtensions};
7
8use super::{
9    ApplicationNodeType, DatabaseVersion, DbVersion, GetDatabaseVersionQuery,
10    UpdateDatabaseVersionQuery,
11};
12
13use crate::sqlite::{ConnectionExtensions, SqliteConnection};
14
15/// Struct to perform application version check in the database.
16pub struct DatabaseVersionChecker<'conn> {
17    /// Pathbuf to the SQLite3 file.
18    connection: &'conn SqliteConnection,
19
20    /// Application type which vesion is verified.
21    application_type: ApplicationNodeType,
22
23    /// logger
24    logger: Logger,
25
26    /// known migrations
27    migrations: BTreeSet<SqlMigration>,
28}
29
30impl<'conn> DatabaseVersionChecker<'conn> {
31    /// constructor
32    pub fn new(
33        logger: Logger,
34        application_type: ApplicationNodeType,
35        connection: &'conn SqliteConnection,
36    ) -> Self {
37        let migrations = BTreeSet::new();
38
39        Self {
40            connection,
41            application_type,
42            logger: logger.new_with_component_name::<Self>(),
43            migrations,
44        }
45    }
46
47    /// Register a migration.
48    pub fn add_migration(&mut self, migration: SqlMigration) -> &mut Self {
49        let _ = self.migrations.insert(migration);
50
51        self
52    }
53
54    /// Apply migrations
55    pub fn apply(&self) -> StdResult<()> {
56        debug!(&self.logger, "Check database version",);
57        self.create_table_if_not_exists(&self.application_type)
58            .with_context(|| "Can not create table 'db_version' while applying migrations")?;
59        let db_version = self
60            .connection
61            .fetch_first(GetDatabaseVersionQuery::get_application_version(
62                &self.application_type,
63            ))
64            .with_context(|| "Can not get application version while applying migrations")?
65            .unwrap(); // At least a record exists.
66
67        // the current database version is equal to the maximum migration
68        // version present in this software.
69        // If no migration registered then version = 0.
70        let migration_version = self.migrations.iter().map(|m| m.version).max().unwrap_or(0);
71
72        match migration_version.cmp(&db_version.version) {
73            Ordering::Greater => {
74                debug!(
75                    &self.logger,
76                    "Database needs upgrade from version '{}' to version '{}', applying new migrations…",
77                    db_version.version,
78                    migration_version
79                );
80                self.apply_migrations(&db_version, self.connection)?;
81                info!(
82                    &self.logger,
83                    "Database upgraded to version '{migration_version}'"
84                );
85            }
86            Ordering::Less => {
87                error!(
88                    &self.logger,
89                    "Software version '{}' is older than database structure version '{}'.",
90                    db_version.version,
91                    migration_version,
92                );
93
94                Err(anyhow!(
95                    "This software version is older than the database structure. Aborting launch to prevent possible data corruption."
96                ))?;
97            }
98            Ordering::Equal => {
99                debug!(&self.logger, "database up to date");
100            }
101        };
102
103        Ok(())
104    }
105
106    fn apply_migrations(
107        &self,
108        starting_version: &DatabaseVersion,
109        connection: &SqliteConnection,
110    ) -> StdResult<()> {
111        for migration in &self
112            .migrations
113            .iter()
114            .filter(|&m| m.version > starting_version.version)
115            .collect::<Vec<&SqlMigration>>()
116        {
117            self.check_minimum_required_version(starting_version.version, migration)?;
118            connection.execute(&migration.alterations)?;
119            let db_version = DatabaseVersion {
120                version: migration.version,
121                application_type: self.application_type.clone(),
122                updated_at: Utc::now(),
123            };
124            let _ = connection
125                .fetch_first(UpdateDatabaseVersionQuery::one(db_version))
126                .with_context(|| {
127                    format!(
128                        "Can not save database version when applying migration: '{}'",
129                        migration.version
130                    )
131                })?;
132        }
133
134        Ok(())
135    }
136
137    /// Method to create the table at the beginning of the migration procedure.
138    /// This code is temporary and should not last.
139    pub fn create_table_if_not_exists(
140        &self,
141        application_type: &ApplicationNodeType,
142    ) -> StdResult<()> {
143        let connection = self.connection;
144        let table_exists = connection.query_single_cell::<_, i64>(
145            "select exists(select name from sqlite_master where type='table' and name='db_version') as table_exists",
146            &[],
147        )? == 1;
148
149        if !table_exists {
150            let sql = format!("
151create table db_version (application_type text not null primary key, version integer not null, updated_at text not null);
152insert into db_version (application_type, version, updated_at) values ('{application_type}', 0, '{}');
153", Utc::now().to_rfc3339());
154            connection.execute(sql)?;
155        }
156
157        Ok(())
158    }
159
160    /// Checks if the database version meets the minimum required version to apply a squashed migration.
161    /// If the database version 0 or if the migration doesn't specify a fallback distribution version, the check passes.
162    /// For migrations with a fallback distribution version, the check passes if the database version is exactly
163    /// one less than the migration version (i.e., there's no gap between them).
164    fn check_minimum_required_version(
165        &self,
166        db_version: DbVersion,
167        migration: &SqlMigration,
168    ) -> StdResult<()> {
169        if db_version == 0 {
170            return Ok(());
171        }
172
173        if let Some(fallback_distribution_version) = &migration.fallback_distribution_version {
174            let min_required_version = migration.version - 1;
175            if db_version < min_required_version {
176                return Err(self.generate_fallback_migration_error(
177                    migration.version,
178                    fallback_distribution_version,
179                ));
180            }
181        }
182
183        Ok(())
184    }
185
186    fn generate_fallback_migration_error(
187        &self,
188        migration_version: i64,
189        fallback_distribution_version: &str,
190    ) -> StdError {
191        anyhow!(
192            r#"
193                Minimum required database version is not met to apply migration '{}'.
194                Please migrate your {} node database with the minimum node version compatible available in the distribution: '{}'.
195
196                First, download the required node version in your current directory by running the following command:
197                curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-{} -d {} -p $(pwd)
198
199                Then run the database migrate command:
200                mithril-{} database migrate --stores-directory /path/to/stores-directory
201            "#,
202            migration_version,
203            self.application_type.to_string(),
204            fallback_distribution_version,
205            self.application_type.to_string(),
206            fallback_distribution_version,
207            self.application_type.to_string()
208        )
209    }
210}
211
212/// Represent a file containing SQL structure or data alterations.
213#[derive(Debug, Clone)]
214pub struct SqlMigration {
215    /// The semver version this migration targets.
216    pub version: DbVersion,
217
218    /// SQL statements to alter the database.
219    pub alterations: String,
220
221    /// The distribution version the user can fallback to in order to update their database before updating to the latest node.
222    pub fallback_distribution_version: Option<String>,
223}
224
225impl SqlMigration {
226    /// Create a new SQL migration instance.
227    pub fn new<T: Into<String>>(version: DbVersion, alteration: T) -> Self {
228        Self {
229            version,
230            alterations: alteration.into(),
231            fallback_distribution_version: None,
232        }
233    }
234
235    /// Create a new squashed SQL migration instance with the fallback distribution version.
236    pub fn new_squashed<T: Into<String>>(
237        version: DbVersion,
238        fallback_distribution_version: T,
239        alteration: T,
240    ) -> Self {
241        Self {
242            version,
243            alterations: alteration.into(),
244            fallback_distribution_version: Some(fallback_distribution_version.into()),
245        }
246    }
247}
248
249impl PartialOrd for SqlMigration {
250    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
251        Some(self.cmp(other))
252    }
253}
254
255impl Ord for SqlMigration {
256    fn cmp(&self, other: &Self) -> Ordering {
257        self.version.cmp(&other.version)
258    }
259}
260
261impl PartialEq for SqlMigration {
262    fn eq(&self, other: &Self) -> bool {
263        self.version.eq(&other.version)
264    }
265}
266
267impl Eq for SqlMigration {}
268
269#[cfg(test)]
270mod tests {
271    use anyhow::Context;
272    use mithril_common::test_utils::TempDir;
273    use mithril_common::{StdResult, current_function};
274    use sqlite::{Connection, ConnectionThreadSafe};
275    use std::path::PathBuf;
276
277    use super::*;
278
279    const CREATE_TABLE_SQL_REQUEST: &str = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
280    const ALTER_TABLE_SQL_REQUEST: &str = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
281
282    fn discard_logger() -> Logger {
283        Logger::root(slog::Discard, slog::o!())
284    }
285
286    fn check_database_version(connection: &SqliteConnection, db_version: DbVersion) {
287        let version = connection
288            .fetch_first(GetDatabaseVersionQuery::get_application_version(
289                &ApplicationNodeType::Aggregator,
290            ))
291            .unwrap()
292            .unwrap();
293
294        assert_eq!(db_version, version.version);
295    }
296
297    fn create_sqlite_file(test_name: &str) -> StdResult<(PathBuf, SqliteConnection)> {
298        let dirpath = TempDir::create("mithril_test_database", test_name);
299        let filepath = dirpath.join("db.sqlite3");
300
301        let connection = Connection::open_thread_safe(&filepath)
302            .with_context(|| "connection to sqlite file failure")?;
303
304        Ok((filepath, connection))
305    }
306
307    fn get_table_whatever_column_count(connection: &SqliteConnection) -> i64 {
308        let sql = "select count(*) as column_count from pragma_table_info('whatever');";
309
310        connection
311            .prepare(sql)
312            .unwrap()
313            .iter()
314            .next()
315            .unwrap()
316            .unwrap()
317            .read::<i64, _>(0)
318    }
319
320    fn create_db_checker(connection: &ConnectionThreadSafe) -> DatabaseVersionChecker {
321        DatabaseVersionChecker::new(
322            discard_logger(),
323            ApplicationNodeType::Aggregator,
324            connection,
325        )
326    }
327
328    #[test]
329    fn test_upgrade_with_migration() {
330        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
331        let mut db_checker = create_db_checker(&connection);
332
333        db_checker.apply().unwrap();
334        assert_eq!(0, get_table_whatever_column_count(&connection));
335
336        db_checker.apply().unwrap();
337        assert_eq!(0, get_table_whatever_column_count(&connection));
338
339        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
340        let migration = SqlMigration {
341            version: 1,
342            alterations: alterations.to_string(),
343            fallback_distribution_version: None,
344        };
345        db_checker.add_migration(migration);
346        db_checker.apply().unwrap();
347        assert_eq!(1, get_table_whatever_column_count(&connection));
348        check_database_version(&connection, 1);
349
350        db_checker.apply().unwrap();
351        assert_eq!(1, get_table_whatever_column_count(&connection));
352        check_database_version(&connection, 1);
353
354        let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
355        let migration = SqlMigration {
356            version: 2,
357            alterations: alterations.to_string(),
358            fallback_distribution_version: None,
359        };
360        db_checker.add_migration(migration);
361        db_checker.apply().unwrap();
362        assert_eq!(2, get_table_whatever_column_count(&connection));
363        check_database_version(&connection, 2);
364
365        // in the test below both migrations are declared in reversed order to
366        // ensure they are played in the right order. The last one depends on
367        // the 3rd.
368        let alterations = "alter table whatever add column one_last_thing text; update whatever set one_last_thing = more_thing";
369        let migration = SqlMigration {
370            version: 4,
371            alterations: alterations.to_string(),
372            fallback_distribution_version: None,
373        };
374        db_checker.add_migration(migration);
375        let alterations = "alter table whatever add column more_thing text; update whatever set more_thing = 'more thing'";
376        let migration = SqlMigration {
377            version: 3,
378            alterations: alterations.to_string(),
379            fallback_distribution_version: None,
380        };
381        db_checker.add_migration(migration);
382        db_checker.apply().unwrap();
383        assert_eq!(4, get_table_whatever_column_count(&connection));
384        check_database_version(&connection, 4);
385    }
386
387    #[test]
388    fn test_upgrade_with_migration_with_a_version_gap() {
389        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
390        let mut db_checker = create_db_checker(&connection);
391
392        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
393        let migration = SqlMigration {
394            version: 3,
395            alterations: alterations.to_string(),
396            fallback_distribution_version: None,
397        };
398        db_checker.add_migration(migration);
399        db_checker.apply().unwrap();
400        assert_eq!(1, get_table_whatever_column_count(&connection));
401        check_database_version(&connection, 3);
402
403        let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
404        let migration_with_version_gap = SqlMigration {
405            version: 10,
406            alterations: alterations.to_string(),
407            fallback_distribution_version: None,
408        };
409        db_checker.add_migration(migration_with_version_gap);
410        db_checker.apply().unwrap();
411        assert_eq!(2, get_table_whatever_column_count(&connection));
412        check_database_version(&connection, 10);
413    }
414
415    #[test]
416    fn starting_with_migration() {
417        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
418        let mut db_checker = create_db_checker(&connection);
419
420        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
421        let migration = SqlMigration {
422            version: 1,
423            alterations: alterations.to_string(),
424            fallback_distribution_version: None,
425        };
426        db_checker.add_migration(migration);
427        db_checker.apply().unwrap();
428        assert_eq!(1, get_table_whatever_column_count(&connection));
429        check_database_version(&connection, 1);
430    }
431
432    #[test]
433    /// This test case ensure that when multiple migrations are played and one fails:
434    /// * previous migrations are ok and the database version is updated
435    /// * further migrations are not played.
436    fn test_failing_migration() {
437        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
438        let mut db_checker = create_db_checker(&connection);
439        // Table whatever does not exist, this should fail with error.
440        let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
441        let migration = SqlMigration {
442            version: 1,
443            alterations: alterations.to_string(),
444            fallback_distribution_version: None,
445        };
446        db_checker.add_migration(migration);
447        let alterations = "alter table wrong add column thing_content text; update whatever set thing_content = 'some content'";
448        let migration = SqlMigration {
449            version: 2,
450            alterations: alterations.to_string(),
451            fallback_distribution_version: None,
452        };
453        db_checker.add_migration(migration);
454        let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
455        let migration = SqlMigration {
456            version: 3,
457            alterations: alterations.to_string(),
458            fallback_distribution_version: None,
459        };
460        db_checker.add_migration(migration);
461        db_checker.apply().unwrap_err();
462        check_database_version(&connection, 1);
463    }
464
465    #[test]
466    fn test_fail_downgrading() {
467        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
468        let mut db_checker = create_db_checker(&connection);
469        let migration = SqlMigration {
470            version: 1,
471            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
472            fallback_distribution_version: None,
473        };
474        db_checker.add_migration(migration);
475        db_checker.apply().unwrap();
476        check_database_version(&connection, 1);
477
478        // re instantiate a new checker with no migration registered (version 0).
479        let db_checker = create_db_checker(&connection);
480        assert!(
481            db_checker.apply().is_err(),
482            "using an old version with an up to date database should fail"
483        );
484        check_database_version(&connection, 1);
485    }
486
487    #[test]
488    fn check_minimum_required_version_does_not_fail_when_no_fallback_distribution_version() {
489        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
490        let db_checker = create_db_checker(&connection);
491
492        let alterations = CREATE_TABLE_SQL_REQUEST;
493        let migration = SqlMigration {
494            version: 3,
495            alterations: alterations.to_string(),
496            fallback_distribution_version: None,
497        };
498
499        db_checker.check_minimum_required_version(1, &migration).expect(
500            "Check minimum required version should not fail when no fallback distribution version",
501        );
502    }
503
504    #[test]
505    fn check_minimum_required_version_does_not_fail_when_fallback_distribution_version_with_fresh_database()
506     {
507        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
508        let db_checker = create_db_checker(&connection);
509
510        let alterations = CREATE_TABLE_SQL_REQUEST;
511        let migration = SqlMigration {
512            version: 2,
513            alterations: alterations.to_string(),
514            fallback_distribution_version: Some("2511.0".to_string()),
515        };
516
517        db_checker
518            .check_minimum_required_version(0, &migration)
519            .expect("Check minimum required version should not fail with fresh database");
520    }
521
522    #[test]
523    fn check_minimum_required_version_does_not_fail_when_no_gap_between_db_version_and_migration_version()
524     {
525        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
526        let db_checker = create_db_checker(&connection);
527
528        let migration = SqlMigration {
529            version: 2,
530            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
531            fallback_distribution_version: Some("2511.0".to_string()),
532        };
533
534        db_checker
535            .check_minimum_required_version(1, &migration)
536            .expect("Check minimum required version should not fail when no gap between db version and migration version");
537    }
538
539    #[test]
540    fn check_minimum_required_version_fails_when_gap_between_db_version_and_migration_version() {
541        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
542        let db_checker = DatabaseVersionChecker::new(
543            discard_logger(),
544            ApplicationNodeType::Aggregator,
545            &connection,
546        );
547
548        let migration = SqlMigration {
549            version: 3,
550            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
551            fallback_distribution_version: Some("2511.0".to_string()),
552        };
553
554        let error = db_checker
555            .check_minimum_required_version(1, &migration)
556            .expect_err("Check minimum required version should fail when gap between db version and migration version");
557
558        assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
559    }
560
561    #[test]
562    fn apply_fails_when_trying_to_apply_squashed_migration_on_old_database() {
563        let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
564        let mut db_checker = DatabaseVersionChecker::new(
565            discard_logger(),
566            ApplicationNodeType::Aggregator,
567            &connection,
568        );
569
570        let migration = SqlMigration {
571            version: 1,
572            alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
573            fallback_distribution_version: None,
574        };
575        db_checker.add_migration(migration);
576        db_checker.apply().unwrap();
577        check_database_version(&connection, 1);
578
579        let squashed_migration = SqlMigration {
580            version: 3,
581            alterations: ALTER_TABLE_SQL_REQUEST.to_string(),
582            fallback_distribution_version: Some("2511.0".to_string()),
583        };
584        db_checker.add_migration(squashed_migration);
585
586        let error = db_checker
587            .apply()
588            .expect_err("Should fail when applying squashed migration on old database");
589
590        assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
591    }
592}