mithril_persistence/database/
db_version.rs1use anyhow::anyhow;
2use chrono::{DateTime, Utc};
3use mithril_common::StdResult;
4use sqlite::{Row, Value};
5use std::{
6 cmp::Ordering,
7 fmt::{Debug, Display},
8};
9
10use crate::sqlite::{HydrationError, Projection, Query, SourceAlias, SqLiteEntity, WhereCondition};
11
12use super::DbVersion;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum ApplicationNodeType {
17 Aggregator,
19
20 Signer,
22}
23
24impl ApplicationNodeType {
25 pub fn new(node_type: &str) -> StdResult<Self> {
27 match node_type {
28 "aggregator" => Ok(Self::Aggregator),
29 "signer" => Ok(Self::Signer),
30 _ => Err(anyhow!("unknown node type '{node_type}'")),
31 }
32 }
33}
34
35impl Display for ApplicationNodeType {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Self::Aggregator => write!(f, "aggregator"),
39 Self::Signer => write!(f, "signer"),
40 }
41 }
42}
43
44#[derive(Debug, PartialEq, Eq, Clone)]
46pub struct DatabaseVersion {
47 pub version: DbVersion,
49
50 pub application_type: ApplicationNodeType,
52
53 pub updated_at: DateTime<Utc>,
55}
56
57impl SqLiteEntity for DatabaseVersion {
58 fn hydrate(row: Row) -> Result<Self, HydrationError> {
59 let version = row.read::<i64, _>(0);
60 let application_type = row.read::<&str, _>(1);
61 let updated_at = row.read::<&str, _>(2);
62
63 Ok(Self {
64 version,
65 application_type: ApplicationNodeType::new(application_type)
66 .map_err(|e| HydrationError::InvalidData(format!("{e}")))?,
67 updated_at: DateTime::parse_from_rfc3339(updated_at)
68 .map_err(|e| {
69 HydrationError::InvalidData(format!(
70 "Could not turn string '{updated_at}' to rfc3339 Datetime. Error: {e}"
71 ))
72 })?
73 .with_timezone(&Utc),
74 })
75 }
76
77 fn get_projection() -> Projection {
78 let mut projection = Projection::default();
79 projection.add_field("version", "{:db_version:}.version", "text");
80 projection.add_field(
81 "application_type",
82 "{:db_version:}.application_type",
83 "text",
84 );
85 projection.add_field("updated_at", "{:db_version:}.updated_at", "timestamp");
86
87 projection
88 }
89}
90
91impl PartialOrd for DatabaseVersion {
92 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93 if self.application_type != other.application_type {
94 None
95 } else {
96 self.version.partial_cmp(&other.version)
97 }
98 }
99}
100
101pub struct GetDatabaseVersionQuery {
103 condition: WhereCondition,
104}
105
106impl GetDatabaseVersionQuery {
107 pub fn get_application_version(application_type: &ApplicationNodeType) -> Self {
109 let filters = WhereCondition::new(
110 "application_type = ?*",
111 vec![Value::String(format!("{application_type}"))],
112 );
113 Self { condition: filters }
114 }
115}
116
117impl Query for GetDatabaseVersionQuery {
118 type Entity = DatabaseVersion;
119
120 fn filters(&self) -> WhereCondition {
121 self.condition.clone()
122 }
123
124 fn get_definition(&self, condition: &str) -> String {
125 let aliases = SourceAlias::new(&[("{:db_version:}", "db_version")]);
126 let projection = Self::Entity::get_projection().expand(aliases);
127
128 format!(
129 r#"
130select {projection}
131from db_version
132where {condition}
133"#
134 )
135 }
136}
137
138pub struct UpdateDatabaseVersionQuery {
140 condition: WhereCondition,
141}
142
143impl UpdateDatabaseVersionQuery {
144 pub fn one(version: DatabaseVersion) -> Self {
146 let filters = WhereCondition::new(
147 "",
148 vec![
149 Value::String(format!("{}", version.application_type)),
150 Value::Integer(version.version),
151 Value::String(version.updated_at.to_rfc3339()),
152 ],
153 );
154
155 Self { condition: filters }
156 }
157}
158
159impl Query for UpdateDatabaseVersionQuery {
160 type Entity = DatabaseVersion;
161
162 fn filters(&self) -> WhereCondition {
163 self.condition.clone()
164 }
165
166 fn get_definition(&self, _condition: &str) -> String {
167 let aliases = SourceAlias::new(&[("{:db_version:}", "db_version")]);
168 let projection = Self::Entity::get_projection().expand(aliases);
169
170 format!(
171 r#"
172insert into db_version (application_type, version, updated_at) values (?, ?, ?)
173 on conflict (application_type) do update set version = excluded.version, updated_at = excluded.updated_at
174returning {projection}
175"#
176 )
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_projection() {
186 let projection = DatabaseVersion::get_projection();
187 let aliases = SourceAlias::new(&[("{:db_version:}", "whatever")]);
188
189 assert_eq!(
190 "whatever.version as version, whatever.application_type as application_type, whatever.updated_at as updated_at"
191 .to_string(),
192 projection.expand(aliases)
193 );
194 }
195
196 #[test]
197 fn test_definition() {
198 let query =
199 GetDatabaseVersionQuery::get_application_version(&ApplicationNodeType::Aggregator);
200
201 assert_eq!(
202 r#"
203select db_version.version as version, db_version.application_type as application_type, db_version.updated_at as updated_at
204from db_version
205where true
206"#,
207 query.get_definition("true")
208 )
209 }
210
211 #[test]
212 fn test_updated_entity() {
213 let query = UpdateDatabaseVersionQuery::one(DatabaseVersion {
214 version: 0,
215 application_type: ApplicationNodeType::Aggregator,
216 updated_at: Default::default(),
217 });
218
219 assert_eq!(
220 r#"
221insert into db_version (application_type, version, updated_at) values (?, ?, ?)
222 on conflict (application_type) do update set version = excluded.version, updated_at = excluded.updated_at
223returning db_version.version as version, db_version.application_type as application_type, db_version.updated_at as updated_at
224"#,
225 query.get_definition("true")
226 )
227 }
228}