1use anyhow::{Context, anyhow};
2use reqwest::{IntoUrl, Response, Url, header::HeaderMap};
3use semver::Version;
4use slog::{Logger, debug, error, warn};
5use std::sync::Arc;
6use std::time::Duration;
7
8use mithril_common::MITHRIL_API_VERSION_HEADER;
9use mithril_common::api_version::APIVersionProvider;
10
11use crate::AggregatorHttpClientResult;
12use crate::builder::AggregatorClientBuilder;
13use crate::error::AggregatorHttpClientError;
14use crate::query::{AggregatorQuery, QueryContext, QueryMethod};
15
16const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
17const API_VERSION_COMPUTE_FAILURE_MESSAGE: &str = "Failed to compute the current API version";
18
19pub struct AggregatorHttpClient {
21 pub(super) aggregator_endpoint: Url,
22 pub(super) api_version_provider: Arc<APIVersionProvider>,
23 pub(super) additional_headers: HeaderMap,
24 pub(super) timeout_duration: Option<Duration>,
25 pub(super) client: reqwest::Client,
26 pub(super) logger: Logger,
27}
28
29impl AggregatorHttpClient {
30 pub fn builder<U: IntoUrl>(aggregator_url: U) -> AggregatorClientBuilder {
34 AggregatorClientBuilder::new(aggregator_url)
35 }
36
37 pub async fn send<Q: AggregatorQuery>(
39 &self,
40 query: Q,
41 ) -> AggregatorHttpClientResult<Q::Response> {
42 let route = query.route();
43 debug!(
44 self.logger, "{} /{route}", Q::method();
45 "aggregator" => %self.aggregator_endpoint, query.entry_log_additional_fields(),
46 );
47
48 let current_api_version = self
49 .api_version_provider
50 .compute_current_version()
51 .inspect_err(
52 |err| error!(self.logger, "{API_VERSION_COMPUTE_FAILURE_MESSAGE}"; "error" => ?err),
53 )
54 .ok();
55
56 let mut request_builder = match Q::method() {
57 QueryMethod::Get => self.client.get(self.join_aggregator_endpoint(&route)?),
58 QueryMethod::Post => self.client.post(self.join_aggregator_endpoint(&route)?),
59 }
60 .headers(self.additional_headers.clone());
61
62 if let Some(version) = ¤t_api_version {
63 request_builder =
64 request_builder.header(MITHRIL_API_VERSION_HEADER, version.to_string());
65 }
66
67 if let Some(body) = query.body() {
68 request_builder = request_builder.json(&body);
69 }
70
71 if let Some(timeout) = self.timeout_duration {
72 request_builder = request_builder.timeout(timeout);
73 }
74
75 let response = request_builder
76 .send()
77 .await
78 .map_err(|e| AggregatorHttpClientError::RemoteServerUnreachable(anyhow!(e)))?;
79
80 if let Some(version) = ¤t_api_version {
81 self.warn_if_api_version_mismatch(&response, version);
82 }
83
84 let context = QueryContext {
85 response,
86 logger: self.logger.clone(),
87 };
88 query.handle_response(context).await
89 }
90
91 fn join_aggregator_endpoint(&self, endpoint: &str) -> AggregatorHttpClientResult<Url> {
92 self.aggregator_endpoint
93 .join(endpoint)
94 .with_context(|| {
95 format!(
96 "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
97 self.aggregator_endpoint
98 )
99 })
100 .map_err(AggregatorHttpClientError::InvalidEndpoint)
101 }
102
103 fn warn_if_api_version_mismatch(&self, response: &Response, client_version: &Version) {
105 let remote_aggregator_version = response
106 .headers()
107 .get(MITHRIL_API_VERSION_HEADER)
108 .and_then(|v| v.to_str().ok())
109 .and_then(|s| Version::parse(s).ok());
110
111 if let Some(aggregator) = remote_aggregator_version
112 && client_version < &aggregator
113 {
114 warn!(self.logger, "{API_VERSION_MISMATCH_WARNING_MESSAGE}";
115 "remote_aggregator_version" => %aggregator,
116 "caller_version" => %client_version,
117 );
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use http::StatusCode;
125
126 use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
127
128 use crate::query::QueryLogFields;
129 use crate::test::{TestLogger, setup_server_and_client};
130
131 use super::*;
132
133 #[derive(Debug, Eq, PartialEq, serde::Deserialize)]
134 struct TestResponse {
135 foo: String,
136 bar: i32,
137 }
138
139 struct TestGetQuery;
140
141 #[async_trait::async_trait]
142 impl AggregatorQuery for TestGetQuery {
143 type Response = TestResponse;
144 type Body = ();
145
146 fn method() -> QueryMethod {
147 QueryMethod::Get
148 }
149
150 fn route(&self) -> String {
151 "dummy-get-route".to_string()
152 }
153
154 async fn handle_response(
155 &self,
156 context: QueryContext,
157 ) -> AggregatorHttpClientResult<Self::Response> {
158 match context.response.status() {
159 StatusCode::OK => context
160 .response
161 .json::<TestResponse>()
162 .await
163 .map_err(|err| AggregatorHttpClientError::JsonParseFailed(anyhow!(err))),
164 _ => Err(context.unhandled_status_code().await),
165 }
166 }
167 }
168
169 #[derive(Debug, Clone, Eq, PartialEq, serde::Serialize)]
170 struct TestBody {
171 pika: String,
172 chu: u8,
173 }
174
175 impl TestBody {
176 fn new<P: Into<String>>(pika: P, chu: u8) -> Self {
177 Self {
178 pika: pika.into(),
179 chu,
180 }
181 }
182 }
183
184 struct TestPostQuery {
185 body: TestBody,
186 }
187
188 #[async_trait::async_trait]
189 impl AggregatorQuery for TestPostQuery {
190 type Response = ();
191 type Body = TestBody;
192
193 fn method() -> QueryMethod {
194 QueryMethod::Post
195 }
196
197 fn route(&self) -> String {
198 "dummy-post-route".to_string()
199 }
200
201 fn body(&self) -> Option<Self::Body> {
202 Some(self.body.clone())
203 }
204
205 fn entry_log_additional_fields(&self) -> QueryLogFields {
206 QueryLogFields::from([
207 ("pika", self.body.pika.clone()),
208 ("chuu", format!("{:04}", self.body.chu)),
209 ])
210 }
211
212 async fn handle_response(
213 &self,
214 context: QueryContext,
215 ) -> AggregatorHttpClientResult<Self::Response> {
216 match context.response.status() {
217 StatusCode::CREATED => Ok(()),
218 _ => Err(context.unhandled_status_code().await),
219 }
220 }
221 }
222
223 #[tokio::test]
224 async fn test_minimal_get_query() {
225 let (server, client) = setup_server_and_client();
226 server.mock(|when, then| {
227 when.method(httpmock::Method::GET).path("/dummy-get-route");
228 then.status(200).body(r#"{"foo": "bar", "bar": 123}"#);
229 });
230
231 let response = client.send(TestGetQuery).await.unwrap();
232
233 assert_eq!(
234 response,
235 TestResponse {
236 foo: "bar".to_string(),
237 bar: 123,
238 }
239 )
240 }
241
242 #[tokio::test]
243 async fn test_minimal_post_query() {
244 let (server, client) = setup_server_and_client();
245 server.mock(|when, then| {
246 when.method(httpmock::Method::POST)
247 .path("/dummy-post-route")
248 .header("content-type", "application/json")
249 .body(serde_json::to_string(&TestBody::new("miaouss", 5)).unwrap());
250 then.status(201);
251 });
252
253 client
254 .send(TestPostQuery {
255 body: TestBody::new("miaouss", 5),
256 })
257 .await
258 .unwrap();
259 }
260
261 #[tokio::test]
262 async fn test_query_send_mithril_api_version_header() {
263 let (server, mut client) = setup_server_and_client();
264 client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
265 Version::parse("1.2.9").unwrap(),
266 ));
267 server.mock(|when, then| {
268 when.method(httpmock::Method::GET)
269 .header(MITHRIL_API_VERSION_HEADER, "1.2.9");
270 then.status(200).body(r#"{"foo": "a", "bar": 1}"#);
271 });
272
273 client.send(TestGetQuery).await.expect("should not fail");
274 }
275
276 #[tokio::test]
277 async fn test_dont_fail_and_logs_error_when_mithril_api_version_cannot_be_computed() {
278 let (logger, log_inspector) = TestLogger::memory();
279 let (server, mut client) = setup_server_and_client();
280 client.api_version_provider = Arc::new(APIVersionProvider::new_failing());
281 client.logger = logger;
282 server.mock(|when, then| {
283 when.method(httpmock::Method::GET);
284 then.status(200).body(r#"{"foo": "a", "bar": 1}"#);
285 });
286
287 client.send(TestGetQuery).await.expect("should not fail");
288
289 assert!(log_inspector.contains_log(API_VERSION_COMPUTE_FAILURE_MESSAGE));
290 }
291
292 #[tokio::test]
293 async fn test_log_before_query_execution() {
294 let (logger, log_inspector) = TestLogger::memory();
295 let (server, mut client) = setup_server_and_client();
296 client.logger = logger;
297 server.mock(|when, then| {
298 when.method(httpmock::Method::GET);
299 then.status(200).body(r#"{"foo": "a", "bar": 1}"#);
300 });
301 server.mock(|when, then| {
302 when.method(httpmock::Method::POST);
303 then.status(201);
304 });
305
306 client.send(TestGetQuery).await.expect("should not fail");
307 assert!(log_inspector.contains_log(&format!(
308 "DEBUG GET /dummy-get-route; aggregator={}/",
309 server.base_url()
310 )));
311
312 client
313 .send(TestPostQuery {
314 body: TestBody::new("miaouss", 4),
315 })
316 .await
317 .unwrap();
318 assert!(log_inspector.contains_log(&format!(
319 "DEBUG POST /dummy-post-route; chuu=0004, pika=miaouss, aggregator={}/",
320 server.base_url()
321 )));
322 }
323
324 #[tokio::test]
325 async fn test_query_send_additional_header_and_dont_override_mithril_api_version_header() {
326 let (server, mut client) = setup_server_and_client();
327 client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
328 Version::parse("1.2.9").unwrap(),
329 ));
330 client.additional_headers = {
331 let mut headers = HeaderMap::new();
332 headers.insert(MITHRIL_API_VERSION_HEADER, "9.4.5".parse().unwrap());
333 headers.insert("foo", "bar".parse().unwrap());
334 headers
335 };
336
337 server.mock(|when, then| {
338 when.method(httpmock::Method::POST)
339 .header(MITHRIL_API_VERSION_HEADER, "1.2.9")
340 .header("foo", "bar");
341 then.status(201).body(r#"{"foo": "a", "bar": 1}"#);
342 });
343
344 client
345 .send(TestPostQuery {
346 body: TestBody::new("miaouss", 3),
347 })
348 .await
349 .expect("should not fail");
350 }
351
352 #[tokio::test]
353 async fn test_query_timeout() {
354 let (server, mut client) = setup_server_and_client();
355 client.timeout_duration = Some(Duration::from_millis(10));
356 let _server_mock = server.mock(|when, then| {
357 when.method(httpmock::Method::GET);
358 then.delay(Duration::from_millis(100));
359 });
360
361 let error = client.send(TestGetQuery).await.expect_err("should not fail");
362
363 assert!(
364 matches!(error, AggregatorHttpClientError::RemoteServerUnreachable(_)),
365 "unexpected error type: {error:?}"
366 );
367 }
368
369 mod warn_if_api_version_mismatch {
370 use http::response::Builder as HttpResponseBuilder;
371 use reqwest::Response;
372 use std::fmt::Display;
373
374 use mithril_common::test::logging::MemoryDrainForTestInspector;
375
376 use super::*;
377
378 fn build_fake_response_with_header<K: Display, V: Display>(key: K, value: V) -> Response {
379 HttpResponseBuilder::new()
380 .header(key.to_string(), value.to_string())
381 .body("whatever")
382 .unwrap()
383 .into()
384 }
385
386 fn assert_api_version_warning_logged<A: Display, S: Display>(
387 log_inspector: &MemoryDrainForTestInspector,
388 aggregator_version: A,
389 client_version: S,
390 ) {
391 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
392 assert!(
393 log_inspector
394 .contains_log(&format!("remote_aggregator_version={aggregator_version}")),
395 "remote_aggregator_version: '{aggregator_version}'"
396 );
397 assert!(
398 log_inspector.contains_log(&format!("caller_version={client_version}")),
399 "caller_version: '{client_version}'"
400 );
401 }
402
403 #[test]
404 fn test_logs_warning_when_aggregator_api_version_is_newer() {
405 let aggregator_version = Version::new(2, 0, 0);
406 let client_version = Version::new(1, 0, 0);
407 let (logger, log_inspector) = TestLogger::memory();
408 let client = AggregatorHttpClient::builder("http://whatever")
409 .with_logger(logger)
410 .build()
411 .unwrap();
412 let response =
413 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, &aggregator_version);
414
415 assert!(aggregator_version > client_version);
416
417 client.warn_if_api_version_mismatch(&response, &client_version);
418
419 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
420 }
421
422 #[test]
423 fn test_no_warning_logged_when_versions_match() {
424 let client_version = Version::new(1, 0, 0);
425 let (logger, log_inspector) = TestLogger::memory();
426 let client = AggregatorHttpClient::builder("http://whatever")
427 .with_logger(logger)
428 .build()
429 .unwrap();
430 let response =
431 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, &client_version);
432
433 client.warn_if_api_version_mismatch(&response, &client_version);
434
435 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
436 }
437
438 #[test]
439 fn test_no_warning_logged_when_aggregator_api_version_is_older() {
440 let aggregator_version = Version::new(1, 0, 0);
441 let client_version = Version::new(2, 0, 0);
442 let (logger, log_inspector) = TestLogger::memory();
443 let client = AggregatorHttpClient::builder("http://whatever")
444 .with_logger(logger)
445 .build()
446 .unwrap();
447 let response =
448 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, &aggregator_version);
449
450 assert!(aggregator_version < client_version);
451
452 client.warn_if_api_version_mismatch(&response, &client_version);
453
454 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
455 }
456
457 #[test]
458 fn test_does_not_log_or_fail_when_header_is_missing() {
459 let client_version = Version::new(1, 0, 0);
460 let (logger, log_inspector) = TestLogger::memory();
461 let client = AggregatorHttpClient::builder("http://whatever")
462 .with_logger(logger)
463 .build()
464 .unwrap();
465 let response =
466 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
467
468 client.warn_if_api_version_mismatch(&response, &client_version);
469
470 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
471 }
472
473 #[test]
474 fn test_does_not_log_or_fail_when_header_is_not_a_version() {
475 let client_version = Version::new(1, 0, 0);
476 let (logger, log_inspector) = TestLogger::memory();
477 let client = AggregatorHttpClient::builder("http://whatever")
478 .with_logger(logger)
479 .with_api_version_provider(Arc::new(APIVersionProvider::default()))
480 .build()
481 .unwrap();
482 let response =
483 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
484
485 client.warn_if_api_version_mismatch(&response, &client_version);
486
487 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
488 }
489
490 #[tokio::test]
491 async fn test_client_log_warning_if_api_version_mismatch() {
492 let aggregator_version = Version::new(2, 0, 0);
493 let client_version = Version::new(1, 0, 0);
494 let (server, mut client) = setup_server_and_client();
495 let (logger, log_inspector) = TestLogger::memory();
496 client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
497 client_version.clone(),
498 ));
499 client.logger = logger;
500 server.mock(|_, then| {
501 then.status(StatusCode::CREATED.as_u16())
502 .header(MITHRIL_API_VERSION_HEADER, aggregator_version.to_string());
503 });
504
505 assert!(aggregator_version > client_version);
506
507 client
508 .send(TestPostQuery {
509 body: TestBody::new("miaouss", 3),
510 })
511 .await
512 .unwrap();
513
514 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
515 }
516 }
517}