mithril_persistence/sqlite/
connection_extensions.rs

1use anyhow::Context;
2use sqlite::{ReadableWithIndex, Value};
3
4use mithril_common::StdResult;
5
6use crate::sqlite::{EntityCursor, Query, SqliteConnection, Transaction};
7
8/// Extension trait for the [SqliteConnection] type.
9pub trait ConnectionExtensions {
10    /// Begin a transaction on the connection.
11    fn begin_transaction(&self) -> StdResult<Transaction>;
12
13    /// Execute the given sql query and return the value of the first cell read.
14    fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
15        &self,
16        sql: Q,
17        params: &[Value],
18    ) -> StdResult<T>;
19
20    /// Fetch entities from the database using the given query.
21    fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>>;
22
23    /// Fetch the first entity from the database returned using the given query.
24    fn fetch_first<Q: Query>(&self, query: Q) -> StdResult<Option<Q::Entity>> {
25        let mut cursor = self.fetch(query)?;
26        Ok(cursor.next())
27    }
28
29    /// Fetch entities from the database using the given query and collect the result in a collection.
30    fn fetch_collect<Q: Query, B: FromIterator<Q::Entity>>(&self, query: Q) -> StdResult<B> {
31        Ok(self.fetch(query)?.collect::<B>())
32    }
33
34    /// Apply a query that do not return data from the database(ie: insert, delete, ...).
35    fn apply<Q: Query>(&self, query: Q) -> StdResult<()> {
36        self.fetch(query)?.count();
37        Ok(())
38    }
39}
40
41impl ConnectionExtensions for SqliteConnection {
42    fn begin_transaction(&self) -> StdResult<Transaction> {
43        Ok(Transaction::begin(self)?)
44    }
45
46    fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
47        &self,
48        sql: Q,
49        params: &[Value],
50    ) -> StdResult<T> {
51        let mut statement = prepare_statement(self, sql.as_ref())?;
52        statement.bind(params)?;
53        statement.next()?;
54        statement
55            .read::<T, _>(0)
56            .with_context(|| "Read query error")
57    }
58
59    fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>> {
60        let (condition, params) = query.filters().expand();
61        let sql = query.get_definition(&condition);
62        let cursor = prepare_statement(self, &sql)?
63            .into_iter()
64            .bind(&params[..])?;
65
66        let iterator = EntityCursor::new(cursor);
67
68        Ok(iterator)
69    }
70}
71
72fn prepare_statement<'conn>(
73    sqlite_connection: &'conn SqliteConnection,
74    sql: &str,
75) -> StdResult<sqlite::Statement<'conn>> {
76    sqlite_connection.prepare(sql).with_context(|| {
77        format!(
78            "Prepare query error: SQL=`{}`",
79            &sql.replace('\n', " ").trim()
80        )
81    })
82}
83
84#[cfg(test)]
85mod tests {
86    use sqlite::Connection;
87
88    use crate::sqlite::{HydrationError, SqLiteEntity, WhereCondition};
89
90    use super::*;
91
92    #[test]
93    fn test_query_string() {
94        let connection = Connection::open_thread_safe(":memory:").unwrap();
95        let value: String = connection.query_single_cell("select 'test'", &[]).unwrap();
96
97        assert_eq!(value, "test");
98    }
99
100    #[test]
101    fn test_query_max_number() {
102        let connection = Connection::open_thread_safe(":memory:").unwrap();
103        let value: i64 = connection
104            .query_single_cell(
105                "select max(a) from (select 10 a union select 90 a union select 45 a)",
106                &[],
107            )
108            .unwrap();
109
110        assert_eq!(value, 90);
111    }
112
113    #[test]
114    fn test_query_with_params() {
115        let connection = Connection::open_thread_safe(":memory:").unwrap();
116        let value: i64 = connection
117            .query_single_cell(
118                "select max(a) from (select 10 a union select 45 a union select 90 a) \
119                where a > ? and a < ?",
120                &[Value::Integer(10), Value::Integer(90)],
121            )
122            .unwrap();
123
124        assert_eq!(value, 45);
125    }
126
127    #[test]
128    fn test_apply_execute_the_query() {
129        struct DummySqLiteEntity {}
130        impl SqLiteEntity for DummySqLiteEntity {
131            fn hydrate(_row: sqlite::Row) -> Result<Self, HydrationError>
132            where
133                Self: Sized,
134            {
135                unimplemented!()
136            }
137
138            fn get_projection() -> crate::sqlite::Projection {
139                unimplemented!()
140            }
141        }
142
143        struct FakeQuery {
144            sql: String,
145        }
146        impl Query for FakeQuery {
147            type Entity = DummySqLiteEntity;
148
149            fn filters(&self) -> WhereCondition {
150                WhereCondition::default()
151            }
152
153            fn get_definition(&self, _condition: &str) -> String {
154                self.sql.clone()
155            }
156        }
157
158        let connection = Connection::open_thread_safe(":memory:").unwrap();
159        connection
160            .execute("create table query_test(text_data);")
161            .unwrap();
162
163        let value: i64 = connection
164            .query_single_cell("select count(*) from query_test", &[])
165            .unwrap();
166        assert_eq!(value, 0);
167
168        let query = FakeQuery {
169            sql: "insert into query_test(text_data) values ('row 1')".to_string(),
170        };
171        connection.apply(query).unwrap();
172
173        let value: i64 = connection
174            .query_single_cell("select count(*) from query_test", &[])
175            .unwrap();
176        assert_eq!(value, 1);
177    }
178}