mithril_persistence/database/
db_version.rs

1use anyhow::anyhow;
2use chrono::{DateTime, Utc};
3use mithril_common::StdResult;
4use sqlite::{Row, Value};
5use std::{
6    cmp::Ordering,
7    fmt::{Debug, Display},
8};
9
10use crate::sqlite::{HydrationError, Projection, Query, SourceAlias, SqLiteEntity, WhereCondition};
11
12use super::DbVersion;
13
14/// Application using a database
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum ApplicationNodeType {
17    /// Aggregator node type
18    Aggregator,
19
20    /// Signer node type
21    Signer,
22}
23
24impl ApplicationNodeType {
25    /// [ApplicationNodeType] constructor.
26    pub fn new(node_type: &str) -> StdResult<Self> {
27        match node_type {
28            "aggregator" => Ok(Self::Aggregator),
29            "signer" => Ok(Self::Signer),
30            _ => Err(anyhow!("unknown node type '{node_type}'")),
31        }
32    }
33}
34
35impl Display for ApplicationNodeType {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            Self::Aggregator => write!(f, "aggregator"),
39            Self::Signer => write!(f, "signer"),
40        }
41    }
42}
43
44/// Entity related to the `db_version` database table.
45#[derive(Debug, PartialEq, Eq, Clone)]
46pub struct DatabaseVersion {
47    /// Version of the database structure.
48    pub version: DbVersion,
49
50    /// Name of the application.
51    pub application_type: ApplicationNodeType,
52
53    /// Date of the last version upgrade
54    pub updated_at: DateTime<Utc>,
55}
56
57impl SqLiteEntity for DatabaseVersion {
58    fn hydrate(row: Row) -> Result<Self, HydrationError> {
59        let version = row.read::<i64, _>(0);
60        let application_type = row.read::<&str, _>(1);
61        let updated_at = row.read::<&str, _>(2);
62
63        Ok(Self {
64            version,
65            application_type: ApplicationNodeType::new(application_type)
66                .map_err(|e| HydrationError::InvalidData(format!("{e}")))?,
67            updated_at: DateTime::parse_from_rfc3339(updated_at)
68                .map_err(|e| {
69                    HydrationError::InvalidData(format!(
70                        "Could not turn string '{updated_at}' to rfc3339 Datetime. Error: {e}"
71                    ))
72                })?
73                .with_timezone(&Utc),
74        })
75    }
76
77    fn get_projection() -> Projection {
78        let mut projection = Projection::default();
79        projection.add_field("version", "{:db_version:}.version", "text");
80        projection.add_field(
81            "application_type",
82            "{:db_version:}.application_type",
83            "text",
84        );
85        projection.add_field("updated_at", "{:db_version:}.updated_at", "timestamp");
86
87        projection
88    }
89}
90
91impl PartialOrd for DatabaseVersion {
92    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93        if self.application_type != other.application_type {
94            None
95        } else {
96            self.version.partial_cmp(&other.version)
97        }
98    }
99}
100
101/// Query to get [DatabaseVersion] entities.
102pub struct GetDatabaseVersionQuery {
103    condition: WhereCondition,
104}
105
106impl GetDatabaseVersionQuery {
107    /// Query to read the application version from the database.
108    pub fn get_application_version(application_type: &ApplicationNodeType) -> Self {
109        let filters = WhereCondition::new(
110            "application_type = ?*",
111            vec![Value::String(format!("{application_type}"))],
112        );
113        Self { condition: filters }
114    }
115}
116
117impl Query for GetDatabaseVersionQuery {
118    type Entity = DatabaseVersion;
119
120    fn filters(&self) -> WhereCondition {
121        self.condition.clone()
122    }
123
124    fn get_definition(&self, condition: &str) -> String {
125        let aliases = SourceAlias::new(&[("{:db_version:}", "db_version")]);
126        let projection = Self::Entity::get_projection().expand(aliases);
127
128        format!(
129            r#"
130select {projection}
131from db_version
132where {condition}
133"#
134        )
135    }
136}
137
138/// Query to UPSERT [DatabaseVersion] entities.
139pub struct UpdateDatabaseVersionQuery {
140    condition: WhereCondition,
141}
142
143impl UpdateDatabaseVersionQuery {
144    /// Define a query that will UPSERT the given version.
145    pub fn one(version: DatabaseVersion) -> Self {
146        let filters = WhereCondition::new(
147            "",
148            vec![
149                Value::String(format!("{}", version.application_type)),
150                Value::Integer(version.version),
151                Value::String(version.updated_at.to_rfc3339()),
152            ],
153        );
154
155        Self { condition: filters }
156    }
157}
158
159impl Query for UpdateDatabaseVersionQuery {
160    type Entity = DatabaseVersion;
161
162    fn filters(&self) -> WhereCondition {
163        self.condition.clone()
164    }
165
166    fn get_definition(&self, _condition: &str) -> String {
167        let aliases = SourceAlias::new(&[("{:db_version:}", "db_version")]);
168        let projection = Self::Entity::get_projection().expand(aliases);
169
170        format!(
171            r#"
172insert into db_version (application_type, version, updated_at) values (?, ?, ?)
173  on conflict (application_type) do update set version = excluded.version, updated_at = excluded.updated_at
174returning {projection}
175"#
176        )
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_projection() {
186        let projection = DatabaseVersion::get_projection();
187        let aliases = SourceAlias::new(&[("{:db_version:}", "whatever")]);
188
189        assert_eq!(
190            "whatever.version as version, whatever.application_type as application_type, whatever.updated_at as updated_at"
191                .to_string(),
192            projection.expand(aliases)
193        );
194    }
195
196    #[test]
197    fn test_definition() {
198        let query =
199            GetDatabaseVersionQuery::get_application_version(&ApplicationNodeType::Aggregator);
200
201        assert_eq!(
202            r#"
203select db_version.version as version, db_version.application_type as application_type, db_version.updated_at as updated_at
204from db_version
205where true
206"#,
207            query.get_definition("true")
208        )
209    }
210
211    #[test]
212    fn test_updated_entity() {
213        let query = UpdateDatabaseVersionQuery::one(DatabaseVersion {
214            version: 0,
215            application_type: ApplicationNodeType::Aggregator,
216            updated_at: Default::default(),
217        });
218
219        assert_eq!(
220            r#"
221insert into db_version (application_type, version, updated_at) values (?, ?, ?)
222  on conflict (application_type) do update set version = excluded.version, updated_at = excluded.updated_at
223returning db_version.version as version, db_version.application_type as application_type, db_version.updated_at as updated_at
224"#,
225            query.get_definition("true")
226        )
227    }
228}