mithril_persistence/sqlite/
transaction.rs1use crate::sqlite::SqliteConnection;
2
3pub struct Transaction<'a> {
8 connection: &'a SqliteConnection,
9 is_active: bool,
11}
12
13impl<'a> Transaction<'a> {
14 const BEGIN_TRANSACTION: &'static str = "BEGIN TRANSACTION";
15 const COMMIT_TRANSACTION: &'static str = "COMMIT TRANSACTION";
16 const ROLLBACK_TRANSACTION: &'static str = "ROLLBACK TRANSACTION";
17
18 pub fn begin(connection: &'a SqliteConnection) -> Result<Self, sqlite::Error> {
20 connection.execute(Self::BEGIN_TRANSACTION)?;
21 Ok(Self {
22 connection,
23 is_active: true,
24 })
25 }
26
27 pub fn commit(mut self) -> Result<(), sqlite::Error> {
29 self.execute(Self::COMMIT_TRANSACTION)
30 }
31
32 pub fn rollback(mut self) -> Result<(), sqlite::Error> {
34 self.execute(Self::ROLLBACK_TRANSACTION)
35 }
36
37 fn execute(&mut self, command: &str) -> Result<(), sqlite::Error> {
38 if self.is_active {
39 self.is_active = false;
40 self.connection.execute(command)?;
41 }
42 Ok(())
43 }
44}
45
46impl Drop for Transaction<'_> {
47 fn drop(&mut self) {
48 self.execute(Self::ROLLBACK_TRANSACTION).unwrap();
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use anyhow::anyhow;
57 use sqlite::Connection;
58
59 use mithril_common::StdResult;
60
61 use crate::sqlite::ConnectionExtensions;
62
63 use super::*;
64
65 fn init_database() -> SqliteConnection {
66 let connection = Connection::open_thread_safe(":memory:").unwrap();
67 connection
68 .execute("create table query_test(text_data text not null primary key);")
69 .unwrap();
70
71 connection
72 }
73
74 fn get_number_of_rows(connection: &SqliteConnection) -> i64 {
75 connection
76 .query_single_cell("select count(*) from query_test", &[])
77 .unwrap()
78 }
79
80 #[test]
81 fn test_commit() {
82 let connection = init_database();
83
84 assert_eq!(0, get_number_of_rows(&connection));
85 {
86 let transaction = Transaction::begin(&connection).unwrap();
87 connection
88 .execute("insert into query_test(text_data) values ('row 1')")
89 .unwrap();
90 transaction.commit().unwrap();
91 }
92 assert_eq!(1, get_number_of_rows(&connection));
93 }
94
95 #[test]
96 fn test_rollback() {
97 let connection = init_database();
98
99 assert_eq!(0, get_number_of_rows(&connection));
100 {
101 let transaction = Transaction::begin(&connection).unwrap();
102 connection
103 .execute("insert into query_test(text_data) values ('row 1')")
104 .unwrap();
105 transaction.rollback().unwrap();
106 }
107 assert_eq!(0, get_number_of_rows(&connection));
108 }
109
110 #[test]
111 fn test_auto_rollback_when_dropping() {
112 let connection = init_database();
113
114 assert_eq!(0, get_number_of_rows(&connection));
115 {
116 let _transaction = Transaction::begin(&connection).unwrap();
117 connection
118 .execute("insert into query_test(text_data) values ('row 1')")
119 .unwrap();
120 }
121 assert_eq!(0, get_number_of_rows(&connection));
122 }
123
124 #[test]
125 fn test_auto_rollback_when_dropping_because_of_an_error() {
126 fn failing_function() -> StdResult<()> {
127 Err(anyhow!("This is an error"))
128 }
129 fn failing_function_that_insert_data(connection: &SqliteConnection) -> StdResult<()> {
130 let transaction = Transaction::begin(connection).unwrap();
131 connection
132 .execute("insert into query_test(text_data) values ('row 1')")
133 .unwrap();
134 failing_function()?;
135 transaction.commit().unwrap();
136 Ok(())
137 }
138
139 let connection = init_database();
140
141 assert_eq!(0, get_number_of_rows(&connection));
142 let _err = failing_function_that_insert_data(&connection).unwrap_err();
143 assert_eq!(0, get_number_of_rows(&connection));
144 }
145
146 #[test]
147 fn test_drop_dont_panic_if_previous_commit_failed() {
148 let connection = init_database();
149
150 {
151 let transaction = Transaction::begin(&connection).unwrap();
152 connection
153 .execute("insert into query_test(text_data) values ('row 1')")
154 .unwrap();
155
156 connection.execute(Transaction::COMMIT_TRANSACTION).unwrap();
158 transaction.commit().expect_err("Commit should have fail");
159
160 }
162 }
163
164 #[test]
165 fn test_drop_dont_panic_if_previous_rollback_failed() {
166 let connection = init_database();
167
168 {
169 let transaction = Transaction::begin(&connection).unwrap();
170 connection
171 .execute("insert into query_test(text_data) values ('row 1')")
172 .unwrap();
173
174 connection.execute(Transaction::COMMIT_TRANSACTION).unwrap();
176 transaction
177 .rollback()
178 .expect_err("Rollback should have fail");
179
180 }
182 }
183}