1use anyhow::{Context, anyhow};
2use chrono::Utc;
3use slog::{Logger, debug, error, info};
4use std::{cmp::Ordering, collections::BTreeSet};
5
6use mithril_common::{StdError, StdResult, logging::LoggerExtensions};
7
8use super::{
9 ApplicationNodeType, DatabaseVersion, DbVersion, GetDatabaseVersionQuery,
10 UpdateDatabaseVersionQuery,
11};
12
13use crate::sqlite::{ConnectionExtensions, SqliteConnection};
14
15pub struct DatabaseVersionChecker<'conn> {
17 connection: &'conn SqliteConnection,
19
20 application_type: ApplicationNodeType,
22
23 logger: Logger,
25
26 migrations: BTreeSet<SqlMigration>,
28}
29
30impl<'conn> DatabaseVersionChecker<'conn> {
31 pub fn new(
33 logger: Logger,
34 application_type: ApplicationNodeType,
35 connection: &'conn SqliteConnection,
36 ) -> Self {
37 let migrations = BTreeSet::new();
38
39 Self {
40 connection,
41 application_type,
42 logger: logger.new_with_component_name::<Self>(),
43 migrations,
44 }
45 }
46
47 pub fn add_migration(&mut self, migration: SqlMigration) -> &mut Self {
49 let _ = self.migrations.insert(migration);
50
51 self
52 }
53
54 pub fn apply(&self) -> StdResult<()> {
56 debug!(&self.logger, "Check database version",);
57 self.create_table_if_not_exists(&self.application_type)
58 .with_context(|| "Can not create table 'db_version' while applying migrations")?;
59 let db_version = self
60 .connection
61 .fetch_first(GetDatabaseVersionQuery::get_application_version(
62 &self.application_type,
63 ))
64 .with_context(|| "Can not get application version while applying migrations")?
65 .unwrap(); let migration_version = self.migrations.iter().map(|m| m.version).max().unwrap_or(0);
71
72 match migration_version.cmp(&db_version.version) {
73 Ordering::Greater => {
74 debug!(
75 &self.logger,
76 "Database needs upgrade from version '{}' to version '{}', applying new migrations…",
77 db_version.version,
78 migration_version
79 );
80 self.apply_migrations(&db_version, self.connection)?;
81 info!(
82 &self.logger,
83 "Database upgraded to version '{migration_version}'"
84 );
85 }
86 Ordering::Less => {
87 error!(
88 &self.logger,
89 "Software version '{}' is older than database structure version '{}'.",
90 db_version.version,
91 migration_version,
92 );
93
94 Err(anyhow!(
95 "This software version is older than the database structure. Aborting launch to prevent possible data corruption."
96 ))?;
97 }
98 Ordering::Equal => {
99 debug!(&self.logger, "database up to date");
100 }
101 };
102
103 Ok(())
104 }
105
106 fn apply_migrations(
107 &self,
108 starting_version: &DatabaseVersion,
109 connection: &SqliteConnection,
110 ) -> StdResult<()> {
111 for migration in &self
112 .migrations
113 .iter()
114 .filter(|&m| m.version > starting_version.version)
115 .collect::<Vec<&SqlMigration>>()
116 {
117 self.check_minimum_required_version(starting_version.version, migration)?;
118 connection.execute(&migration.alterations)?;
119 let db_version = DatabaseVersion {
120 version: migration.version,
121 application_type: self.application_type.clone(),
122 updated_at: Utc::now(),
123 };
124 let _ = connection
125 .fetch_first(UpdateDatabaseVersionQuery::one(db_version))
126 .with_context(|| {
127 format!(
128 "Can not save database version when applying migration: '{}'",
129 migration.version
130 )
131 })?;
132 }
133
134 Ok(())
135 }
136
137 pub fn create_table_if_not_exists(
140 &self,
141 application_type: &ApplicationNodeType,
142 ) -> StdResult<()> {
143 let connection = self.connection;
144 let table_exists = connection.query_single_cell::<_, i64>(
145 "select exists(select name from sqlite_master where type='table' and name='db_version') as table_exists",
146 &[],
147 )? == 1;
148
149 if !table_exists {
150 let sql = format!("
151create table db_version (application_type text not null primary key, version integer not null, updated_at text not null);
152insert into db_version (application_type, version, updated_at) values ('{application_type}', 0, '{}');
153", Utc::now().to_rfc3339());
154 connection.execute(sql)?;
155 }
156
157 Ok(())
158 }
159
160 fn check_minimum_required_version(
165 &self,
166 db_version: DbVersion,
167 migration: &SqlMigration,
168 ) -> StdResult<()> {
169 if db_version == 0 {
170 return Ok(());
171 }
172
173 if let Some(fallback_distribution_version) = &migration.fallback_distribution_version {
174 let min_required_version = migration.version - 1;
175 if db_version < min_required_version {
176 return Err(self.generate_fallback_migration_error(
177 migration.version,
178 fallback_distribution_version,
179 ));
180 }
181 }
182
183 Ok(())
184 }
185
186 fn generate_fallback_migration_error(
187 &self,
188 migration_version: i64,
189 fallback_distribution_version: &str,
190 ) -> StdError {
191 anyhow!(
192 r#"
193 Minimum required database version is not met to apply migration '{}'.
194 Please migrate your {} node database with the minimum node version compatible available in the distribution: '{}'.
195
196 First, download the required node version in your current directory by running the following command:
197 curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-{} -d {} -p $(pwd)
198
199 Then run the database migrate command:
200 mithril-{} database migrate --stores-directory /path/to/stores-directory
201 "#,
202 migration_version,
203 self.application_type.to_string(),
204 fallback_distribution_version,
205 self.application_type.to_string(),
206 fallback_distribution_version,
207 self.application_type.to_string()
208 )
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct SqlMigration {
215 pub version: DbVersion,
217
218 pub alterations: String,
220
221 pub fallback_distribution_version: Option<String>,
223}
224
225impl SqlMigration {
226 pub fn new<T: Into<String>>(version: DbVersion, alteration: T) -> Self {
228 Self {
229 version,
230 alterations: alteration.into(),
231 fallback_distribution_version: None,
232 }
233 }
234
235 pub fn new_squashed<T: Into<String>>(
237 version: DbVersion,
238 fallback_distribution_version: T,
239 alteration: T,
240 ) -> Self {
241 Self {
242 version,
243 alterations: alteration.into(),
244 fallback_distribution_version: Some(fallback_distribution_version.into()),
245 }
246 }
247}
248
249impl PartialOrd for SqlMigration {
250 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
251 Some(self.cmp(other))
252 }
253}
254
255impl Ord for SqlMigration {
256 fn cmp(&self, other: &Self) -> Ordering {
257 self.version.cmp(&other.version)
258 }
259}
260
261impl PartialEq for SqlMigration {
262 fn eq(&self, other: &Self) -> bool {
263 self.version.eq(&other.version)
264 }
265}
266
267impl Eq for SqlMigration {}
268
269#[cfg(test)]
270mod tests {
271 use anyhow::Context;
272 use mithril_common::test_utils::TempDir;
273 use mithril_common::{StdResult, current_function};
274 use sqlite::{Connection, ConnectionThreadSafe};
275 use std::path::PathBuf;
276
277 use super::*;
278
279 const CREATE_TABLE_SQL_REQUEST: &str = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
280 const ALTER_TABLE_SQL_REQUEST: &str = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
281
282 fn discard_logger() -> Logger {
283 Logger::root(slog::Discard, slog::o!())
284 }
285
286 fn check_database_version(connection: &SqliteConnection, db_version: DbVersion) {
287 let version = connection
288 .fetch_first(GetDatabaseVersionQuery::get_application_version(
289 &ApplicationNodeType::Aggregator,
290 ))
291 .unwrap()
292 .unwrap();
293
294 assert_eq!(db_version, version.version);
295 }
296
297 fn create_sqlite_file(test_name: &str) -> StdResult<(PathBuf, SqliteConnection)> {
298 let dirpath = TempDir::create("mithril_test_database", test_name);
299 let filepath = dirpath.join("db.sqlite3");
300
301 let connection = Connection::open_thread_safe(&filepath)
302 .with_context(|| "connection to sqlite file failure")?;
303
304 Ok((filepath, connection))
305 }
306
307 fn get_table_whatever_column_count(connection: &SqliteConnection) -> i64 {
308 let sql = "select count(*) as column_count from pragma_table_info('whatever');";
309
310 connection
311 .prepare(sql)
312 .unwrap()
313 .iter()
314 .next()
315 .unwrap()
316 .unwrap()
317 .read::<i64, _>(0)
318 }
319
320 fn create_db_checker(connection: &ConnectionThreadSafe) -> DatabaseVersionChecker {
321 DatabaseVersionChecker::new(
322 discard_logger(),
323 ApplicationNodeType::Aggregator,
324 connection,
325 )
326 }
327
328 #[test]
329 fn test_upgrade_with_migration() {
330 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
331 let mut db_checker = create_db_checker(&connection);
332
333 db_checker.apply().unwrap();
334 assert_eq!(0, get_table_whatever_column_count(&connection));
335
336 db_checker.apply().unwrap();
337 assert_eq!(0, get_table_whatever_column_count(&connection));
338
339 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
340 let migration = SqlMigration {
341 version: 1,
342 alterations: alterations.to_string(),
343 fallback_distribution_version: None,
344 };
345 db_checker.add_migration(migration);
346 db_checker.apply().unwrap();
347 assert_eq!(1, get_table_whatever_column_count(&connection));
348 check_database_version(&connection, 1);
349
350 db_checker.apply().unwrap();
351 assert_eq!(1, get_table_whatever_column_count(&connection));
352 check_database_version(&connection, 1);
353
354 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
355 let migration = SqlMigration {
356 version: 2,
357 alterations: alterations.to_string(),
358 fallback_distribution_version: None,
359 };
360 db_checker.add_migration(migration);
361 db_checker.apply().unwrap();
362 assert_eq!(2, get_table_whatever_column_count(&connection));
363 check_database_version(&connection, 2);
364
365 let alterations = "alter table whatever add column one_last_thing text; update whatever set one_last_thing = more_thing";
369 let migration = SqlMigration {
370 version: 4,
371 alterations: alterations.to_string(),
372 fallback_distribution_version: None,
373 };
374 db_checker.add_migration(migration);
375 let alterations = "alter table whatever add column more_thing text; update whatever set more_thing = 'more thing'";
376 let migration = SqlMigration {
377 version: 3,
378 alterations: alterations.to_string(),
379 fallback_distribution_version: None,
380 };
381 db_checker.add_migration(migration);
382 db_checker.apply().unwrap();
383 assert_eq!(4, get_table_whatever_column_count(&connection));
384 check_database_version(&connection, 4);
385 }
386
387 #[test]
388 fn test_upgrade_with_migration_with_a_version_gap() {
389 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
390 let mut db_checker = create_db_checker(&connection);
391
392 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
393 let migration = SqlMigration {
394 version: 3,
395 alterations: alterations.to_string(),
396 fallback_distribution_version: None,
397 };
398 db_checker.add_migration(migration);
399 db_checker.apply().unwrap();
400 assert_eq!(1, get_table_whatever_column_count(&connection));
401 check_database_version(&connection, 3);
402
403 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
404 let migration_with_version_gap = SqlMigration {
405 version: 10,
406 alterations: alterations.to_string(),
407 fallback_distribution_version: None,
408 };
409 db_checker.add_migration(migration_with_version_gap);
410 db_checker.apply().unwrap();
411 assert_eq!(2, get_table_whatever_column_count(&connection));
412 check_database_version(&connection, 10);
413 }
414
415 #[test]
416 fn starting_with_migration() {
417 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
418 let mut db_checker = create_db_checker(&connection);
419
420 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
421 let migration = SqlMigration {
422 version: 1,
423 alterations: alterations.to_string(),
424 fallback_distribution_version: None,
425 };
426 db_checker.add_migration(migration);
427 db_checker.apply().unwrap();
428 assert_eq!(1, get_table_whatever_column_count(&connection));
429 check_database_version(&connection, 1);
430 }
431
432 #[test]
433 fn test_failing_migration() {
437 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
438 let mut db_checker = create_db_checker(&connection);
439 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
441 let migration = SqlMigration {
442 version: 1,
443 alterations: alterations.to_string(),
444 fallback_distribution_version: None,
445 };
446 db_checker.add_migration(migration);
447 let alterations = "alter table wrong add column thing_content text; update whatever set thing_content = 'some content'";
448 let migration = SqlMigration {
449 version: 2,
450 alterations: alterations.to_string(),
451 fallback_distribution_version: None,
452 };
453 db_checker.add_migration(migration);
454 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
455 let migration = SqlMigration {
456 version: 3,
457 alterations: alterations.to_string(),
458 fallback_distribution_version: None,
459 };
460 db_checker.add_migration(migration);
461 db_checker.apply().unwrap_err();
462 check_database_version(&connection, 1);
463 }
464
465 #[test]
466 fn test_fail_downgrading() {
467 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
468 let mut db_checker = create_db_checker(&connection);
469 let migration = SqlMigration {
470 version: 1,
471 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
472 fallback_distribution_version: None,
473 };
474 db_checker.add_migration(migration);
475 db_checker.apply().unwrap();
476 check_database_version(&connection, 1);
477
478 let db_checker = create_db_checker(&connection);
480 assert!(
481 db_checker.apply().is_err(),
482 "using an old version with an up to date database should fail"
483 );
484 check_database_version(&connection, 1);
485 }
486
487 #[test]
488 fn check_minimum_required_version_does_not_fail_when_no_fallback_distribution_version() {
489 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
490 let db_checker = create_db_checker(&connection);
491
492 let alterations = CREATE_TABLE_SQL_REQUEST;
493 let migration = SqlMigration {
494 version: 3,
495 alterations: alterations.to_string(),
496 fallback_distribution_version: None,
497 };
498
499 db_checker.check_minimum_required_version(1, &migration).expect(
500 "Check minimum required version should not fail when no fallback distribution version",
501 );
502 }
503
504 #[test]
505 fn check_minimum_required_version_does_not_fail_when_fallback_distribution_version_with_fresh_database()
506 {
507 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
508 let db_checker = create_db_checker(&connection);
509
510 let alterations = CREATE_TABLE_SQL_REQUEST;
511 let migration = SqlMigration {
512 version: 2,
513 alterations: alterations.to_string(),
514 fallback_distribution_version: Some("2511.0".to_string()),
515 };
516
517 db_checker
518 .check_minimum_required_version(0, &migration)
519 .expect("Check minimum required version should not fail with fresh database");
520 }
521
522 #[test]
523 fn check_minimum_required_version_does_not_fail_when_no_gap_between_db_version_and_migration_version()
524 {
525 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
526 let db_checker = create_db_checker(&connection);
527
528 let migration = SqlMigration {
529 version: 2,
530 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
531 fallback_distribution_version: Some("2511.0".to_string()),
532 };
533
534 db_checker
535 .check_minimum_required_version(1, &migration)
536 .expect("Check minimum required version should not fail when no gap between db version and migration version");
537 }
538
539 #[test]
540 fn check_minimum_required_version_fails_when_gap_between_db_version_and_migration_version() {
541 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
542 let db_checker = DatabaseVersionChecker::new(
543 discard_logger(),
544 ApplicationNodeType::Aggregator,
545 &connection,
546 );
547
548 let migration = SqlMigration {
549 version: 3,
550 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
551 fallback_distribution_version: Some("2511.0".to_string()),
552 };
553
554 let error = db_checker
555 .check_minimum_required_version(1, &migration)
556 .expect_err("Check minimum required version should fail when gap between db version and migration version");
557
558 assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
559 }
560
561 #[test]
562 fn apply_fails_when_trying_to_apply_squashed_migration_on_old_database() {
563 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
564 let mut db_checker = DatabaseVersionChecker::new(
565 discard_logger(),
566 ApplicationNodeType::Aggregator,
567 &connection,
568 );
569
570 let migration = SqlMigration {
571 version: 1,
572 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
573 fallback_distribution_version: None,
574 };
575 db_checker.add_migration(migration);
576 db_checker.apply().unwrap();
577 check_database_version(&connection, 1);
578
579 let squashed_migration = SqlMigration {
580 version: 3,
581 alterations: ALTER_TABLE_SQL_REQUEST.to_string(),
582 fallback_distribution_version: Some("2511.0".to_string()),
583 };
584 db_checker.add_migration(squashed_migration);
585
586 let error = db_checker
587 .apply()
588 .expect_err("Should fail when applying squashed migration on old database");
589
590 assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
591 }
592}