mithril_common/
api_version.rs

1//! API Version provider service
2include!(concat!(env!("OUT_DIR"), "/open_api.rs"));
3use anyhow::anyhow;
4use semver::{Version, VersionReq};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::StdResult;
9
10/// API Version provider
11#[derive(Clone)]
12pub struct APIVersionProvider {
13    alternate_file_discriminator: Arc<dyn ApiVersionDiscriminantSource>,
14    open_api_versions: HashMap<OpenAPIFileName, Version>,
15}
16
17/// Trait to get the discriminant that identifies the alternate `openapi` file to use first if found,
18/// in place of the default `openapi.yml` file.
19#[cfg_attr(test, mockall::automock)]
20pub trait ApiVersionDiscriminantSource: Send + Sync {
21    /// Get the discriminant that identifies the alternate `openapi` file
22    fn get_discriminant(&self) -> String;
23}
24
25impl APIVersionProvider {
26    /// Version provider factory
27    pub fn new(era_checker: Arc<dyn ApiVersionDiscriminantSource>) -> Self {
28        Self {
29            alternate_file_discriminator: era_checker,
30            open_api_versions: get_open_api_versions_mapping(),
31        }
32    }
33
34    /// Compute the current api version
35    pub fn compute_current_version(&self) -> StdResult<Version> {
36        let discriminant = self.alternate_file_discriminator.get_discriminant();
37        let open_api_spec_file_name_default = "openapi.yaml";
38        let open_api_spec_file_name_era = &format!("openapi-{discriminant}.yaml");
39        let open_api_version = self.open_api_versions.get(open_api_spec_file_name_era).unwrap_or(
40            self.open_api_versions
41                .get(open_api_spec_file_name_default)
42                .ok_or_else(|| anyhow!("Missing default API version"))?,
43        );
44
45        Ok(open_api_version.clone())
46    }
47
48    /// Compute the current api version requirement
49    pub fn compute_current_version_requirement(&self) -> StdResult<VersionReq> {
50        let version = &self.compute_current_version()?;
51        let version_req = if version.major > 0 {
52            format!("={}", version.major)
53        } else {
54            format!("={}.{}", version.major, version.minor)
55        };
56
57        Ok(VersionReq::parse(&version_req)?)
58    }
59
60    /// Compute all the sorted list of all versions
61    pub fn compute_all_versions_sorted() -> Vec<Version> {
62        let mut versions: Vec<Version> = get_open_api_versions_mapping().into_values().collect();
63        versions.sort();
64        versions
65    }
66}
67
68impl Default for APIVersionProvider {
69    fn default() -> Self {
70        struct DiscriminantSourceDefault;
71        impl ApiVersionDiscriminantSource for DiscriminantSourceDefault {
72            fn get_discriminant(&self) -> String {
73                // Return nonexistent discriminant to ensure the default 'openapi.yml' file is used
74                "nonexistent-discriminant".to_string()
75            }
76        }
77
78        Self::new(Arc::new(DiscriminantSourceDefault))
79    }
80}
81
82impl crate::test::api_version_extensions::ApiVersionProviderTestExtension for APIVersionProvider {
83    fn update_open_api_versions(&mut self, open_api_versions: HashMap<OpenAPIFileName, Version>) {
84        self.open_api_versions = open_api_versions;
85    }
86
87    fn new_with_default_version(version: Version) -> APIVersionProvider {
88        Self {
89            open_api_versions: HashMap::from([("openapi.yaml".to_string(), version)]),
90            ..Self::default()
91        }
92    }
93
94    fn new_failing() -> APIVersionProvider {
95        Self {
96            // Leverage the error raised if the default api version is missing
97            open_api_versions: HashMap::new(),
98            ..Self::default()
99        }
100    }
101}
102
103#[cfg(test)]
104mod test {
105    use crate::test::api_version_extensions::ApiVersionProviderTestExtension;
106    use crate::test::double::DummyApiVersionDiscriminantSource;
107
108    use super::*;
109
110    #[test]
111    fn test_compute_current_version_default() {
112        let discriminant_source = DummyApiVersionDiscriminantSource::default();
113        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
114        let mut open_api_versions = HashMap::new();
115        open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
116        version_provider.update_open_api_versions(open_api_versions);
117        let api_version_provider = Arc::new(version_provider);
118
119        assert_eq!(
120            "1.2.3".to_string(),
121            api_version_provider.compute_current_version().unwrap().to_string()
122        )
123    }
124
125    #[test]
126    fn test_compute_current_version_era_specific() {
127        let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
128        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
129        let mut open_api_versions = HashMap::new();
130        open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
131        open_api_versions.insert("openapi-dummy.yaml".to_string(), Version::new(2, 1, 0));
132        version_provider.update_open_api_versions(open_api_versions);
133        let api_version_provider = Arc::new(version_provider);
134
135        assert_eq!(
136            "2.1.0".to_string(),
137            api_version_provider.compute_current_version().unwrap().to_string()
138        )
139    }
140
141    #[test]
142    fn test_compute_current_version_requirement_beta() {
143        let discriminant_source = DummyApiVersionDiscriminantSource::default();
144        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
145        let mut open_api_versions = HashMap::new();
146        open_api_versions.insert("openapi.yaml".to_string(), Version::new(0, 2, 3));
147        version_provider.update_open_api_versions(open_api_versions);
148        let api_version_provider = Arc::new(version_provider);
149
150        assert_eq!(
151            "=0.2".to_string(),
152            api_version_provider
153                .compute_current_version_requirement()
154                .unwrap()
155                .to_string()
156        )
157    }
158
159    #[test]
160    fn test_compute_current_version_requirement_stable() {
161        let discriminant_source = DummyApiVersionDiscriminantSource::default();
162        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
163        let mut open_api_versions = HashMap::new();
164        open_api_versions.insert("openapi.yaml".to_string(), Version::new(3, 2, 1));
165        version_provider.update_open_api_versions(open_api_versions);
166        let api_version_provider = Arc::new(version_provider);
167
168        assert_eq!(
169            "=3".to_string(),
170            api_version_provider
171                .compute_current_version_requirement()
172                .unwrap()
173                .to_string()
174        )
175    }
176
177    #[test]
178    fn test_compute_all_versions_sorted() {
179        let all_versions_sorted = APIVersionProvider::compute_all_versions_sorted();
180
181        assert!(!all_versions_sorted.is_empty());
182    }
183
184    #[test]
185    fn default_provider_returns_default_version() {
186        let provider = APIVersionProvider::default();
187        let version = provider.compute_current_version().unwrap();
188
189        assert_eq!(
190            get_open_api_versions_mapping().get("openapi.yaml").unwrap(),
191            &version
192        );
193    }
194
195    #[test]
196    fn building_provider_with_canned_default_openapi_version() {
197        let provider = APIVersionProvider::new_with_default_version(Version::new(1, 2, 3));
198        let version = provider.compute_current_version().unwrap();
199
200        assert_eq!(Version::new(1, 2, 3), version);
201    }
202
203    #[test]
204    fn building_provider_that_fails_compute_current_version() {
205        let provider = APIVersionProvider::new_failing();
206        provider.compute_current_version().expect_err("Should fail");
207    }
208}