mithril_persistence/sqlite/
connection_pool.rs1use std::ops::Deref;
2use std::sync::{Arc, RwLock};
3use std::time::Duration;
4
5use anyhow::anyhow;
6
7use mithril_common::StdResult;
8use mithril_resource_pool::{Reset, ResourcePool, ResourcePoolItem};
9
10use crate::sqlite::{ConnectionBuilder, SqliteConnection};
11
12pub struct SqlitePooledConnection {
14 connection: SqliteConnection,
15 actual_version: u64,
16 pool_version: Arc<RwLock<u64>>,
17 builder: Arc<ConnectionBuilder>,
18}
19
20impl Reset for SqlitePooledConnection {
21 fn reset(&mut self) -> StdResult<()> {
22 let pool_version = self
23 .pool_version
24 .read()
25 .map_err(|e| anyhow!(e.to_string()).context("Failed to acquire pool version lock"))?;
26 if self.actual_version < *pool_version {
27 self.connection = self.builder.build_without_migrations()?;
28 self.actual_version = *pool_version;
29 }
30
31 Ok(())
32 }
33}
34
35impl SqlitePooledConnection {
36 fn new(
38 connection: SqliteConnection,
39 initial_version: u64,
40 pool_version: Arc<RwLock<u64>>,
41 builder: Arc<ConnectionBuilder>,
42 ) -> Self {
43 Self {
44 connection,
45 actual_version: initial_version,
46 pool_version,
47 builder,
48 }
49 }
50}
51
52impl Deref for SqlitePooledConnection {
53 type Target = SqliteConnection;
54
55 fn deref(&self) -> &Self::Target {
56 &self.connection
57 }
58}
59
60pub struct SqliteConnectionPool {
62 connections: ResourcePool<SqlitePooledConnection>,
63 pool_version: Arc<RwLock<u64>>,
64}
65
66impl SqliteConnectionPool {
67 pub fn build(size: usize, builder: ConnectionBuilder) -> StdResult<Self> {
69 let mut connections: Vec<SqlitePooledConnection> = Vec::with_capacity(size);
70 let initial_version = 0;
71 let pool_version = Arc::new(RwLock::new(initial_version));
72 let builder = Arc::new(builder);
73
74 for _count in 0..size {
75 connections.push(SqlitePooledConnection::new(
76 builder.build_without_migrations()?,
77 initial_version,
78 pool_version.clone(),
79 builder.clone(),
80 ));
81 }
82
83 Ok(Self {
84 connections: ResourcePool::new(connections.len(), connections),
85 pool_version,
86 })
87 }
88
89 pub fn connection(&self) -> StdResult<ResourcePoolItem<'_, SqlitePooledConnection>> {
91 let timeout = Duration::from_millis(1000);
92 let connection = self.connections.acquire_resource(timeout)?;
93
94 Ok(connection)
95 }
96
97 pub fn renew_connections(&self) -> StdResult<()> {
104 self.schedule_renew_for_all_connections()?;
105 self.connections.reset_available_resources()?;
106
107 Ok(())
108 }
109
110 fn schedule_renew_for_all_connections(&self) -> StdResult<()> {
111 let mut pool_version = self.pool_version.write().map_err(|e| {
112 anyhow!(e.to_string()).context("Failed to schedule connections renewal")
113 })?;
114 *pool_version += 1;
115
116 Ok(())
117 }
118
119 #[cfg(test)]
120 fn pool_version(&self) -> u64 {
121 *self.pool_version.read().unwrap()
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use slog::{Drain, Logger};
128
129 use mithril_common::temp_dir_create;
130 use mithril_common::test::logging::MemoryDrainForTest;
131
132 use crate::database::SqlMigration;
133
134 use super::*;
135
136 #[test]
137 fn can_build_pool_of_given_size() {
138 let pool = SqliteConnectionPool::build(10, ConnectionBuilder::open_memory()).unwrap();
139
140 assert_eq!(pool.connections.size(), 10);
141 }
142
143 #[test]
144 fn pooled_connection_release_connection_when_drop() {
145 let connection_pool =
146 SqliteConnectionPool::build(1, ConnectionBuilder::open_memory()).unwrap();
147
148 {
149 let _connection = connection_pool.connection().unwrap();
150 assert_eq!(0, connection_pool.connections.count().unwrap());
151 }
152
153 assert_eq!(1, connection_pool.connections.count().unwrap());
154 }
155
156 #[test]
157 fn renew_connections_rebuilds_pooled_connections_when_they_returns_to_pool() {
158 let db_path = temp_dir_create!().join("test.db");
159 let pool = Arc::new(
160 SqliteConnectionPool::build(2, ConnectionBuilder::open_file(&db_path)).unwrap(),
161 );
162
163 let first_connection = pool.connection().unwrap();
165 {
166 let second_connection = pool.connection().unwrap();
168 assert_eq!(0, pool.pool_version());
169 assert_eq!(0, first_connection.actual_version);
170 assert_eq!(0, second_connection.actual_version);
171
172 pool.renew_connections().unwrap();
173
174 assert_eq!(1, pool.pool_version());
175 assert_eq!(0, first_connection.actual_version);
176 assert_eq!(0, second_connection.actual_version);
177 }
178
179 let second_connection = pool.connection().unwrap();
181 assert_eq!(1, pool.pool_version());
182 assert_eq!(0, first_connection.actual_version);
183 assert_eq!(1, second_connection.actual_version);
184 }
185
186 #[test]
187 fn renew_connections_immediately_renew_idle_connections() {
188 let db_path = temp_dir_create!().join("test.db");
189 let pool = Arc::new(
190 SqliteConnectionPool::build(2, ConnectionBuilder::open_file(&db_path)).unwrap(),
191 );
192
193 let connection_that_should_not_be_renewed = pool.connection().unwrap();
194 pool.renew_connections().unwrap();
195 let connection_that_should_have_be_renewed = pool.connection().unwrap();
196
197 assert_eq!(1, pool.pool_version());
198 assert_eq!(0, connection_that_should_not_be_renewed.actual_version);
199 assert_eq!(1, connection_that_should_have_be_renewed.actual_version);
200 }
201
202 #[test]
203 fn multiple_renew_connections_increments_version() {
204 let pool = SqliteConnectionPool::build(1, ConnectionBuilder::open_memory()).unwrap();
205 assert_eq!(0, pool.pool_version());
206
207 pool.renew_connections().unwrap();
208 assert_eq!(1, pool.pool_version());
209
210 pool.renew_connections().unwrap();
211 assert_eq!(2, pool.pool_version());
212
213 pool.renew_connections().unwrap();
214 assert_eq!(3, pool.pool_version());
215 }
216
217 #[test]
218 fn do_not_apply_migrations_when_pooling_a_connection() {
219 let temp_dir = temp_dir_create!();
220 let (memory_drain, inspector) = MemoryDrainForTest::new();
221 let logger = Logger::root(memory_drain.fuse(), slog::o!());
222 let builder = ConnectionBuilder::open_file(&temp_dir.join("db.sqlite"))
223 .with_logger(logger)
224 .with_migrations(vec![SqlMigration::new(1, "")]);
225 let pool = SqliteConnectionPool::build(10, builder).unwrap();
226
227 pool.connection().unwrap();
228 pool.connection().unwrap();
229
230 let number_of_times_migrations_run =
231 inspector.search_logs(ConnectionBuilder::APPLY_MIGRATIONS_LOG).len();
232 assert_eq!(number_of_times_migrations_run, 0);
233 }
234}