mithril_persistence/sqlite/
connection_extensions.rs1use anyhow::Context;
2use sqlite::{ReadableWithIndex, Value};
3
4use mithril_common::StdResult;
5
6use crate::sqlite::{EntityCursor, Query, SqliteConnection, Transaction};
7
8pub trait ConnectionExtensions {
10 fn begin_transaction(&self) -> StdResult<Transaction>;
12
13 fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
15 &self,
16 sql: Q,
17 params: &[Value],
18 ) -> StdResult<T>;
19
20 fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>>;
22
23 fn fetch_first<Q: Query>(&self, query: Q) -> StdResult<Option<Q::Entity>> {
25 let mut cursor = self.fetch(query)?;
26 Ok(cursor.next())
27 }
28
29 fn fetch_collect<Q: Query, B: FromIterator<Q::Entity>>(&self, query: Q) -> StdResult<B> {
31 Ok(self.fetch(query)?.collect::<B>())
32 }
33
34 fn apply<Q: Query>(&self, query: Q) -> StdResult<()> {
36 self.fetch(query)?.count();
37 Ok(())
38 }
39}
40
41impl ConnectionExtensions for SqliteConnection {
42 fn begin_transaction(&self) -> StdResult<Transaction> {
43 Ok(Transaction::begin(self)?)
44 }
45
46 fn query_single_cell<Q: AsRef<str>, T: ReadableWithIndex>(
47 &self,
48 sql: Q,
49 params: &[Value],
50 ) -> StdResult<T> {
51 let mut statement = prepare_statement(self, sql.as_ref())?;
52 statement.bind(params)?;
53 statement.next()?;
54 statement.read::<T, _>(0).with_context(|| "Read query error")
55 }
56
57 fn fetch<Q: Query>(&self, query: Q) -> StdResult<EntityCursor<Q::Entity>> {
58 let (condition, params) = query.filters().expand();
59 let sql = query.get_definition(&condition);
60 let cursor = prepare_statement(self, &sql)?.into_iter().bind(¶ms[..])?;
61
62 let iterator = EntityCursor::new(cursor);
63
64 Ok(iterator)
65 }
66}
67
68fn prepare_statement<'conn>(
69 sqlite_connection: &'conn SqliteConnection,
70 sql: &str,
71) -> StdResult<sqlite::Statement<'conn>> {
72 sqlite_connection.prepare(sql).with_context(|| {
73 format!(
74 "Prepare query error: SQL=`{}`",
75 &sql.replace('\n', " ").trim()
76 )
77 })
78}
79
80#[cfg(test)]
81mod tests {
82 use sqlite::Connection;
83
84 use crate::sqlite::{HydrationError, SqLiteEntity, WhereCondition};
85
86 use super::*;
87
88 #[test]
89 fn test_query_string() {
90 let connection = Connection::open_thread_safe(":memory:").unwrap();
91 let value: String = connection.query_single_cell("select 'test'", &[]).unwrap();
92
93 assert_eq!(value, "test");
94 }
95
96 #[test]
97 fn test_query_max_number() {
98 let connection = Connection::open_thread_safe(":memory:").unwrap();
99 let value: i64 = connection
100 .query_single_cell(
101 "select max(a) from (select 10 a union select 90 a union select 45 a)",
102 &[],
103 )
104 .unwrap();
105
106 assert_eq!(value, 90);
107 }
108
109 #[test]
110 fn test_query_with_params() {
111 let connection = Connection::open_thread_safe(":memory:").unwrap();
112 let value: i64 = connection
113 .query_single_cell(
114 "select max(a) from (select 10 a union select 45 a union select 90 a) \
115 where a > ? and a < ?",
116 &[Value::Integer(10), Value::Integer(90)],
117 )
118 .unwrap();
119
120 assert_eq!(value, 45);
121 }
122
123 #[test]
124 fn test_apply_execute_the_query() {
125 struct DummySqLiteEntity {}
126 impl SqLiteEntity for DummySqLiteEntity {
127 fn hydrate(_row: sqlite::Row) -> Result<Self, HydrationError>
128 where
129 Self: Sized,
130 {
131 unimplemented!()
132 }
133
134 fn get_projection() -> crate::sqlite::Projection {
135 unimplemented!()
136 }
137 }
138
139 struct FakeQuery {
140 sql: String,
141 }
142 impl Query for FakeQuery {
143 type Entity = DummySqLiteEntity;
144
145 fn filters(&self) -> WhereCondition {
146 WhereCondition::default()
147 }
148
149 fn get_definition(&self, _condition: &str) -> String {
150 self.sql.clone()
151 }
152 }
153
154 let connection = Connection::open_thread_safe(":memory:").unwrap();
155 connection.execute("create table query_test(text_data);").unwrap();
156
157 let value: i64 = connection
158 .query_single_cell("select count(*) from query_test", &[])
159 .unwrap();
160 assert_eq!(value, 0);
161
162 let query = FakeQuery {
163 sql: "insert into query_test(text_data) values ('row 1')".to_string(),
164 };
165 connection.apply(query).unwrap();
166
167 let value: i64 = connection
168 .query_single_cell("select count(*) from query_test", &[])
169 .unwrap();
170 assert_eq!(value, 1);
171 }
172}