1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use anyhow::Context;
use sqlite::{ReadableWithIndex, Value};

use mithril_common::StdResult;

use crate::sqlite::{EntityCursor, Query, SqliteConnection, Transaction};

/// Extension trait for the [SqliteConnection] type.
pub trait ConnectionExtensions {
    /// Begin a transaction on the connection.
    fn begin_transaction(&self) -> StdResult<Transaction>;

    /// Execute the given sql query and return the value of the first cell read.
    fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
        &self,
        sql: Q,
        params: &[Value],
    ) -> StdResult<T>;

    /// Fetch entities from the database using the given query.
    fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>>;

    /// Fetch the first entity from the database returned using the given query.
    fn fetch_first<Q: Query>(&self, query: Q) -> StdResult<Option<Q::Entity>> {
        let mut cursor = self.fetch(query)?;
        Ok(cursor.next())
    }

    /// Fetch entities from the database using the given query and collect the result in a collection.
    fn fetch_collect<Q: Query, B: FromIterator<Q::Entity>>(&self, query: Q) -> StdResult<B> {
        Ok(self.fetch(query)?.collect::<B>())
    }
}

impl ConnectionExtensions for SqliteConnection {
    fn begin_transaction(&self) -> StdResult<Transaction> {
        Ok(Transaction::begin(self)?)
    }

    fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
        &self,
        sql: Q,
        params: &[Value],
    ) -> StdResult<T> {
        let mut statement = prepare_statement(self, sql.as_ref())?;
        statement.bind(params)?;
        statement.next()?;
        statement
            .read::<T, _>(0)
            .with_context(|| "Read query error")
    }

    fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>> {
        let (condition, params) = query.filters().expand();
        let sql = query.get_definition(&condition);
        let cursor = prepare_statement(self, &sql)?
            .into_iter()
            .bind(&params[..])?;

        let iterator = EntityCursor::new(cursor);

        Ok(iterator)
    }
}

fn prepare_statement<'conn>(
    sqlite_connection: &'conn SqliteConnection,
    sql: &str,
) -> StdResult<sqlite::Statement<'conn>> {
    sqlite_connection.prepare(sql).with_context(|| {
        format!(
            "Prepare query error: SQL=`{}`",
            &sql.replace('\n', " ").trim()
        )
    })
}

#[cfg(test)]
mod tests {
    use sqlite::Connection;

    use super::*;

    #[test]
    fn test_query_string() {
        let connection = Connection::open_thread_safe(":memory:").unwrap();
        let value: String = connection.query_single_cell("select 'test'", &[]).unwrap();

        assert_eq!(value, "test");
    }

    #[test]
    fn test_query_max_number() {
        let connection = Connection::open_thread_safe(":memory:").unwrap();
        let value: i64 = connection
            .query_single_cell(
                "select max(a) from (select 10 a union select 90 a union select 45 a)",
                &[],
            )
            .unwrap();

        assert_eq!(value, 90);
    }

    #[test]
    fn test_query_with_params() {
        let connection = Connection::open_thread_safe(":memory:").unwrap();
        let value: i64 = connection
            .query_single_cell(
                "select max(a) from (select 10 a union select 45 a union select 90 a) \
                where a > ? and a < ?",
                &[Value::Integer(10), Value::Integer(90)],
            )
            .unwrap();

        assert_eq!(value, 45);
    }
}