mithril_common/
api_version.rs1include!(concat!(env!("OUT_DIR"), "/open_api.rs"));
3use anyhow::anyhow;
4use semver::{Version, VersionReq};
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::StdResult;
9
10#[derive(Clone)]
12pub struct APIVersionProvider {
13 alternate_file_discriminator: Arc<dyn ApiVersionDiscriminantSource>,
14 open_api_versions: HashMap<OpenAPIFileName, Version>,
15}
16
17#[cfg_attr(test, mockall::automock)]
20pub trait ApiVersionDiscriminantSource: Send + Sync {
21 fn get_discriminant(&self) -> String;
23}
24
25impl APIVersionProvider {
26 pub fn new(era_checker: Arc<dyn ApiVersionDiscriminantSource>) -> Self {
28 Self {
29 alternate_file_discriminator: era_checker,
30 open_api_versions: get_open_api_versions_mapping(),
31 }
32 }
33
34 pub fn compute_current_version(&self) -> StdResult<Version> {
36 let discriminant = self.alternate_file_discriminator.get_discriminant();
37 let open_api_spec_file_name_default = "openapi.yaml";
38 let open_api_spec_file_name_era = &format!("openapi-{discriminant}.yaml");
39 let open_api_version = self.open_api_versions.get(open_api_spec_file_name_era).unwrap_or(
40 self.open_api_versions
41 .get(open_api_spec_file_name_default)
42 .ok_or_else(|| anyhow!("Missing default API version"))?,
43 );
44
45 Ok(open_api_version.clone())
46 }
47
48 pub fn compute_current_version_requirement(&self) -> StdResult<VersionReq> {
50 let version = &self.compute_current_version()?;
51 let version_req = if version.major > 0 {
52 format!("={}", version.major)
53 } else {
54 format!("={}.{}", version.major, version.minor)
55 };
56
57 Ok(VersionReq::parse(&version_req)?)
58 }
59
60 pub fn compute_all_versions_sorted() -> Vec<Version> {
62 let mut versions: Vec<Version> = get_open_api_versions_mapping().into_values().collect();
63 versions.sort();
64 versions
65 }
66
67 pub fn update_open_api_versions(
69 &mut self,
70 open_api_versions: HashMap<OpenAPIFileName, Version>,
71 ) {
72 self.open_api_versions = open_api_versions;
73 }
74}
75
76cfg_test_tools! {
77 pub struct DummyApiVersionDiscriminantSource {
79 discriminant: String,
80 }
81
82 impl DummyApiVersionDiscriminantSource {
83 pub fn new<T: Into<String>>(discrimant: T) -> Self {
85 Self {
86 discriminant: discrimant.into(),
87 }
88 }
89 }
90
91 impl Default for DummyApiVersionDiscriminantSource {
92 fn default() -> Self {
93 Self {
94 discriminant: "dummy".to_string(),
95 }
96 }
97 }
98
99 impl ApiVersionDiscriminantSource for DummyApiVersionDiscriminantSource {
100 fn get_discriminant(&self) -> String {
101 self.discriminant.clone()
102 }
103 }
104}
105
106#[cfg(test)]
107mod test {
108 use super::*;
109
110 #[test]
111 fn test_compute_current_version_default() {
112 let discriminant_source = DummyApiVersionDiscriminantSource::default();
113 let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
114 let mut open_api_versions = HashMap::new();
115 open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
116 version_provider.update_open_api_versions(open_api_versions);
117 let api_version_provider = Arc::new(version_provider);
118
119 assert_eq!(
120 "1.2.3".to_string(),
121 api_version_provider.compute_current_version().unwrap().to_string()
122 )
123 }
124
125 #[test]
126 fn test_compute_current_version_era_specific() {
127 let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
128 let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
129 let mut open_api_versions = HashMap::new();
130 open_api_versions.insert("openapi.yaml".to_string(), Version::new(1, 2, 3));
131 open_api_versions.insert("openapi-dummy.yaml".to_string(), Version::new(2, 1, 0));
132 version_provider.update_open_api_versions(open_api_versions);
133 let api_version_provider = Arc::new(version_provider);
134
135 assert_eq!(
136 "2.1.0".to_string(),
137 api_version_provider.compute_current_version().unwrap().to_string()
138 )
139 }
140
141 #[test]
142 fn test_compute_current_version_requirement_beta() {
143 let discriminant_source = DummyApiVersionDiscriminantSource::default();
144 let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
145 let mut open_api_versions = HashMap::new();
146 open_api_versions.insert("openapi.yaml".to_string(), Version::new(0, 2, 3));
147 version_provider.update_open_api_versions(open_api_versions);
148 let api_version_provider = Arc::new(version_provider);
149
150 assert_eq!(
151 "=0.2".to_string(),
152 api_version_provider
153 .compute_current_version_requirement()
154 .unwrap()
155 .to_string()
156 )
157 }
158
159 #[test]
160 fn test_compute_current_version_requirement_stable() {
161 let discriminant_source = DummyApiVersionDiscriminantSource::default();
162 let mut version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
163 let mut open_api_versions = HashMap::new();
164 open_api_versions.insert("openapi.yaml".to_string(), Version::new(3, 2, 1));
165 version_provider.update_open_api_versions(open_api_versions);
166 let api_version_provider = Arc::new(version_provider);
167
168 assert_eq!(
169 "=3".to_string(),
170 api_version_provider
171 .compute_current_version_requirement()
172 .unwrap()
173 .to_string()
174 )
175 }
176
177 #[test]
178 fn test_compute_all_versions_sorted() {
179 let all_versions_sorted = APIVersionProvider::compute_all_versions_sorted();
180
181 assert!(!all_versions_sorted.is_empty());
182 }
183}