mithril_persistence/sqlite/
connection_extensions.rs1use anyhow::Context;
2use sqlite::{ReadableWithIndex, Value};
3
4use mithril_common::StdResult;
5
6use crate::sqlite::{EntityCursor, Query, SqliteConnection, Transaction};
7
8pub trait ConnectionExtensions {
10 fn begin_transaction(&self) -> StdResult<Transaction>;
12
13 fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
15 &self,
16 sql: Q,
17 params: &[Value],
18 ) -> StdResult<T>;
19
20 fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>>;
22
23 fn fetch_first<Q: Query>(&self, query: Q) -> StdResult<Option<Q::Entity>> {
25 let mut cursor = self.fetch(query)?;
26 Ok(cursor.next())
27 }
28
29 fn fetch_collect<Q: Query, B: FromIterator<Q::Entity>>(&self, query: Q) -> StdResult<B> {
31 Ok(self.fetch(query)?.collect::<B>())
32 }
33
34 fn apply<Q: Query>(&self, query: Q) -> StdResult<()> {
36 self.fetch(query)?.count();
37 Ok(())
38 }
39}
40
41impl ConnectionExtensions for SqliteConnection {
42 fn begin_transaction(&self) -> StdResult<Transaction> {
43 Ok(Transaction::begin(self)?)
44 }
45
46 fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
47 &self,
48 sql: Q,
49 params: &[Value],
50 ) -> StdResult<T> {
51 let mut statement = prepare_statement(self, sql.as_ref())?;
52 statement.bind(params)?;
53 statement.next()?;
54 statement
55 .read::<T, _>(0)
56 .with_context(|| "Read query error")
57 }
58
59 fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>> {
60 let (condition, params) = query.filters().expand();
61 let sql = query.get_definition(&condition);
62 let cursor = prepare_statement(self, &sql)?
63 .into_iter()
64 .bind(¶ms[..])?;
65
66 let iterator = EntityCursor::new(cursor);
67
68 Ok(iterator)
69 }
70}
71
72fn prepare_statement<'conn>(
73 sqlite_connection: &'conn SqliteConnection,
74 sql: &str,
75) -> StdResult<sqlite::Statement<'conn>> {
76 sqlite_connection.prepare(sql).with_context(|| {
77 format!(
78 "Prepare query error: SQL=`{}`",
79 &sql.replace('\n', " ").trim()
80 )
81 })
82}
83
84#[cfg(test)]
85mod tests {
86 use sqlite::Connection;
87
88 use crate::sqlite::{HydrationError, SqLiteEntity, WhereCondition};
89
90 use super::*;
91
92 #[test]
93 fn test_query_string() {
94 let connection = Connection::open_thread_safe(":memory:").unwrap();
95 let value: String = connection.query_single_cell("select 'test'", &[]).unwrap();
96
97 assert_eq!(value, "test");
98 }
99
100 #[test]
101 fn test_query_max_number() {
102 let connection = Connection::open_thread_safe(":memory:").unwrap();
103 let value: i64 = connection
104 .query_single_cell(
105 "select max(a) from (select 10 a union select 90 a union select 45 a)",
106 &[],
107 )
108 .unwrap();
109
110 assert_eq!(value, 90);
111 }
112
113 #[test]
114 fn test_query_with_params() {
115 let connection = Connection::open_thread_safe(":memory:").unwrap();
116 let value: i64 = connection
117 .query_single_cell(
118 "select max(a) from (select 10 a union select 45 a union select 90 a) \
119 where a > ? and a < ?",
120 &[Value::Integer(10), Value::Integer(90)],
121 )
122 .unwrap();
123
124 assert_eq!(value, 45);
125 }
126
127 #[test]
128 fn test_apply_execute_the_query() {
129 struct DummySqLiteEntity {}
130 impl SqLiteEntity for DummySqLiteEntity {
131 fn hydrate(_row: sqlite::Row) -> Result<Self, HydrationError>
132 where
133 Self: Sized,
134 {
135 unimplemented!()
136 }
137
138 fn get_projection() -> crate::sqlite::Projection {
139 unimplemented!()
140 }
141 }
142
143 struct FakeQuery {
144 sql: String,
145 }
146 impl Query for FakeQuery {
147 type Entity = DummySqLiteEntity;
148
149 fn filters(&self) -> WhereCondition {
150 WhereCondition::default()
151 }
152
153 fn get_definition(&self, _condition: &str) -> String {
154 self.sql.clone()
155 }
156 }
157
158 let connection = Connection::open_thread_safe(":memory:").unwrap();
159 connection
160 .execute("create table query_test(text_data);")
161 .unwrap();
162
163 let value: i64 = connection
164 .query_single_cell("select count(*) from query_test", &[])
165 .unwrap();
166 assert_eq!(value, 0);
167
168 let query = FakeQuery {
169 sql: "insert into query_test(text_data) values ('row 1')".to_string(),
170 };
171 connection.apply(query).unwrap();
172
173 let value: i64 = connection
174 .query_single_cell("select count(*) from query_test", &[])
175 .unwrap();
176 assert_eq!(value, 1);
177 }
178}