mithril_common/
api_version.rs

1//! API Version provider service
2include!(concat!(env!("OUT_DIR"), "/open_api.rs"));
3use anyhow::anyhow;
4use semver::{Version, VersionReq};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::StdResult;
9
10/// API Version provider
11#[derive(Clone)]
12pub struct APIVersionProvider {
13    alternate_file_discriminator: Arc<dyn ApiVersionDiscriminantSource>,
14    open_api_versions: HashMap<OpenAPIFileName, Version>,
15}
16
17/// Trait to get the discriminant that identifies the alternate `openapi` file to use first if found,
18/// in place of the default `openapi.yml` file.
19#[cfg_attr(test, mockall::automock)]
20pub trait ApiVersionDiscriminantSource: Send + Sync {
21    /// Get the discriminant that identifies the alternate `openapi` file
22    fn get_discriminant(&self) -> String;
23}
24
25impl APIVersionProvider {
26    /// Version provider factory
27    pub fn new(era_checker: Arc<dyn ApiVersionDiscriminantSource>) -> Self {
28        Self {
29            alternate_file_discriminator: era_checker,
30            open_api_versions: get_open_api_versions_mapping(),
31        }
32    }
33
34    /// Compute the current api version
35    pub fn compute_current_version(&self) -> StdResult<Version> {
36        let discriminant = self.alternate_file_discriminator.get_discriminant();
37        let open_api_spec_file_name_default = "openapi.yaml";
38        let open_api_spec_file_name_era = &format!("openapi-{discriminant}.yaml");
39        let open_api_version = self.open_api_versions.get(open_api_spec_file_name_era).unwrap_or(
40            self.open_api_versions
41                .get(open_api_spec_file_name_default)
42                .ok_or_else(|| anyhow!("Missing default API version"))?,
43        );
44
45        Ok(open_api_version.clone())
46    }
47
48    /// Compute the current api version requirement
49    pub fn compute_current_version_requirement(&self) -> StdResult<VersionReq> {
50        let version = &self.compute_current_version()?;
51        let version_req = if version.major > 0 {
52            format!("={}", version.major)
53        } else {
54            format!("={}.{}", version.major, version.minor)
55        };
56
57        Ok(VersionReq::parse(&version_req)?)
58    }
59
60    /// Compute all the sorted list of all versions
61    pub fn compute_all_versions_sorted() -> Vec<Version> {
62        let mut versions: Vec<Version> = get_open_api_versions_mapping().into_values().collect();
63        versions.sort();
64        versions
65    }
66
67    /// Update open api versions. Test only
68    pub fn update_open_api_versions(
69        &mut self,
70        open_api_versions: HashMap<OpenAPIFileName, Version>,
71    ) {
72        self.open_api_versions = open_api_versions;
73    }
74}
75
76cfg_test_tools! {
77    /// A dummy implementation of the `ApiVersionDiscriminantSource` trait for testing purposes.
78    pub struct DummyApiVersionDiscriminantSource {
79        discriminant: String,
80    }
81
82    impl DummyApiVersionDiscriminantSource {
83        /// Create a new instance of `DummyApiVersionDiscriminantSource` with the given discriminant.
84        pub fn new<T: Into<String>>(discrimant: T) -> Self {
85            Self {
86                discriminant: discrimant.into(),
87            }
88        }
89    }
90
91    impl Default for DummyApiVersionDiscriminantSource {
92        fn default() -> Self {
93            Self {
94                discriminant: "dummy".to_string(),
95            }
96        }
97    }
98
99    impl ApiVersionDiscriminantSource for DummyApiVersionDiscriminantSource {
100        fn get_discriminant(&self) -> String {
101            self.discriminant.clone()
102        }
103    }
104}
105
106#[cfg(test)]
107mod test {
108    use super::*;
109
110    #[test]
111    fn test_compute_current_version_default() {
112        let discriminant_source = DummyApiVersionDiscriminantSource::default();
113        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
114        let mut open_api_versions = HashMap::new();
115        open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
116        version_provider.update_open_api_versions(open_api_versions);
117        let api_version_provider = Arc::new(version_provider);
118
119        assert_eq!(
120            "1.2.3".to_string(),
121            api_version_provider.compute_current_version().unwrap().to_string()
122        )
123    }
124
125    #[test]
126    fn test_compute_current_version_era_specific() {
127        let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
128        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
129        let mut open_api_versions = HashMap::new();
130        open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
131        open_api_versions.insert("openapi-dummy.yaml".to_string(), Version::new(2, 1, 0));
132        version_provider.update_open_api_versions(open_api_versions);
133        let api_version_provider = Arc::new(version_provider);
134
135        assert_eq!(
136            "2.1.0".to_string(),
137            api_version_provider.compute_current_version().unwrap().to_string()
138        )
139    }
140
141    #[test]
142    fn test_compute_current_version_requirement_beta() {
143        let discriminant_source = DummyApiVersionDiscriminantSource::default();
144        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
145        let mut open_api_versions = HashMap::new();
146        open_api_versions.insert("openapi.yaml".to_string(), Version::new(0, 2, 3));
147        version_provider.update_open_api_versions(open_api_versions);
148        let api_version_provider = Arc::new(version_provider);
149
150        assert_eq!(
151            "=0.2".to_string(),
152            api_version_provider
153                .compute_current_version_requirement()
154                .unwrap()
155                .to_string()
156        )
157    }
158
159    #[test]
160    fn test_compute_current_version_requirement_stable() {
161        let discriminant_source = DummyApiVersionDiscriminantSource::default();
162        let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
163        let mut open_api_versions = HashMap::new();
164        open_api_versions.insert("openapi.yaml".to_string(), Version::new(3, 2, 1));
165        version_provider.update_open_api_versions(open_api_versions);
166        let api_version_provider = Arc::new(version_provider);
167
168        assert_eq!(
169            "=3".to_string(),
170            api_version_provider
171                .compute_current_version_requirement()
172                .unwrap()
173                .to_string()
174        )
175    }
176
177    #[test]
178    fn test_compute_all_versions_sorted() {
179        let all_versions_sorted = APIVersionProvider::compute_all_versions_sorted();
180
181        assert!(!all_versions_sorted.is_empty());
182    }
183}