mithril_aggregator_client/
client.rs

1use anyhow::{Context, anyhow};
2use reqwest::{IntoUrl, Response, Url, header::HeaderMap};
3use semver::Version;
4use slog::{Logger, debug, error, warn};
5use std::sync::Arc;
6use std::time::Duration;
7
8use mithril_common::MITHRIL_API_VERSION_HEADER;
9use mithril_common::api_version::APIVersionProvider;
10
11use crate::AggregatorHttpClientResult;
12use crate::builder::AggregatorClientBuilder;
13use crate::error::AggregatorHttpClientError;
14use crate::query::{AggregatorQuery, QueryContext, QueryMethod};
15
16const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
17const API_VERSION_COMPUTE_FAILURE_MESSAGE: &str = "Failed to compute the current API version";
18
19/// A client to send HTTP requests to a Mithril Aggregator
20pub struct AggregatorHttpClient {
21    pub(super) aggregator_endpoint: Url,
22    pub(super) api_version_provider: Arc<APIVersionProvider>,
23    pub(super) additional_headers: HeaderMap,
24    pub(super) timeout_duration: Option<Duration>,
25    pub(super) client: reqwest::Client,
26    pub(super) logger: Logger,
27}
28
29impl AggregatorHttpClient {
30    /// Creates a [AggregatorClientBuilder] to configure a `AggregatorClient`.
31    //
32    // This is the same as `AggregatorClient::builder()`.
33    pub fn builder<U: IntoUrl>(aggregator_url: U) -> AggregatorClientBuilder {
34        AggregatorClientBuilder::new(aggregator_url)
35    }
36
37    /// Send the given query to the Mithril Aggregator
38    pub async fn send<Q: AggregatorQuery>(
39        &self,
40        query: Q,
41    ) -> AggregatorHttpClientResult<Q::Response> {
42        let route = query.route();
43        debug!(
44            self.logger, "{} /{route}", Q::method();
45            "aggregator" => %self.aggregator_endpoint, query.entry_log_additional_fields(),
46        );
47
48        let current_api_version = self
49            .api_version_provider
50            .compute_current_version()
51            .inspect_err(
52                |err| error!(self.logger, "{API_VERSION_COMPUTE_FAILURE_MESSAGE}"; "error" => ?err),
53            )
54            .ok();
55
56        let mut request_builder = match Q::method() {
57            QueryMethod::Get => self.client.get(self.join_aggregator_endpoint(&route)?),
58            QueryMethod::Post => self.client.post(self.join_aggregator_endpoint(&route)?),
59        }
60        .headers(self.additional_headers.clone());
61
62        if let Some(version) = &current_api_version {
63            request_builder =
64                request_builder.header(MITHRIL_API_VERSION_HEADER, version.to_string());
65        }
66
67        if let Some(body) = query.body() {
68            request_builder = request_builder.json(&body);
69        }
70
71        if let Some(timeout) = self.timeout_duration {
72            request_builder = request_builder.timeout(timeout);
73        }
74
75        let response = request_builder
76            .send()
77            .await
78            .map_err(|e| AggregatorHttpClientError::RemoteServerUnreachable(anyhow!(e)))?;
79
80        if let Some(version) = &current_api_version {
81            self.warn_if_api_version_mismatch(&response, version);
82        }
83
84        let context = QueryContext {
85            response,
86            logger: self.logger.clone(),
87        };
88        query.handle_response(context).await
89    }
90
91    fn join_aggregator_endpoint(&self, endpoint: &str) -> AggregatorHttpClientResult<Url> {
92        self.aggregator_endpoint
93            .join(endpoint)
94            .with_context(|| {
95                format!(
96                    "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
97                    self.aggregator_endpoint
98                )
99            })
100            .map_err(AggregatorHttpClientError::InvalidEndpoint)
101    }
102
103    /// Check API version mismatch and log a warning if the aggregator's version is more recent.
104    fn warn_if_api_version_mismatch(&self, response: &Response, client_version: &Version) {
105        let remote_aggregator_version = response
106            .headers()
107            .get(MITHRIL_API_VERSION_HEADER)
108            .and_then(|v| v.to_str().ok())
109            .and_then(|s| Version::parse(s).ok());
110
111        if let Some(aggregator) = remote_aggregator_version
112            && client_version < &aggregator
113        {
114            warn!(self.logger, "{API_VERSION_MISMATCH_WARNING_MESSAGE}";
115                "remote_aggregator_version" => %aggregator,
116                "caller_version" => %client_version,
117            );
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use http::StatusCode;
125
126    use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
127
128    use crate::query::QueryLogFields;
129    use crate::test::{TestLogger, setup_server_and_client};
130
131    use super::*;
132
133    #[derive(Debug, Eq, PartialEq, serde::Deserialize)]
134    struct TestResponse {
135        foo: String,
136        bar: i32,
137    }
138
139    struct TestGetQuery;
140
141    #[async_trait::async_trait]
142    impl AggregatorQuery for TestGetQuery {
143        type Response = TestResponse;
144        type Body = ();
145
146        fn method() -> QueryMethod {
147            QueryMethod::Get
148        }
149
150        fn route(&self) -> String {
151            "dummy-get-route".to_string()
152        }
153
154        async fn handle_response(
155            &self,
156            context: QueryContext,
157        ) -> AggregatorHttpClientResult<Self::Response> {
158            match context.response.status() {
159                StatusCode::OK => context
160                    .response
161                    .json::<TestResponse>()
162                    .await
163                    .map_err(|err| AggregatorHttpClientError::JsonParseFailed(anyhow!(err))),
164                _ => Err(context.unhandled_status_code().await),
165            }
166        }
167    }
168
169    #[derive(Debug, Clone, Eq, PartialEq, serde::Serialize)]
170    struct TestBody {
171        pika: String,
172        chu: u8,
173    }
174
175    impl TestBody {
176        fn new<P: Into<String>>(pika: P, chu: u8) -> Self {
177            Self {
178                pika: pika.into(),
179                chu,
180            }
181        }
182    }
183
184    struct TestPostQuery {
185        body: TestBody,
186    }
187
188    #[async_trait::async_trait]
189    impl AggregatorQuery for TestPostQuery {
190        type Response = ();
191        type Body = TestBody;
192
193        fn method() -> QueryMethod {
194            QueryMethod::Post
195        }
196
197        fn route(&self) -> String {
198            "dummy-post-route".to_string()
199        }
200
201        fn body(&self) -> Option<Self::Body> {
202            Some(self.body.clone())
203        }
204
205        fn entry_log_additional_fields(&self) -> QueryLogFields {
206            QueryLogFields::from([
207                ("pika", self.body.pika.clone()),
208                ("chuu", format!("{:04}", self.body.chu)),
209            ])
210        }
211
212        async fn handle_response(
213            &self,
214            context: QueryContext,
215        ) -> AggregatorHttpClientResult<Self::Response> {
216            match context.response.status() {
217                StatusCode::CREATED => Ok(()),
218                _ => Err(context.unhandled_status_code().await),
219            }
220        }
221    }
222
223    #[tokio::test]
224    async fn test_minimal_get_query() {
225        let (server, client) = setup_server_and_client();
226        server.mock(|when, then| {
227            when.method(httpmock::Method::GET).path("/dummy-get-route");
228            then.status(200).body(r#"{"foo": "bar", "bar": 123}"#);
229        });
230
231        let response = client.send(TestGetQuery).await.unwrap();
232
233        assert_eq!(
234            response,
235            TestResponse {
236                foo: "bar".to_string(),
237                bar: 123,
238            }
239        )
240    }
241
242    #[tokio::test]
243    async fn test_minimal_post_query() {
244        let (server, client) = setup_server_and_client();
245        server.mock(|when, then| {
246            when.method(httpmock::Method::POST)
247                .path("/dummy-post-route")
248                .header("content-type", "application/json")
249                .body(serde_json::to_string(&TestBody::new("miaouss", 5)).unwrap());
250            then.status(201);
251        });
252
253        client
254            .send(TestPostQuery {
255                body: TestBody::new("miaouss", 5),
256            })
257            .await
258            .unwrap();
259    }
260
261    #[tokio::test]
262    async fn test_query_send_mithril_api_version_header() {
263        let (server, mut client) = setup_server_and_client();
264        client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
265            Version::parse("1.2.9").unwrap(),
266        ));
267        server.mock(|when, then| {
268            when.method(httpmock::Method::GET)
269                .header(MITHRIL_API_VERSION_HEADER, "1.2.9");
270            then.status(200).body(r#"{"foo": "a", "bar": 1}"#);
271        });
272
273        client.send(TestGetQuery).await.expect("should not fail");
274    }
275
276    #[tokio::test]
277    async fn test_dont_fail_and_logs_error_when_mithril_api_version_cannot_be_computed() {
278        let (logger, log_inspector) = TestLogger::memory();
279        let (server, mut client) = setup_server_and_client();
280        client.api_version_provider = Arc::new(APIVersionProvider::new_failing());
281        client.logger = logger;
282        server.mock(|when, then| {
283            when.method(httpmock::Method::GET);
284            then.status(200).body(r#"{"foo": "a", "bar": 1}"#);
285        });
286
287        client.send(TestGetQuery).await.expect("should not fail");
288
289        assert!(log_inspector.contains_log(API_VERSION_COMPUTE_FAILURE_MESSAGE));
290    }
291
292    #[tokio::test]
293    async fn test_log_before_query_execution() {
294        let (logger, log_inspector) = TestLogger::memory();
295        let (server, mut client) = setup_server_and_client();
296        client.logger = logger;
297        server.mock(|when, then| {
298            when.method(httpmock::Method::GET);
299            then.status(200).body(r#"{"foo": "a", "bar": 1}"#);
300        });
301        server.mock(|when, then| {
302            when.method(httpmock::Method::POST);
303            then.status(201);
304        });
305
306        client.send(TestGetQuery).await.expect("should not fail");
307        assert!(log_inspector.contains_log(&format!(
308            "DEBUG GET /dummy-get-route; aggregator={}/",
309            server.base_url()
310        )));
311
312        client
313            .send(TestPostQuery {
314                body: TestBody::new("miaouss", 4),
315            })
316            .await
317            .unwrap();
318        assert!(log_inspector.contains_log(&format!(
319            "DEBUG POST /dummy-post-route; chuu=0004, pika=miaouss, aggregator={}/",
320            server.base_url()
321        )));
322    }
323
324    #[tokio::test]
325    async fn test_query_send_additional_header_and_dont_override_mithril_api_version_header() {
326        let (server, mut client) = setup_server_and_client();
327        client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
328            Version::parse("1.2.9").unwrap(),
329        ));
330        client.additional_headers = {
331            let mut headers = HeaderMap::new();
332            headers.insert(MITHRIL_API_VERSION_HEADER, "9.4.5".parse().unwrap());
333            headers.insert("foo", "bar".parse().unwrap());
334            headers
335        };
336
337        server.mock(|when, then| {
338            when.method(httpmock::Method::POST)
339                .header(MITHRIL_API_VERSION_HEADER, "1.2.9")
340                .header("foo", "bar");
341            then.status(201).body(r#"{"foo": "a", "bar": 1}"#);
342        });
343
344        client
345            .send(TestPostQuery {
346                body: TestBody::new("miaouss", 3),
347            })
348            .await
349            .expect("should not fail");
350    }
351
352    #[tokio::test]
353    async fn test_query_timeout() {
354        let (server, mut client) = setup_server_and_client();
355        client.timeout_duration = Some(Duration::from_millis(10));
356        let _server_mock = server.mock(|when, then| {
357            when.method(httpmock::Method::GET);
358            then.delay(Duration::from_millis(100));
359        });
360
361        let error = client.send(TestGetQuery).await.expect_err("should not fail");
362
363        assert!(
364            matches!(error, AggregatorHttpClientError::RemoteServerUnreachable(_)),
365            "unexpected error type: {error:?}"
366        );
367    }
368
369    mod warn_if_api_version_mismatch {
370        use http::response::Builder as HttpResponseBuilder;
371        use reqwest::Response;
372        use std::fmt::Display;
373
374        use mithril_common::test::logging::MemoryDrainForTestInspector;
375
376        use super::*;
377
378        fn build_fake_response_with_header<K: Display, V: Display>(key: K, value: V) -> Response {
379            HttpResponseBuilder::new()
380                .header(key.to_string(), value.to_string())
381                .body("whatever")
382                .unwrap()
383                .into()
384        }
385
386        fn assert_api_version_warning_logged<A: Display, S: Display>(
387            log_inspector: &MemoryDrainForTestInspector,
388            aggregator_version: A,
389            client_version: S,
390        ) {
391            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
392            assert!(
393                log_inspector
394                    .contains_log(&format!("remote_aggregator_version={aggregator_version}")),
395                "remote_aggregator_version: '{aggregator_version}'"
396            );
397            assert!(
398                log_inspector.contains_log(&format!("caller_version={client_version}")),
399                "caller_version: '{client_version}'"
400            );
401        }
402
403        #[test]
404        fn test_logs_warning_when_aggregator_api_version_is_newer() {
405            let aggregator_version = Version::new(2, 0, 0);
406            let client_version = Version::new(1, 0, 0);
407            let (logger, log_inspector) = TestLogger::memory();
408            let client = AggregatorHttpClient::builder("http://whatever")
409                .with_logger(logger)
410                .build()
411                .unwrap();
412            let response =
413                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, &aggregator_version);
414
415            assert!(aggregator_version > client_version);
416
417            client.warn_if_api_version_mismatch(&response, &client_version);
418
419            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
420        }
421
422        #[test]
423        fn test_no_warning_logged_when_versions_match() {
424            let client_version = Version::new(1, 0, 0);
425            let (logger, log_inspector) = TestLogger::memory();
426            let client = AggregatorHttpClient::builder("http://whatever")
427                .with_logger(logger)
428                .build()
429                .unwrap();
430            let response =
431                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, &client_version);
432
433            client.warn_if_api_version_mismatch(&response, &client_version);
434
435            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
436        }
437
438        #[test]
439        fn test_no_warning_logged_when_aggregator_api_version_is_older() {
440            let aggregator_version = Version::new(1, 0, 0);
441            let client_version = Version::new(2, 0, 0);
442            let (logger, log_inspector) = TestLogger::memory();
443            let client = AggregatorHttpClient::builder("http://whatever")
444                .with_logger(logger)
445                .build()
446                .unwrap();
447            let response =
448                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, &aggregator_version);
449
450            assert!(aggregator_version < client_version);
451
452            client.warn_if_api_version_mismatch(&response, &client_version);
453
454            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
455        }
456
457        #[test]
458        fn test_does_not_log_or_fail_when_header_is_missing() {
459            let client_version = Version::new(1, 0, 0);
460            let (logger, log_inspector) = TestLogger::memory();
461            let client = AggregatorHttpClient::builder("http://whatever")
462                .with_logger(logger)
463                .build()
464                .unwrap();
465            let response =
466                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
467
468            client.warn_if_api_version_mismatch(&response, &client_version);
469
470            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
471        }
472
473        #[test]
474        fn test_does_not_log_or_fail_when_header_is_not_a_version() {
475            let client_version = Version::new(1, 0, 0);
476            let (logger, log_inspector) = TestLogger::memory();
477            let client = AggregatorHttpClient::builder("http://whatever")
478                .with_logger(logger)
479                .with_api_version_provider(Arc::new(APIVersionProvider::default()))
480                .build()
481                .unwrap();
482            let response =
483                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
484
485            client.warn_if_api_version_mismatch(&response, &client_version);
486
487            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
488        }
489
490        #[tokio::test]
491        async fn test_client_log_warning_if_api_version_mismatch() {
492            let aggregator_version = Version::new(2, 0, 0);
493            let client_version = Version::new(1, 0, 0);
494            let (server, mut client) = setup_server_and_client();
495            let (logger, log_inspector) = TestLogger::memory();
496            client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
497                client_version.clone(),
498            ));
499            client.logger = logger;
500            server.mock(|_, then| {
501                then.status(StatusCode::CREATED.as_u16())
502                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version.to_string());
503            });
504
505            assert!(aggregator_version > client_version);
506
507            client
508                .send(TestPostQuery {
509                    body: TestBody::new("miaouss", 3),
510                })
511                .await
512                .unwrap();
513
514            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
515        }
516    }
517}