mithril_persistence/sqlite/
transaction.rs

1use crate::sqlite::SqliteConnection;
2
3/// Sqlite transaction wrapper.
4///
5/// Transactions are automatically rolled back if this struct object is dropped and
6/// the transaction was not committed.
7pub struct Transaction<'a> {
8    connection: &'a SqliteConnection,
9    // An active transaction is one that has yet to be committed or rolled back.
10    is_active: bool,
11}
12
13impl<'a> Transaction<'a> {
14    const BEGIN_TRANSACTION: &'static str = "BEGIN TRANSACTION";
15    const COMMIT_TRANSACTION: &'static str = "COMMIT TRANSACTION";
16    const ROLLBACK_TRANSACTION: &'static str = "ROLLBACK TRANSACTION";
17
18    /// Begin a new transaction.
19    pub fn begin(connection: &'a SqliteConnection) -> Result<Self, sqlite::Error> {
20        connection.execute(Self::BEGIN_TRANSACTION)?;
21        Ok(Self {
22            connection,
23            is_active: true,
24        })
25    }
26
27    /// Commit the transaction.
28    pub fn commit(mut self) -> Result<(), sqlite::Error> {
29        self.execute(Self::COMMIT_TRANSACTION)
30    }
31
32    /// Rollback the transaction.
33    pub fn rollback(mut self) -> Result<(), sqlite::Error> {
34        self.execute(Self::ROLLBACK_TRANSACTION)
35    }
36
37    fn execute(&mut self, command: &str) -> Result<(), sqlite::Error> {
38        if self.is_active {
39            self.is_active = false;
40            self.connection.execute(command)?;
41        }
42        Ok(())
43    }
44}
45
46impl Drop for Transaction<'_> {
47    fn drop(&mut self) {
48        // Unwrap should not happen here, otherwise it would mean that we have not handled
49        // correctly the transaction "active" state or that the connection was closed.
50        self.execute(Self::ROLLBACK_TRANSACTION).unwrap();
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use anyhow::anyhow;
57    use sqlite::Connection;
58
59    use mithril_common::StdResult;
60
61    use crate::sqlite::ConnectionExtensions;
62
63    use super::*;
64
65    fn init_database() -> SqliteConnection {
66        let connection = Connection::open_thread_safe(":memory:").unwrap();
67        connection
68            .execute("create table query_test(text_data text not null primary key);")
69            .unwrap();
70
71        connection
72    }
73
74    fn get_number_of_rows(connection: &SqliteConnection) -> i64 {
75        connection
76            .query_single_cell("select count(*) from query_test", &[])
77            .unwrap()
78    }
79
80    #[test]
81    fn test_commit() {
82        let connection = init_database();
83
84        assert_eq!(0, get_number_of_rows(&connection));
85        {
86            let transaction = Transaction::begin(&connection).unwrap();
87            connection
88                .execute("insert into query_test(text_data) values ('row 1')")
89                .unwrap();
90            transaction.commit().unwrap();
91        }
92        assert_eq!(1, get_number_of_rows(&connection));
93    }
94
95    #[test]
96    fn test_rollback() {
97        let connection = init_database();
98
99        assert_eq!(0, get_number_of_rows(&connection));
100        {
101            let transaction = Transaction::begin(&connection).unwrap();
102            connection
103                .execute("insert into query_test(text_data) values ('row 1')")
104                .unwrap();
105            transaction.rollback().unwrap();
106        }
107        assert_eq!(0, get_number_of_rows(&connection));
108    }
109
110    #[test]
111    fn test_auto_rollback_when_dropping() {
112        let connection = init_database();
113
114        assert_eq!(0, get_number_of_rows(&connection));
115        {
116            let _transaction = Transaction::begin(&connection).unwrap();
117            connection
118                .execute("insert into query_test(text_data) values ('row 1')")
119                .unwrap();
120        }
121        assert_eq!(0, get_number_of_rows(&connection));
122    }
123
124    #[test]
125    fn test_auto_rollback_when_dropping_because_of_an_error() {
126        fn failing_function() -> StdResult<()> {
127            Err(anyhow!("This is an error"))
128        }
129        fn failing_function_that_insert_data(connection: &SqliteConnection) -> StdResult<()> {
130            let transaction = Transaction::begin(connection).unwrap();
131            connection
132                .execute("insert into query_test(text_data) values ('row 1')")
133                .unwrap();
134            failing_function()?;
135            transaction.commit().unwrap();
136            Ok(())
137        }
138
139        let connection = init_database();
140
141        assert_eq!(0, get_number_of_rows(&connection));
142        let _err = failing_function_that_insert_data(&connection).unwrap_err();
143        assert_eq!(0, get_number_of_rows(&connection));
144    }
145
146    #[test]
147    fn test_drop_dont_panic_if_previous_commit_failed() {
148        let connection = init_database();
149
150        {
151            let transaction = Transaction::begin(&connection).unwrap();
152            connection
153                .execute("insert into query_test(text_data) values ('row 1')")
154                .unwrap();
155
156            // Commiting make the transaction inactive thus make next operation fail
157            connection.execute(Transaction::COMMIT_TRANSACTION).unwrap();
158            transaction.commit().expect_err("Commit should have fail");
159
160            // When going out of scope, drop is called and should not panic
161        }
162    }
163
164    #[test]
165    fn test_drop_dont_panic_if_previous_rollback_failed() {
166        let connection = init_database();
167
168        {
169            let transaction = Transaction::begin(&connection).unwrap();
170            connection
171                .execute("insert into query_test(text_data) values ('row 1')")
172                .unwrap();
173
174            // Commiting make the transaction inactive thus make next operation fail
175            connection.execute(Transaction::COMMIT_TRANSACTION).unwrap();
176            transaction
177                .rollback()
178                .expect_err("Rollback should have fail");
179
180            // When going out of scope, drop is called and should not panic
181        }
182    }
183}