1use anyhow::{Context, anyhow};
2use chrono::Utc;
3use slog::{Logger, debug, error, info};
4use std::{cmp::Ordering, collections::BTreeSet};
5
6use mithril_common::{StdError, StdResult, logging::LoggerExtensions};
7
8use super::{
9 ApplicationNodeType, DatabaseVersion, DbVersion, GetDatabaseVersionQuery,
10 UpdateDatabaseVersionQuery,
11};
12
13use crate::sqlite::{
14 ConnectionExtensions, OptimizeMode, SqliteCleaner, SqliteCleaningTask, SqliteConnection,
15};
16
17pub struct DatabaseVersionChecker<'conn> {
19 connection: &'conn SqliteConnection,
21
22 application_type: ApplicationNodeType,
24
25 logger: Logger,
27
28 migrations: BTreeSet<SqlMigration>,
30}
31
32impl<'conn> DatabaseVersionChecker<'conn> {
33 pub fn new(
35 logger: Logger,
36 application_type: ApplicationNodeType,
37 connection: &'conn SqliteConnection,
38 ) -> Self {
39 let migrations = BTreeSet::new();
40
41 Self {
42 connection,
43 application_type,
44 logger: logger.new_with_component_name::<Self>(),
45 migrations,
46 }
47 }
48
49 pub fn add_migration(&mut self, migration: SqlMigration) -> &mut Self {
51 let _ = self.migrations.insert(migration);
52
53 self
54 }
55
56 pub fn apply(&self) -> StdResult<()> {
58 debug!(&self.logger, "Check database version",);
59 self.create_table_if_not_exists(&self.application_type)
60 .with_context(|| "Can not create table 'db_version' while applying migrations")?;
61 let db_version = self
62 .connection
63 .fetch_first(GetDatabaseVersionQuery::get_application_version(
64 &self.application_type,
65 ))
66 .with_context(|| "Can not get application version while applying migrations")?
67 .unwrap(); let migration_version = self.migrations.iter().map(|m| m.version).max().unwrap_or(0);
73
74 match migration_version.cmp(&db_version.version) {
75 Ordering::Greater => {
76 debug!(
77 &self.logger,
78 "Database needs upgrade from version '{}' to version '{}', applying new migrations…",
79 db_version.version,
80 migration_version
81 );
82 self.apply_migrations(&db_version, self.connection)?;
83 info!(
84 &self.logger,
85 "Database upgraded to version '{migration_version}'"
86 );
87
88 SqliteCleaner::new(self.connection)
90 .with_logger(self.logger.clone())
91 .with_tasks(&[SqliteCleaningTask::Optimize(OptimizeMode::Default)])
92 .run()
93 .with_context(|| "Database optimization error")?;
94 }
95 Ordering::Less => {
96 error!(
97 &self.logger,
98 "Software version '{}' is older than database structure version '{}'.",
99 db_version.version,
100 migration_version,
101 );
102
103 Err(anyhow!(
104 "This software version is older than the database structure. Aborting launch to prevent possible data corruption."
105 ))?;
106 }
107 Ordering::Equal => {
108 debug!(&self.logger, "database up to date");
109 }
110 };
111
112 Ok(())
113 }
114
115 fn apply_migrations(
116 &self,
117 starting_version: &DatabaseVersion,
118 connection: &SqliteConnection,
119 ) -> StdResult<()> {
120 for migration in &self
121 .migrations
122 .iter()
123 .filter(|&m| m.version > starting_version.version)
124 .collect::<Vec<&SqlMigration>>()
125 {
126 self.check_minimum_required_version(starting_version.version, migration)?;
127 connection.execute(&migration.alterations)?;
128 let db_version = DatabaseVersion {
129 version: migration.version,
130 application_type: self.application_type.clone(),
131 updated_at: Utc::now(),
132 };
133 let _ = connection
134 .fetch_first(UpdateDatabaseVersionQuery::one(db_version))
135 .with_context(|| {
136 format!(
137 "Can not save database version when applying migration: '{}'",
138 migration.version
139 )
140 })?;
141 }
142
143 Ok(())
144 }
145
146 pub fn create_table_if_not_exists(
149 &self,
150 application_type: &ApplicationNodeType,
151 ) -> StdResult<()> {
152 let connection = self.connection;
153 let table_exists = connection.query_single_cell::<_, i64>(
154 "select exists(select name from sqlite_master where type='table' and name='db_version') as table_exists",
155 &[],
156 )? == 1;
157
158 if !table_exists {
159 let sql = format!("
160create table db_version (application_type text not null primary key, version integer not null, updated_at text not null);
161insert into db_version (application_type, version, updated_at) values ('{application_type}', 0, '{}');
162", Utc::now().to_rfc3339());
163 connection.execute(sql)?;
164 }
165
166 Ok(())
167 }
168
169 fn check_minimum_required_version(
174 &self,
175 db_version: DbVersion,
176 migration: &SqlMigration,
177 ) -> StdResult<()> {
178 if db_version == 0 {
179 return Ok(());
180 }
181
182 if let Some(fallback_distribution_version) = &migration.fallback_distribution_version {
183 let min_required_version = migration.version - 1;
184 if db_version < min_required_version {
185 return Err(self.generate_fallback_migration_error(
186 migration.version,
187 fallback_distribution_version,
188 ));
189 }
190 }
191
192 Ok(())
193 }
194
195 fn generate_fallback_migration_error(
196 &self,
197 migration_version: i64,
198 fallback_distribution_version: &str,
199 ) -> StdError {
200 anyhow!(
201 r#"
202 Minimum required database version is not met to apply migration '{}'.
203 Please migrate your {} node database with the minimum node version compatible available in the distribution: '{}'.
204
205 First, download the required node version in your current directory by running the following command:
206 curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-{} -d {} -p $(pwd)
207
208 Then run the database migrate command:
209 mithril-{} database migrate --stores-directory /path/to/stores-directory
210 "#,
211 migration_version,
212 self.application_type,
213 fallback_distribution_version,
214 self.application_type,
215 fallback_distribution_version,
216 self.application_type
217 )
218 }
219}
220
221#[derive(Debug, Clone)]
223pub struct SqlMigration {
224 pub version: DbVersion,
226
227 pub alterations: String,
229
230 pub fallback_distribution_version: Option<String>,
232}
233
234impl SqlMigration {
235 pub fn new<T: Into<String>>(version: DbVersion, alteration: T) -> Self {
237 Self {
238 version,
239 alterations: alteration.into(),
240 fallback_distribution_version: None,
241 }
242 }
243
244 pub fn new_squashed<T: Into<String>>(
246 version: DbVersion,
247 fallback_distribution_version: T,
248 alteration: T,
249 ) -> Self {
250 Self {
251 version,
252 alterations: alteration.into(),
253 fallback_distribution_version: Some(fallback_distribution_version.into()),
254 }
255 }
256}
257
258impl PartialOrd for SqlMigration {
259 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
260 Some(self.cmp(other))
261 }
262}
263
264impl Ord for SqlMigration {
265 fn cmp(&self, other: &Self) -> Ordering {
266 self.version.cmp(&other.version)
267 }
268}
269
270impl PartialEq for SqlMigration {
271 fn eq(&self, other: &Self) -> bool {
272 self.version.eq(&other.version)
273 }
274}
275
276impl Eq for SqlMigration {}
277
278#[cfg(test)]
279mod tests {
280 use anyhow::Context;
281 use mithril_common::test::TempDir;
282 use mithril_common::{StdResult, current_function};
283 use sqlite::{Connection, ConnectionThreadSafe};
284 use std::path::PathBuf;
285
286 use super::*;
287
288 const CREATE_TABLE_SQL_REQUEST: &str = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
289 const ALTER_TABLE_SQL_REQUEST: &str = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
290
291 fn discard_logger() -> Logger {
292 Logger::root(slog::Discard, slog::o!())
293 }
294
295 fn check_database_version(connection: &SqliteConnection, db_version: DbVersion) {
296 let version = connection
297 .fetch_first(GetDatabaseVersionQuery::get_application_version(
298 &ApplicationNodeType::Aggregator,
299 ))
300 .unwrap()
301 .unwrap();
302
303 assert_eq!(db_version, version.version);
304 }
305
306 fn create_sqlite_file(test_name: &str) -> StdResult<(PathBuf, SqliteConnection)> {
307 let dirpath = TempDir::create("mithril_test_database", test_name);
308 let filepath = dirpath.join("db.sqlite3");
309
310 let connection = Connection::open_thread_safe(&filepath)
311 .with_context(|| "connection to sqlite file failure")?;
312
313 Ok((filepath, connection))
314 }
315
316 fn get_table_whatever_column_count(connection: &SqliteConnection) -> i64 {
317 let sql = "select count(*) as column_count from pragma_table_info('whatever');";
318
319 connection
320 .prepare(sql)
321 .unwrap()
322 .iter()
323 .next()
324 .unwrap()
325 .unwrap()
326 .read::<i64, _>(0)
327 }
328
329 fn create_db_checker(connection: &ConnectionThreadSafe) -> DatabaseVersionChecker<'_> {
330 DatabaseVersionChecker::new(
331 discard_logger(),
332 ApplicationNodeType::Aggregator,
333 connection,
334 )
335 }
336
337 #[test]
338 fn test_upgrade_with_migration() {
339 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
340 let mut db_checker = create_db_checker(&connection);
341
342 db_checker.apply().unwrap();
343 assert_eq!(0, get_table_whatever_column_count(&connection));
344
345 db_checker.apply().unwrap();
346 assert_eq!(0, get_table_whatever_column_count(&connection));
347
348 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
349 let migration = SqlMigration {
350 version: 1,
351 alterations: alterations.to_string(),
352 fallback_distribution_version: None,
353 };
354 db_checker.add_migration(migration);
355 db_checker.apply().unwrap();
356 assert_eq!(1, get_table_whatever_column_count(&connection));
357 check_database_version(&connection, 1);
358
359 db_checker.apply().unwrap();
360 assert_eq!(1, get_table_whatever_column_count(&connection));
361 check_database_version(&connection, 1);
362
363 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
364 let migration = SqlMigration {
365 version: 2,
366 alterations: alterations.to_string(),
367 fallback_distribution_version: None,
368 };
369 db_checker.add_migration(migration);
370 db_checker.apply().unwrap();
371 assert_eq!(2, get_table_whatever_column_count(&connection));
372 check_database_version(&connection, 2);
373
374 let alterations = "alter table whatever add column one_last_thing text; update whatever set one_last_thing = more_thing";
378 let migration = SqlMigration {
379 version: 4,
380 alterations: alterations.to_string(),
381 fallback_distribution_version: None,
382 };
383 db_checker.add_migration(migration);
384 let alterations = "alter table whatever add column more_thing text; update whatever set more_thing = 'more thing'";
385 let migration = SqlMigration {
386 version: 3,
387 alterations: alterations.to_string(),
388 fallback_distribution_version: None,
389 };
390 db_checker.add_migration(migration);
391 db_checker.apply().unwrap();
392 assert_eq!(4, get_table_whatever_column_count(&connection));
393 check_database_version(&connection, 4);
394 }
395
396 #[test]
397 fn test_upgrade_with_migration_with_a_version_gap() {
398 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
399 let mut db_checker = create_db_checker(&connection);
400
401 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
402 let migration = SqlMigration {
403 version: 3,
404 alterations: alterations.to_string(),
405 fallback_distribution_version: None,
406 };
407 db_checker.add_migration(migration);
408 db_checker.apply().unwrap();
409 assert_eq!(1, get_table_whatever_column_count(&connection));
410 check_database_version(&connection, 3);
411
412 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
413 let migration_with_version_gap = SqlMigration {
414 version: 10,
415 alterations: alterations.to_string(),
416 fallback_distribution_version: None,
417 };
418 db_checker.add_migration(migration_with_version_gap);
419 db_checker.apply().unwrap();
420 assert_eq!(2, get_table_whatever_column_count(&connection));
421 check_database_version(&connection, 10);
422 }
423
424 #[test]
425 fn starting_with_migration() {
426 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
427 let mut db_checker = create_db_checker(&connection);
428
429 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
430 let migration = SqlMigration {
431 version: 1,
432 alterations: alterations.to_string(),
433 fallback_distribution_version: None,
434 };
435 db_checker.add_migration(migration);
436 db_checker.apply().unwrap();
437 assert_eq!(1, get_table_whatever_column_count(&connection));
438 check_database_version(&connection, 1);
439 }
440
441 #[test]
442 fn test_failing_migration() {
446 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
447 let mut db_checker = create_db_checker(&connection);
448 let alterations = "create table whatever (thing_id integer); insert into whatever (thing_id) values (1), (2), (3), (4);";
450 let migration = SqlMigration {
451 version: 1,
452 alterations: alterations.to_string(),
453 fallback_distribution_version: None,
454 };
455 db_checker.add_migration(migration);
456 let alterations = "alter table wrong add column thing_content text; update whatever set thing_content = 'some content'";
457 let migration = SqlMigration {
458 version: 2,
459 alterations: alterations.to_string(),
460 fallback_distribution_version: None,
461 };
462 db_checker.add_migration(migration);
463 let alterations = "alter table whatever add column thing_content text; update whatever set thing_content = 'some content'";
464 let migration = SqlMigration {
465 version: 3,
466 alterations: alterations.to_string(),
467 fallback_distribution_version: None,
468 };
469 db_checker.add_migration(migration);
470 db_checker.apply().unwrap_err();
471 check_database_version(&connection, 1);
472 }
473
474 #[test]
475 fn test_fail_downgrading() {
476 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
477 let mut db_checker = create_db_checker(&connection);
478 let migration = SqlMigration {
479 version: 1,
480 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
481 fallback_distribution_version: None,
482 };
483 db_checker.add_migration(migration);
484 db_checker.apply().unwrap();
485 check_database_version(&connection, 1);
486
487 let db_checker = create_db_checker(&connection);
489 assert!(
490 db_checker.apply().is_err(),
491 "using an old version with an up to date database should fail"
492 );
493 check_database_version(&connection, 1);
494 }
495
496 #[test]
497 fn check_minimum_required_version_does_not_fail_when_no_fallback_distribution_version() {
498 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
499 let db_checker = create_db_checker(&connection);
500
501 let alterations = CREATE_TABLE_SQL_REQUEST;
502 let migration = SqlMigration {
503 version: 3,
504 alterations: alterations.to_string(),
505 fallback_distribution_version: None,
506 };
507
508 db_checker.check_minimum_required_version(1, &migration).expect(
509 "Check minimum required version should not fail when no fallback distribution version",
510 );
511 }
512
513 #[test]
514 fn check_minimum_required_version_does_not_fail_when_fallback_distribution_version_with_fresh_database()
515 {
516 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
517 let db_checker = create_db_checker(&connection);
518
519 let alterations = CREATE_TABLE_SQL_REQUEST;
520 let migration = SqlMigration {
521 version: 2,
522 alterations: alterations.to_string(),
523 fallback_distribution_version: Some("2511.0".to_string()),
524 };
525
526 db_checker
527 .check_minimum_required_version(0, &migration)
528 .expect("Check minimum required version should not fail with fresh database");
529 }
530
531 #[test]
532 fn check_minimum_required_version_does_not_fail_when_no_gap_between_db_version_and_migration_version()
533 {
534 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
535 let db_checker = create_db_checker(&connection);
536
537 let migration = SqlMigration {
538 version: 2,
539 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
540 fallback_distribution_version: Some("2511.0".to_string()),
541 };
542
543 db_checker
544 .check_minimum_required_version(1, &migration)
545 .expect("Check minimum required version should not fail when no gap between db version and migration version");
546 }
547
548 #[test]
549 fn check_minimum_required_version_fails_when_gap_between_db_version_and_migration_version() {
550 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
551 let db_checker = DatabaseVersionChecker::new(
552 discard_logger(),
553 ApplicationNodeType::Aggregator,
554 &connection,
555 );
556
557 let migration = SqlMigration {
558 version: 3,
559 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
560 fallback_distribution_version: Some("2511.0".to_string()),
561 };
562
563 let error = db_checker
564 .check_minimum_required_version(1, &migration)
565 .expect_err("Check minimum required version should fail when gap between db version and migration version");
566
567 assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
568 }
569
570 #[test]
571 fn apply_fails_when_trying_to_apply_squashed_migration_on_old_database() {
572 let (_filepath, connection) = create_sqlite_file(current_function!()).unwrap();
573 let mut db_checker = DatabaseVersionChecker::new(
574 discard_logger(),
575 ApplicationNodeType::Aggregator,
576 &connection,
577 );
578
579 let migration = SqlMigration {
580 version: 1,
581 alterations: CREATE_TABLE_SQL_REQUEST.to_string(),
582 fallback_distribution_version: None,
583 };
584 db_checker.add_migration(migration);
585 db_checker.apply().unwrap();
586 check_database_version(&connection, 1);
587
588 let squashed_migration = SqlMigration {
589 version: 3,
590 alterations: ALTER_TABLE_SQL_REQUEST.to_string(),
591 fallback_distribution_version: Some("2511.0".to_string()),
592 };
593 db_checker.add_migration(squashed_migration);
594
595 let error = db_checker
596 .apply()
597 .expect_err("Should fail when applying squashed migration on old database");
598
599 assert!(error.to_string().contains("curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/input-output-hk/mithril/refs/heads/main/mithril-install.sh | sh -s -- -c mithril-aggregator -d 2511.0 -p $(pwd)"));
600 }
601}