mithril_persistence/sqlite/
connection_pool.rs

1use std::ops::Deref;
2use std::sync::{Arc, RwLock};
3use std::time::Duration;
4
5use anyhow::anyhow;
6
7use mithril_common::StdResult;
8use mithril_resource_pool::{Reset, ResourcePool, ResourcePoolItem};
9
10use crate::sqlite::{ConnectionBuilder, SqliteConnection};
11
12/// SqliteConnection wrapper for a pooled connection
13pub struct SqlitePooledConnection {
14    connection: SqliteConnection,
15    actual_version: u64,
16    pool_version: Arc<RwLock<u64>>,
17    builder: Arc<ConnectionBuilder>,
18}
19
20impl Reset for SqlitePooledConnection {
21    fn reset(&mut self) -> StdResult<()> {
22        let pool_version = self
23            .pool_version
24            .read()
25            .map_err(|e| anyhow!(e.to_string()).context("Failed to acquire pool version lock"))?;
26        if self.actual_version < *pool_version {
27            self.connection = self.builder.build_without_migrations()?;
28            self.actual_version = *pool_version;
29        }
30
31        Ok(())
32    }
33}
34
35impl SqlitePooledConnection {
36    /// Create a new SqlitePooledConnection
37    fn new(
38        connection: SqliteConnection,
39        initial_version: u64,
40        pool_version: Arc<RwLock<u64>>,
41        builder: Arc<ConnectionBuilder>,
42    ) -> Self {
43        Self {
44            connection,
45            actual_version: initial_version,
46            pool_version,
47            builder,
48        }
49    }
50}
51
52impl Deref for SqlitePooledConnection {
53    type Target = SqliteConnection;
54
55    fn deref(&self) -> &Self::Target {
56        &self.connection
57    }
58}
59
60/// Pool of Sqlite connections
61pub struct SqliteConnectionPool {
62    connections: ResourcePool<SqlitePooledConnection>,
63    pool_version: Arc<RwLock<u64>>,
64}
65
66impl SqliteConnectionPool {
67    /// Create a new pool with the given size by calling the given builder function
68    pub fn build(size: usize, builder: ConnectionBuilder) -> StdResult<Self> {
69        let mut connections: Vec<SqlitePooledConnection> = Vec::with_capacity(size);
70        let initial_version = 0;
71        let pool_version = Arc::new(RwLock::new(initial_version));
72        let builder = Arc::new(builder);
73
74        for _count in 0..size {
75            connections.push(SqlitePooledConnection::new(
76                builder.build_without_migrations()?,
77                initial_version,
78                pool_version.clone(),
79                builder.clone(),
80            ));
81        }
82
83        Ok(Self {
84            connections: ResourcePool::new(connections.len(), connections),
85            pool_version,
86        })
87    }
88
89    /// Get a connection from the pool
90    pub fn connection(&self) -> StdResult<ResourcePoolItem<'_, SqlitePooledConnection>> {
91        let timeout = Duration::from_millis(1000);
92        let connection = self.connections.acquire_resource(timeout)?;
93
94        Ok(connection)
95    }
96
97    /// Renew all connections in this pool
98    ///
99    /// Renewing a connection means that it will be closed and replaced with a new connection.
100    ///
101    /// - Idle connections are renewed immediately.
102    /// - Pooled connections will be renewed when they are returned to the pool.
103    pub fn renew_connections(&self) -> StdResult<()> {
104        self.schedule_renew_for_all_connections()?;
105        self.connections.reset_available_resources()?;
106
107        Ok(())
108    }
109
110    fn schedule_renew_for_all_connections(&self) -> StdResult<()> {
111        let mut pool_version = self.pool_version.write().map_err(|e| {
112            anyhow!(e.to_string()).context("Failed to schedule connections renewal")
113        })?;
114        *pool_version += 1;
115
116        Ok(())
117    }
118
119    #[cfg(test)]
120    fn pool_version(&self) -> u64 {
121        *self.pool_version.read().unwrap()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use slog::{Drain, Logger};
128
129    use mithril_common::temp_dir_create;
130    use mithril_common::test::logging::MemoryDrainForTest;
131
132    use crate::database::SqlMigration;
133
134    use super::*;
135
136    #[test]
137    fn can_build_pool_of_given_size() {
138        let pool = SqliteConnectionPool::build(10, ConnectionBuilder::open_memory()).unwrap();
139
140        assert_eq!(pool.connections.size(), 10);
141    }
142
143    #[test]
144    fn pooled_connection_release_connection_when_drop() {
145        let connection_pool =
146            SqliteConnectionPool::build(1, ConnectionBuilder::open_memory()).unwrap();
147
148        {
149            let _connection = connection_pool.connection().unwrap();
150            assert_eq!(0, connection_pool.connections.count().unwrap());
151        }
152
153        assert_eq!(1, connection_pool.connections.count().unwrap());
154    }
155
156    #[test]
157    fn renew_connections_rebuilds_pooled_connections_when_they_returns_to_pool() {
158        let db_path = temp_dir_create!().join("test.db");
159        let pool = Arc::new(
160            SqliteConnectionPool::build(2, ConnectionBuilder::open_file(&db_path)).unwrap(),
161        );
162
163        // One connection that will be hold off so it won't be renewed
164        let first_connection = pool.connection().unwrap();
165        {
166            // This connection will be renewed when going out of scope
167            let second_connection = pool.connection().unwrap();
168            assert_eq!(0, pool.pool_version());
169            assert_eq!(0, first_connection.actual_version);
170            assert_eq!(0, second_connection.actual_version);
171
172            pool.renew_connections().unwrap();
173
174            assert_eq!(1, pool.pool_version());
175            assert_eq!(0, first_connection.actual_version);
176            assert_eq!(0, second_connection.actual_version);
177        }
178
179        // The pools have two connections, this re-acquires the second connection, and it should have been renewed
180        let second_connection = pool.connection().unwrap();
181        assert_eq!(1, pool.pool_version());
182        assert_eq!(0, first_connection.actual_version);
183        assert_eq!(1, second_connection.actual_version);
184    }
185
186    #[test]
187    fn renew_connections_immediately_renew_idle_connections() {
188        let db_path = temp_dir_create!().join("test.db");
189        let pool = Arc::new(
190            SqliteConnectionPool::build(2, ConnectionBuilder::open_file(&db_path)).unwrap(),
191        );
192
193        let connection_that_should_not_be_renewed = pool.connection().unwrap();
194        pool.renew_connections().unwrap();
195        let connection_that_should_have_be_renewed = pool.connection().unwrap();
196
197        assert_eq!(1, pool.pool_version());
198        assert_eq!(0, connection_that_should_not_be_renewed.actual_version);
199        assert_eq!(1, connection_that_should_have_be_renewed.actual_version);
200    }
201
202    #[test]
203    fn multiple_renew_connections_increments_version() {
204        let pool = SqliteConnectionPool::build(1, ConnectionBuilder::open_memory()).unwrap();
205        assert_eq!(0, pool.pool_version());
206
207        pool.renew_connections().unwrap();
208        assert_eq!(1, pool.pool_version());
209
210        pool.renew_connections().unwrap();
211        assert_eq!(2, pool.pool_version());
212
213        pool.renew_connections().unwrap();
214        assert_eq!(3, pool.pool_version());
215    }
216
217    #[test]
218    fn do_not_apply_migrations_when_pooling_a_connection() {
219        let temp_dir = temp_dir_create!();
220        let (memory_drain, inspector) = MemoryDrainForTest::new();
221        let logger = Logger::root(memory_drain.fuse(), slog::o!());
222        let builder = ConnectionBuilder::open_file(&temp_dir.join("db.sqlite"))
223            .with_logger(logger)
224            .with_migrations(vec![SqlMigration::new(1, "")]);
225        let pool = SqliteConnectionPool::build(10, builder).unwrap();
226
227        pool.connection().unwrap();
228        pool.connection().unwrap();
229
230        let number_of_times_migrations_run =
231            inspector.search_logs(ConnectionBuilder::APPLY_MIGRATIONS_LOG).len();
232        assert_eq!(number_of_times_migrations_run, 0);
233    }
234}