mithril_common/
api_version.rs

1//! API Version provider service
2include!(concat!(env!("OUT_DIR"), "/open_api.rs"));
3use anyhow::anyhow;
4use semver::{Version, VersionReq};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::StdResult;
9
10/// API Version provider
11#[derive(Clone)]
12pub struct APIVersionProvider {
13    alternate_file_discriminator: Arc<dyn ApiVersionDiscriminantSource>,
14    open_api_versions: HashMap<OpenAPIFileName, Version>,
15}
16
17/// Trait to get the discriminant that identifies the alternate `openapi` file to use first if found,
18/// in place of the default `openapi.yml` file.
19#[cfg_attr(test, mockall::automock)]
20pub trait ApiVersionDiscriminantSource: Send + Sync {
21    /// Get the discriminant that identifies the alternate `openapi` file
22    fn get_discriminant(&self) -> String;
23}
24
25impl APIVersionProvider {
26    /// Version provider factory
27    pub fn new(era_checker: Arc<dyn ApiVersionDiscriminantSource>) -> Self {
28        Self {
29            alternate_file_discriminator: era_checker,
30            open_api_versions: get_open_api_versions_mapping(),
31        }
32    }
33
34    /// Compute the current api version
35    pub fn compute_current_version(&self) -> StdResult<Version> {
36        let discriminant = self.alternate_file_discriminator.get_discriminant();
37        let open_api_spec_file_name_default = "openapi.yaml";
38        let open_api_spec_file_name_era = &format!("openapi-{discriminant}.yaml");
39        let open_api_version = self
40            .open_api_versions
41            .get(open_api_spec_file_name_era)
42            .unwrap_or(
43                self.open_api_versions
44                    .get(open_api_spec_file_name_default)
45                    .ok_or_else(|| anyhow!("Missing default API version"))?,
46            );
47
48        Ok(open_api_version.clone())
49    }
50
51    /// Compute the current api version requirement
52    pub fn compute_current_version_requirement(&self) -> StdResult<VersionReq> {
53        let version = &self.compute_current_version()?;
54        let version_req = if version.major > 0 {
55            format!("={}", version.major)
56        } else {
57            format!("={}.{}", version.major, version.minor)
58        };
59
60        Ok(VersionReq::parse(&version_req)?)
61    }
62
63    /// Compute all the sorted list of all versions
64    pub fn compute_all_versions_sorted() -> Vec<Version> {
65        let mut versions: Vec<Version> = get_open_api_versions_mapping().into_values().collect();
66        versions.sort();
67        versions
68    }
69
70    /// Update open api versions. Test only
71    pub fn update_open_api_versions(
72        &mut self,
73        open_api_versions: HashMap<OpenAPIFileName, Version>,
74    ) {
75        self.open_api_versions = open_api_versions;
76    }
77}
78
79cfg_test_tools! {
80    /// A dummy implementation of the `ApiVersionDiscriminantSource` trait for testing purposes.
81    pub struct DummyApiVersionDiscriminantSource {
82        discriminant: String,
83    }
84
85    impl DummyApiVersionDiscriminantSource {
86        /// Create a new instance of `DummyApiVersionDiscriminantSource` with the given discriminant.
87        pub fn new<T: Into<String>>(discrimant: T) -> Self {
88            Self {
89                discriminant: discrimant.into(),
90            }
91        }
92    }
93
94    impl Default for DummyApiVersionDiscriminantSource {
95        fn default() -> Self {
96            Self {
97                discriminant: "dummy".to_string(),
98            }
99        }
100    }
101
102    impl ApiVersionDiscriminantSource for DummyApiVersionDiscriminantSource {
103        fn get_discriminant(&self) -> String {
104            self.discriminant.clone()
105        }
106    }
107}
108
109#[cfg(test)]
110mod test {
111    use super::*;
112
113    #[test]
114    fn test_compute_current_version_default() {
115        let discriminant_source = DummyApiVersionDiscriminantSource::default();
116        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
117        let mut open_api_versions = HashMap::new();
118        open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
119        version_provider.update_open_api_versions(open_api_versions);
120        let api_version_provider = Arc::new(version_provider);
121
122        assert_eq!(
123            "1.2.3".to_string(),
124            api_version_provider
125                .compute_current_version()
126                .unwrap()
127                .to_string()
128        )
129    }
130
131    #[test]
132    fn test_compute_current_version_era_specific() {
133        let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
134        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
135        let mut open_api_versions = HashMap::new();
136        open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
137        open_api_versions.insert("openapi-dummy.yaml".to_string(), Version::new(2, 1, 0));
138        version_provider.update_open_api_versions(open_api_versions);
139        let api_version_provider = Arc::new(version_provider);
140
141        assert_eq!(
142            "2.1.0".to_string(),
143            api_version_provider
144                .compute_current_version()
145                .unwrap()
146                .to_string()
147        )
148    }
149
150    #[test]
151    fn test_compute_current_version_requirement_beta() {
152        let discriminant_source = DummyApiVersionDiscriminantSource::default();
153        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
154        let mut open_api_versions = HashMap::new();
155        open_api_versions.insert("openapi.yaml".to_string(), Version::new(0, 2, 3));
156        version_provider.update_open_api_versions(open_api_versions);
157        let api_version_provider = Arc::new(version_provider);
158
159        assert_eq!(
160            "=0.2".to_string(),
161            api_version_provider
162                .compute_current_version_requirement()
163                .unwrap()
164                .to_string()
165        )
166    }
167
168    #[test]
169    fn test_compute_current_version_requirement_stable() {
170        let discriminant_source = DummyApiVersionDiscriminantSource::default();
171        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
172        let mut open_api_versions = HashMap::new();
173        open_api_versions.insert("openapi.yaml".to_string(), Version::new(3, 2, 1));
174        version_provider.update_open_api_versions(open_api_versions);
175        let api_version_provider = Arc::new(version_provider);
176
177        assert_eq!(
178            "=3".to_string(),
179            api_version_provider
180                .compute_current_version_requirement()
181                .unwrap()
182                .to_string()
183        )
184    }
185
186    #[test]
187    fn test_compute_all_versions_sorted() {
188        let all_versions_sorted = APIVersionProvider::compute_all_versions_sorted();
189
190        assert!(!all_versions_sorted.is_empty());
191    }
192}