use std::ops::Not;
use std::path::{Path, PathBuf};
use anyhow::Context;
use slog::{debug, Logger};
use sqlite::{Connection, ConnectionThreadSafe};
use mithril_common::logging::LoggerExtensions;
use mithril_common::StdResult;
use crate::database::{ApplicationNodeType, DatabaseVersionChecker, SqlMigration};
pub struct ConnectionBuilder {
connection_path: PathBuf,
sql_migrations: Vec<SqlMigration>,
options: Vec<ConnectionOptions>,
node_type: ApplicationNodeType,
base_logger: Logger,
}
#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
pub enum ConnectionOptions {
EnableWriteAheadLog,
EnableForeignKeys,
ForceDisableForeignKeys,
}
impl ConnectionBuilder {
pub fn open_file(path: &Path) -> Self {
Self {
connection_path: path.to_path_buf(),
sql_migrations: vec![],
options: vec![],
node_type: ApplicationNodeType::Signer,
base_logger: Logger::root(slog::Discard, slog::o!()),
}
}
pub fn open_memory() -> Self {
Self::open_file(":memory:".as_ref())
}
pub fn with_migrations(mut self, migrations: Vec<SqlMigration>) -> Self {
self.sql_migrations = migrations;
self
}
pub fn with_options(mut self, options: &[ConnectionOptions]) -> Self {
for option in options {
self.options.push(option.clone());
}
self
}
pub fn with_logger(mut self, logger: Logger) -> Self {
self.base_logger = logger;
self
}
pub fn with_node_type(mut self, node_type: ApplicationNodeType) -> Self {
self.node_type = node_type;
self
}
pub fn build(self) -> StdResult<ConnectionThreadSafe> {
let logger = self.base_logger.new_with_component_name::<Self>();
debug!(logger, "Opening SQLite connection"; "path" => self.connection_path.display());
let connection =
Connection::open_thread_safe(&self.connection_path).with_context(|| {
format!(
"SQLite initialization: could not open connection with string '{}'.",
self.connection_path.display()
)
})?;
if self
.options
.contains(&ConnectionOptions::EnableWriteAheadLog)
{
debug!(logger, "Enabling SQLite Write Ahead Log journal mode");
connection
.execute("pragma journal_mode = wal; pragma synchronous = normal;")
.with_context(|| "SQLite initialization: could not enable WAL.")?;
}
if self.options.contains(&ConnectionOptions::EnableForeignKeys) {
debug!(logger, "Enabling SQLite foreign key support");
connection
.execute("pragma foreign_keys=true")
.with_context(|| "SQLite initialization: could not enable FOREIGN KEY support.")?;
}
let migrations = self.sql_migrations.clone();
self.apply_migrations(&connection, migrations)?;
if self
.options
.contains(&ConnectionOptions::ForceDisableForeignKeys)
{
debug!(logger, "Force disabling SQLite foreign key support");
connection
.execute("pragma foreign_keys=false")
.with_context(|| "SQLite initialization: could not disable FOREIGN KEY support.")?;
}
Ok(connection)
}
pub fn apply_migrations(
&self,
connection: &ConnectionThreadSafe,
sql_migrations: Vec<SqlMigration>,
) -> StdResult<()> {
let logger = self.base_logger.new_with_component_name::<Self>();
if sql_migrations.is_empty().not() {
debug!(logger, "Applying database migrations");
let mut db_checker = DatabaseVersionChecker::new(
self.base_logger.clone(),
self.node_type.clone(),
connection,
);
for migration in sql_migrations {
db_checker.add_migration(migration.clone());
}
db_checker
.apply()
.with_context(|| "Database migration error")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use sqlite::Value;
use mithril_common::test_utils::TempDir;
use crate::sqlite::ConnectionOptions::ForceDisableForeignKeys;
use super::*;
const DEFAULT_SQLITE_JOURNAL_MODE: &str = "delete";
const NORMAL_SYNCHRONOUS_FLAG: i64 = 1;
fn execute_single_cell_query(connection: &Connection, query: &str) -> Value {
let mut statement = connection.prepare(query).unwrap();
let mut row = statement.iter().next().unwrap().unwrap();
row.take(0)
}
#[test]
fn test_open_in_memory_without_foreign_key() {
let connection = ConnectionBuilder::open_memory().build().unwrap();
let journal_mode = execute_single_cell_query(&connection, "pragma journal_mode;");
let foreign_keys = execute_single_cell_query(&connection, "pragma foreign_keys;");
assert_eq!(Value::String("memory".to_string()), journal_mode);
assert_eq!(Value::Integer(false.into()), foreign_keys);
}
#[test]
fn test_open_with_foreign_key() {
let connection = ConnectionBuilder::open_memory()
.with_options(&[ConnectionOptions::EnableForeignKeys])
.build()
.unwrap();
let journal_mode = execute_single_cell_query(&connection, "pragma journal_mode;");
let foreign_keys = execute_single_cell_query(&connection, "pragma foreign_keys;");
assert_eq!(Value::String("memory".to_string()), journal_mode);
assert_eq!(Value::Integer(true.into()), foreign_keys);
}
#[test]
fn test_open_file_without_wal_and_foreign_keys() {
let dirpath = TempDir::create(
"mithril_test_database",
"test_open_file_without_wal_and_foreign_keys",
);
let filepath = dirpath.join("db.sqlite3");
assert!(!filepath.exists());
let connection = ConnectionBuilder::open_file(&filepath).build().unwrap();
let journal_mode = execute_single_cell_query(&connection, "pragma journal_mode;");
let foreign_keys = execute_single_cell_query(&connection, "pragma foreign_keys;");
assert!(filepath.exists());
assert_eq!(
Value::String(DEFAULT_SQLITE_JOURNAL_MODE.to_string()),
journal_mode
);
assert_eq!(Value::Integer(false.into()), foreign_keys);
}
#[test]
fn test_open_file_with_wal_and_foreign_keys() {
let dirpath = TempDir::create(
"mithril_test_database",
"test_open_file_with_wal_and_foreign_keys",
);
let filepath = dirpath.join("db.sqlite3");
assert!(!filepath.exists());
let connection = ConnectionBuilder::open_file(&filepath)
.with_options(&[
ConnectionOptions::EnableForeignKeys,
ConnectionOptions::EnableWriteAheadLog,
])
.build()
.unwrap();
let journal_mode = execute_single_cell_query(&connection, "pragma journal_mode;");
let foreign_keys = execute_single_cell_query(&connection, "pragma foreign_keys;");
assert!(filepath.exists());
assert_eq!(Value::String("wal".to_string()), journal_mode);
assert_eq!(Value::Integer(true.into()), foreign_keys);
}
#[test]
fn enabling_wal_option_also_set_synchronous_flag_to_normal() {
let dirpath = TempDir::create(
"mithril_test_database",
"enabling_wal_option_also_set_synchronous_flag_to_normal",
);
let connection = ConnectionBuilder::open_file(&dirpath.join("db.sqlite3"))
.with_options(&[ConnectionOptions::EnableWriteAheadLog])
.build()
.unwrap();
let synchronous_flag = execute_single_cell_query(&connection, "pragma synchronous;");
assert_eq!(Value::Integer(NORMAL_SYNCHRONOUS_FLAG), synchronous_flag);
}
#[test]
fn builder_apply_given_migrations() {
let connection = ConnectionBuilder::open_memory()
.with_migrations(vec![
SqlMigration::new(1, "create table first(id integer);"),
SqlMigration::new(2, "create table second(id integer);"),
])
.build()
.unwrap();
let tables_list = execute_single_cell_query(
&connection,
"SELECT group_concat(name) FROM sqlite_schema \
WHERE type = 'table' AND name NOT LIKE 'sqlite_%' AND name != 'db_version' \
ORDER BY name;",
);
assert_eq!(Value::String("first,second".to_string()), tables_list);
}
#[test]
fn can_disable_foreign_keys_even_if_a_migration_enable_them() {
let connection = ConnectionBuilder::open_memory()
.with_migrations(vec![SqlMigration::new(1, "pragma foreign_keys=true;")])
.with_options(&[ForceDisableForeignKeys])
.build()
.unwrap();
let foreign_keys = execute_single_cell_query(&connection, "pragma foreign_keys;");
assert_eq!(Value::Integer(false.into()), foreign_keys);
}
#[test]
fn test_apply_a_partial_migrations() {
let migrations = vec![
SqlMigration::new(1, "create table first(id integer);"),
SqlMigration::new(2, "create table second(id integer);"),
];
let connection = ConnectionBuilder::open_memory().build().unwrap();
assert!(connection.prepare("select * from first;").is_err());
assert!(connection.prepare("select * from second;").is_err());
ConnectionBuilder::open_memory()
.apply_migrations(&connection, migrations[0..1].to_vec())
.unwrap();
assert!(connection.prepare("select * from first;").is_ok());
assert!(connection.prepare("select * from second;").is_err());
ConnectionBuilder::open_memory()
.apply_migrations(&connection, migrations)
.unwrap();
assert!(connection.prepare("select * from first;").is_ok());
assert!(connection.prepare("select * from second;").is_ok());
}
}