mithril_persistence/sqlite/
connection_extensions.rs

1use anyhow::Context;
2use sqlite::{ReadableWithIndex, Value};
3
4use mithril_common::StdResult;
5
6use crate::sqlite::{EntityCursor, Query, SqliteConnection, Transaction};
7
8/// Extension trait for the [SqliteConnection] type.
9pub trait ConnectionExtensions {
10    /// Begin a transaction on the connection.
11    fn begin_transaction(&self) -> StdResult<Transaction>;
12
13    /// Execute the given sql query and return the value of the first cell read.
14    fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
15        &self,
16        sql: Q,
17        params: &[Value],
18    ) -> StdResult<T>;
19
20    /// Fetch entities from the database using the given query.
21    fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>>;
22
23    /// Fetch the first entity from the database returned using the given query.
24    fn fetch_first<Q: Query>(&self, query: Q) -> StdResult<Option<Q::Entity>> {
25        let mut cursor = self.fetch(query)?;
26        Ok(cursor.next())
27    }
28
29    /// Fetch entities from the database using the given query and collect the result in a collection.
30    fn fetch_collect<Q: Query, B: FromIterator<Q::Entity>>(&self, query: Q) -> StdResult<B> {
31        Ok(self.fetch(query)?.collect::<B>())
32    }
33
34    /// Apply a query that do not return data from the database(ie: insert, delete, ...).
35    fn apply<Q: Query>(&self, query: Q) -> StdResult<()> {
36        self.fetch(query)?.count();
37        Ok(())
38    }
39}
40
41impl ConnectionExtensions for SqliteConnection {
42    fn begin_transaction(&self) -> StdResult<Transaction> {
43        Ok(Transaction::begin(self)?)
44    }
45
46    fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
47        &self,
48        sql: Q,
49        params: &[Value],
50    ) -> StdResult<T> {
51        let mut statement = prepare_statement(self, sql.as_ref())?;
52        statement.bind(params)?;
53        statement.next()?;
54        statement.read::<T, _>(0).with_context(|| "Read query error")
55    }
56
57    fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>> {
58        let (condition, params) = query.filters().expand();
59        let sql = query.get_definition(&condition);
60        let cursor = prepare_statement(self, &sql)?.into_iter().bind(&params[..])?;
61
62        let iterator = EntityCursor::new(cursor);
63
64        Ok(iterator)
65    }
66}
67
68fn prepare_statement<'conn>(
69    sqlite_connection: &'conn SqliteConnection,
70    sql: &str,
71) -> StdResult<sqlite::Statement<'conn>> {
72    sqlite_connection.prepare(sql).with_context(|| {
73        format!(
74            "Prepare query error: SQL=`{}`",
75            &sql.replace('\n', " ").trim()
76        )
77    })
78}
79
80#[cfg(test)]
81mod tests {
82    use sqlite::Connection;
83
84    use crate::sqlite::{HydrationError, SqLiteEntity, WhereCondition};
85
86    use super::*;
87
88    #[test]
89    fn test_query_string() {
90        let connection = Connection::open_thread_safe(":memory:").unwrap();
91        let value: String = connection.query_single_cell("select 'test'", &[]).unwrap();
92
93        assert_eq!(value, "test");
94    }
95
96    #[test]
97    fn test_query_max_number() {
98        let connection = Connection::open_thread_safe(":memory:").unwrap();
99        let value: i64 = connection
100            .query_single_cell(
101                "select max(a) from (select 10 a union select 90 a union select 45 a)",
102                &[],
103            )
104            .unwrap();
105
106        assert_eq!(value, 90);
107    }
108
109    #[test]
110    fn test_query_with_params() {
111        let connection = Connection::open_thread_safe(":memory:").unwrap();
112        let value: i64 = connection
113            .query_single_cell(
114                "select max(a) from (select 10 a union select 45 a union select 90 a) \
115                where a > ? and a < ?",
116                &[Value::Integer(10), Value::Integer(90)],
117            )
118            .unwrap();
119
120        assert_eq!(value, 45);
121    }
122
123    #[test]
124    fn test_apply_execute_the_query() {
125        struct DummySqLiteEntity {}
126        impl SqLiteEntity for DummySqLiteEntity {
127            fn hydrate(_row: sqlite::Row) -> Result<Self, HydrationError>
128            where
129                Self: Sized,
130            {
131                unimplemented!()
132            }
133
134            fn get_projection() -> crate::sqlite::Projection {
135                unimplemented!()
136            }
137        }
138
139        struct FakeQuery {
140            sql: String,
141        }
142        impl Query for FakeQuery {
143            type Entity = DummySqLiteEntity;
144
145            fn filters(&self) -> WhereCondition {
146                WhereCondition::default()
147            }
148
149            fn get_definition(&self, _condition: &str) -> String {
150                self.sql.clone()
151            }
152        }
153
154        let connection = Connection::open_thread_safe(":memory:").unwrap();
155        connection.execute("create table query_test(text_data);").unwrap();
156
157        let value: i64 = connection
158            .query_single_cell("select count(*) from query_test", &[])
159            .unwrap();
160        assert_eq!(value, 0);
161
162        let query = FakeQuery {
163            sql: "insert into query_test(text_data) values ('row 1')".to_string(),
164        };
165        connection.apply(query).unwrap();
166
167        let value: i64 = connection
168            .query_single_cell("select count(*) from query_test", &[])
169            .unwrap();
170        assert_eq!(value, 1);
171    }
172}