1use anyhow::{anyhow, Context};
2use chrono::Utc;
3use slog::{debug, error, info, Logger};
4use std::{cmp::Ordering, collections::BTreeSet};
5
6use mithril_common::{logging::LoggerExtensions, StdError, StdResult};
7
8use super::{
9 ApplicationNodeType, DatabaseVersion, DbVersion, GetDatabaseVersionQuery,
10 UpdateDatabaseVersionQuery,
11};
12
13use crate::sqlite::{ConnectionExtensions, SqliteConnection};
14
15pub struct DatabaseVersionChecker<'conn> {
17 connection: &'conn SqliteConnection,
19
20 application_type: ApplicationNodeType,
22
23 logger: Logger,
25
26 migrations: BTreeSet<SqlMigration>,
28}
29
30impl<'conn> DatabaseVersionChecker<'conn> {
31 pub fn new(
33 logger: Logger,
34 application_type: ApplicationNodeType,
35 connection: &'conn SqliteConnection,
36 ) -> Self {
37 let migrations = BTreeSet::new();
38
39 Self {
40 connection,
41 application_type,
42 logger: logger.new_with_component_name::<Self>(),
43 migrations,
44 }
45 }
46
47 pub fn add_migration(&mut self, migration: SqlMigration) -> &mut Self {
49 let _ = self.migrations.insert(migration);
50
51 self
52 }
53
54 pub fn apply(&self) -> StdResult<()> {
56 debug!(&self.logger, "Check database version",);
57 self.create_table_if_not_exists(&self.application_type)
58 .with_context(|| "Can not create table 'db_version' while applying migrations")?;
59 let db_version = self
60 .connection
61 .fetch_first(GetDatabaseVersionQuery::get_application_version(
62 &self.application_type,
63 ))
64 .with_context(|| "Can not get application version while applying migrations")?
65 .unwrap(); let migration_version = self.migrations.iter().map(|m| m.version).max().unwrap_or(0);
71
72 match migration_version.cmp(&db_version.version) {
73 Ordering::Greater => {
74 debug!(
75 &self.logger,
76 "Database needs upgrade from version '{}' to version '{}', applying new migrations…",
77 db_version.version, migration_version
78 );
79 self.apply_migrations(&db_version, self.connection)?;
80 info!(
81 &self.logger,
82 "Database upgraded to version '{migration_version}'"
83 );
84 }
85 Ordering::Less => {
86 error!(
87 &self.logger,
88 "Software version '{}' is older than database structure version '{}'.",
89 db_version.version,
90 migration_version,
91 );
92
93 Err(anyhow!("This software version is older than the database structure. Aborting launch to prevent possible data corruption."))?;
94 }
95 Ordering::Equal => {
96 debug!(&self.logger, "database up to date");
97 }
98 };
99
100 Ok(())
101 }
102
103 fn apply_migrations(
104 &self,
105 starting_version: &DatabaseVersion,
106 connection: &SqliteConnection,
107 ) -> StdResult<()> {
108 for migration in &self
109 .migrations
110 .iter()
111 .filter(|&m| m.version > starting_version.version)
112 .collect::<Vec<&SqlMigration>>()
113 {
114 self.check_minimum_required_version(starting_version.version, migration)?;
115 connection.execute(&migration.alterations)?;
116 let db_version = DatabaseVersion {
117 version: migration.version,
118 application_type: self.application_type.clone(),
119 updated_at: Utc::now(),
120 };
121 let _ = connection
122 .fetch_first(UpdateDatabaseVersionQuery::one(db_version))
123 .with_context(|| {
124 format!(
125 "Can not save database version when applying migration: '{}'",
126 migration.version
127 )
128 })?;
129 }
130
131 Ok(())
132 }
133
134 pub fn create_table_if_not_exists(
137 &self,
138 application_type: &ApplicationNodeType,
139 ) -> StdResult<()> {
140 let connection = self.connection;
141 let table_exists = connection.query_single_cell::<_, i64>(
142 "select exists(select name from sqlite_master where type='table' and name='db_version') as table_exists",
143 &[],
144 )? == 1;
145
146 if !table_exists {
147 let sql = format!("
148create table db_version (application_type text not null primary key, version integer not null, updated_at text not null);
149insert into db_version (application_type, version, updated_at) values ('{application_type}', 0, '{}');
150", Utc::now().to_rfc3339());
151 connection.execute(sql)?;
152 }
153
154 Ok(())
155 }
156
157 fn check_minimum_required_version(
162 &self,
163 db_version: DbVersion,
164 migration: &SqlMigration,
165 ) -> StdResult<()> {
166 if db_version == 0 {
167 return Ok(());
168 }
169
170 if let Some(fallback_distribution_version) = &migration.fallback_distribution_version {
171 let min_required_version = migration.version - 1;
172 if db_version < min_required_version {
173 return Err(self.generate_fallback_migration_error(
174 migration.version,
175 fallback_distribution_version,
176 ));
177 }
178 }
179
180 Ok(())
181 }
182
183 fn generate_fallback_migration_error(
184 &self,
185 migration_version: i64,
186 fallback_distribution_version: &str,
187 ) -> StdError {
188 anyhow!(
189 r#"
190 Minimum required database version is not met to apply migration '{}'.
191 Please migrate your {} node database with the minimum node version compatible available in the distribution: '{}'.
192
193 First, download the required node version in your current directory by running the following command:
194 curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-{} -d {} -p $(pwd)
195
196 Then run the database migrate command:
197 mithril-{} database migrate --stores-directory /path/to/stores-directory
198 "#,
199 migration_version,
200 self.application_type.to_string(),
201 fallback_distribution_version,
202 self.application_type.to_string(),
203 fallback_distribution_version,
204 self.application_type.to_string()
205 )
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct SqlMigration {
212 pub version: DbVersion,
214
215 pub alterations: String,
217
218 pub fallback_distribution_version: Option<String>,
220}
221
222impl SqlMigration {
223 pub fn new<T: Into<String>>(version: DbVersion, alteration: T) -> Self {
225 Self {
226 version,
227 alterations: alteration.into(),
228 fallback_distribution_version: None,
229 }
230 }
231
232 pub fn new_squashed<T: Into<String>>(
234 version: DbVersion,
235 fallback_distribution_version: T,
236 alteration: T,
237 ) -> Self {
238 Self {
239 version,
240 alterations: alteration.into(),
241 fallback_distribution_version: Some(fallback_distribution_version.into()),
242 }
243 }
244}
245
246impl PartialOrd for SqlMigration {
247 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
248 Some(self.cmp(other))
249 }
250}
251
252impl Ord for SqlMigration {
253 fn cmp(&self, other: &Self) -> Ordering {
254 self.version.cmp(&other.version)
255 }
256}
257
258impl PartialEq for SqlMigration {
259 fn eq(&self, other: &Self) -> bool {
260 self.version.eq(&other.version)
261 }
262}
263
264impl Eq for SqlMigration {}
265
266#[cfg(test)]
267mod tests {
268 use anyhow::Context;
269 use mithril_common::test_utils::TempDir;
270 use mithril_common::{current_function, StdResult};
271 use sqlite::{Connection, ConnectionThreadSafe};
272 use std::path::PathBuf;
273
274 use super::*;
275
276 const CREATE_TABLE_SQL_REQUEST: &str = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
277 const ALTER_TABLE_SQL_REQUEST: &str = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
278
279 fn discard_logger() -> Logger {
280 Logger::root(slog::Discard, slog::o!())
281 }
282
283 fn check_database_version(connection: &SqliteConnection, db_version: DbVersion) {
284 let version = connection
285 .fetch_first(GetDatabaseVersionQuery::get_application_version(
286 &ApplicationNodeType::Aggregator,
287 ))
288 .unwrap()
289 .unwrap();
290
291 assert_eq!(db_version, version.version);
292 }
293
294 fn create_sqlite_file(test_name: &str) -> StdResult<(PathBuf, SqliteConnection)> {
295 let dirpath = TempDir::create("mithril_test_database", test_name);
296 let filepath = dirpath.join("db.sqlite3");
297
298 let connection = Connection::open_thread_safe(&filepath)
299 .with_context(|| "connection to sqlite file failure")?;
300
301 Ok((filepath, connection))
302 }
303
304 fn get_table_whatever_column_count(connection: &SqliteConnection) -> i64 {
305 let sql = "select count(*) as column_count from pragma_table_info('whatever');";
306 let column_count = connection
307 .prepare(sql)
308 .unwrap()
309 .iter()
310 .next()
311 .unwrap()
312 .unwrap()
313 .read::<i64, _>(0);
314
315 column_count
316 }
317
318 fn create_db_checker(connection: &ConnectionThreadSafe) -> DatabaseVersionChecker {
319 DatabaseVersionChecker::new(
320 discard_logger(),
321 ApplicationNodeType::Aggregator,
322 connection,
323 )
324 }
325
326 #[test]
327 fn test_upgrade_with_migration() {
328 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
329 let mut db_checker = create_db_checker(&connection);
330
331 db_checker.apply().unwrap();
332 assert_eq!(0, get_table_whatever_column_count(&connection));
333
334 db_checker.apply().unwrap();
335 assert_eq!(0, get_table_whatever_column_count(&connection));
336
337 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
338 let migration = SqlMigration {
339 version: 1,
340 alterations: alterations.to_string(),
341 fallback_distribution_version: None,
342 };
343 db_checker.add_migration(migration);
344 db_checker.apply().unwrap();
345 assert_eq!(1, get_table_whatever_column_count(&connection));
346 check_database_version(&connection, 1);
347
348 db_checker.apply().unwrap();
349 assert_eq!(1, get_table_whatever_column_count(&connection));
350 check_database_version(&connection, 1);
351
352 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
353 let migration = SqlMigration {
354 version: 2,
355 alterations: alterations.to_string(),
356 fallback_distribution_version: None,
357 };
358 db_checker.add_migration(migration);
359 db_checker.apply().unwrap();
360 assert_eq!(2, get_table_whatever_column_count(&connection));
361 check_database_version(&connection, 2);
362
363 let alterations = "alter table whatever add column one_last_thing text; update whatever set one_last_thing = more_thing";
367 let migration = SqlMigration {
368 version: 4,
369 alterations: alterations.to_string(),
370 fallback_distribution_version: None,
371 };
372 db_checker.add_migration(migration);
373 let alterations = "alter table whatever add column more_thing text; update whatever set more_thing = 'more thing'";
374 let migration = SqlMigration {
375 version: 3,
376 alterations: alterations.to_string(),
377 fallback_distribution_version: None,
378 };
379 db_checker.add_migration(migration);
380 db_checker.apply().unwrap();
381 assert_eq!(4, get_table_whatever_column_count(&connection));
382 check_database_version(&connection, 4);
383 }
384
385 #[test]
386 fn test_upgrade_with_migration_with_a_version_gap() {
387 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
388 let mut db_checker = create_db_checker(&connection);
389
390 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
391 let migration = SqlMigration {
392 version: 3,
393 alterations: alterations.to_string(),
394 fallback_distribution_version: None,
395 };
396 db_checker.add_migration(migration);
397 db_checker.apply().unwrap();
398 assert_eq!(1, get_table_whatever_column_count(&connection));
399 check_database_version(&connection, 3);
400
401 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
402 let migration_with_version_gap = SqlMigration {
403 version: 10,
404 alterations: alterations.to_string(),
405 fallback_distribution_version: None,
406 };
407 db_checker.add_migration(migration_with_version_gap);
408 db_checker.apply().unwrap();
409 assert_eq!(2, get_table_whatever_column_count(&connection));
410 check_database_version(&connection, 10);
411 }
412
413 #[test]
414 fn starting_with_migration() {
415 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
416 let mut db_checker = create_db_checker(&connection);
417
418 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
419 let migration = SqlMigration {
420 version: 1,
421 alterations: alterations.to_string(),
422 fallback_distribution_version: None,
423 };
424 db_checker.add_migration(migration);
425 db_checker.apply().unwrap();
426 assert_eq!(1, get_table_whatever_column_count(&connection));
427 check_database_version(&connection, 1);
428 }
429
430 #[test]
431 fn test_failing_migration() {
435 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
436 let mut db_checker = create_db_checker(&connection);
437 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
439 let migration = SqlMigration {
440 version: 1,
441 alterations: alterations.to_string(),
442 fallback_distribution_version: None,
443 };
444 db_checker.add_migration(migration);
445 let alterations = "alter table wrong add column thing_content text; update whatever set thing_content = 'some content'";
446 let migration = SqlMigration {
447 version: 2,
448 alterations: alterations.to_string(),
449 fallback_distribution_version: None,
450 };
451 db_checker.add_migration(migration);
452 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
453 let migration = SqlMigration {
454 version: 3,
455 alterations: alterations.to_string(),
456 fallback_distribution_version: None,
457 };
458 db_checker.add_migration(migration);
459 db_checker.apply().unwrap_err();
460 check_database_version(&connection, 1);
461 }
462
463 #[test]
464 fn test_fail_downgrading() {
465 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
466 let mut db_checker = create_db_checker(&connection);
467 let migration = SqlMigration {
468 version: 1,
469 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
470 fallback_distribution_version: None,
471 };
472 db_checker.add_migration(migration);
473 db_checker.apply().unwrap();
474 check_database_version(&connection, 1);
475
476 let db_checker = create_db_checker(&connection);
478 assert!(
479 db_checker.apply().is_err(),
480 "using an old version with an up to date database should fail"
481 );
482 check_database_version(&connection, 1);
483 }
484
485 #[test]
486 fn check_minimum_required_version_does_not_fail_when_no_fallback_distribution_version() {
487 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
488 let db_checker = create_db_checker(&connection);
489
490 let alterations = CREATE_TABLE_SQL_REQUEST;
491 let migration = SqlMigration {
492 version: 3,
493 alterations: alterations.to_string(),
494 fallback_distribution_version: None,
495 };
496
497 db_checker
498 .check_minimum_required_version(1, &migration)
499 .expect(
500 "Check minimum required version should not fail when no fallback distribution version",
501 );
502 }
503
504 #[test]
505 fn check_minimum_required_version_does_not_fail_when_fallback_distribution_version_with_fresh_database(
506 ) {
507 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
508 let db_checker = create_db_checker(&connection);
509
510 let alterations = CREATE_TABLE_SQL_REQUEST;
511 let migration = SqlMigration {
512 version: 2,
513 alterations: alterations.to_string(),
514 fallback_distribution_version: Some("2511.0".to_string()),
515 };
516
517 db_checker
518 .check_minimum_required_version(0, &migration)
519 .expect("Check minimum required version should not fail with fresh database");
520 }
521
522 #[test]
523 fn check_minimum_required_version_does_not_fail_when_no_gap_between_db_version_and_migration_version(
524 ) {
525 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
526 let db_checker = create_db_checker(&connection);
527
528 let migration = SqlMigration {
529 version: 2,
530 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
531 fallback_distribution_version: Some("2511.0".to_string()),
532 };
533
534 db_checker
535 .check_minimum_required_version(1, &migration)
536 .expect("Check minimum required version should not fail when no gap between db version and migration version");
537 }
538
539 #[test]
540 fn check_minimum_required_version_fails_when_gap_between_db_version_and_migration_version() {
541 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
542 let db_checker = DatabaseVersionChecker::new(
543 discard_logger(),
544 ApplicationNodeType::Aggregator,
545 &connection,
546 );
547
548 let migration = SqlMigration {
549 version: 3,
550 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
551 fallback_distribution_version: Some("2511.0".to_string()),
552 };
553
554 let error = db_checker
555 .check_minimum_required_version(1, &migration)
556 .expect_err("Check minimum required version should fail when gap between db version and migration version");
557
558 assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
559 }
560
561 #[test]
562 fn apply_fails_when_trying_to_apply_squashed_migration_on_old_database() {
563 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
564 let mut db_checker = DatabaseVersionChecker::new(
565 discard_logger(),
566 ApplicationNodeType::Aggregator,
567 &connection,
568 );
569
570 let migration = SqlMigration {
571 version: 1,
572 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
573 fallback_distribution_version: None,
574 };
575 db_checker.add_migration(migration);
576 db_checker.apply().unwrap();
577 check_database_version(&connection, 1);
578
579 let squashed_migration = SqlMigration {
580 version: 3,
581 alterations: ALTER_TABLE_SQL_REQUEST.to_string(),
582 fallback_distribution_version: Some("2511.0".to_string()),
583 };
584 db_checker.add_migration(squashed_migration);
585
586 let error = db_checker
587 .apply()
588 .expect_err("Should fail when applying squashed migration on old database");
589
590 assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
591 }
592}