mithril_aggregator/services/
aggregator_client.rs

1use anyhow::anyhow;
2use async_trait::async_trait;
3use mithril_common::messages::TryFromMessageAdapter;
4use reqwest::header::{self, HeaderValue};
5use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
6use slog::{debug, error, Logger};
7use std::{io, sync::Arc, time::Duration};
8use thiserror::Error;
9
10use mithril_common::{
11    api_version::APIVersionProvider,
12    entities::{ClientError, ServerError},
13    logging::LoggerExtensions,
14    messages::EpochSettingsMessage,
15    StdError, MITHRIL_AGGREGATOR_VERSION_HEADER, MITHRIL_API_VERSION_HEADER,
16};
17
18use crate::entities::LeaderAggregatorEpochSettings;
19use crate::message_adapters::FromEpochSettingsAdapter;
20
21const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
22
23/// Error structure for the Aggregator Client.
24#[derive(Error, Debug)]
25pub enum AggregatorClientError {
26    /// The aggregator host has returned a technical error.
27    #[error("remote server technical error")]
28    RemoteServerTechnical(#[source] StdError),
29
30    /// The aggregator host responded it cannot fulfill our request.
31    #[error("remote server logical error")]
32    RemoteServerLogical(#[source] StdError),
33
34    /// Could not reach aggregator.
35    #[error("remote server unreachable")]
36    RemoteServerUnreachable(#[source] StdError),
37
38    /// Unhandled status code
39    #[error("unhandled status code: {0}, response text: {1}")]
40    UnhandledStatusCode(StatusCode, String),
41
42    /// Could not parse response.
43    #[error("json parsing failed")]
44    JsonParseFailed(#[source] StdError),
45
46    /// Mostly network errors.
47    #[error("Input/Output error")]
48    IOError(#[from] io::Error),
49
50    /// Incompatible API version error
51    #[error("HTTP API version mismatch")]
52    ApiVersionMismatch(#[source] StdError),
53
54    /// HTTP client creation error
55    #[error("HTTP client creation failed")]
56    HTTPClientCreation(#[source] StdError),
57
58    /// Proxy creation error
59    #[error("proxy creation failed")]
60    ProxyCreation(#[source] StdError),
61
62    /// Adapter error
63    #[error("adapter failed")]
64    Adapter(#[source] StdError),
65}
66
67#[cfg(test)]
68/// convenient methods to error enum
69impl AggregatorClientError {
70    pub(crate) fn is_api_version_mismatch(&self) -> bool {
71        matches!(self, Self::ApiVersionMismatch(_))
72    }
73}
74
75impl AggregatorClientError {
76    /// Create an `AggregatorClientError` from a response.
77    ///
78    /// This method is meant to be used after handling domain-specific cases leaving only
79    /// 4xx or 5xx status codes.
80    /// Otherwise, it will return an `UnhandledStatusCode` error.
81    pub async fn from_response(response: Response) -> Self {
82        let error_code = response.status();
83
84        if error_code.is_client_error() {
85            let root_cause = Self::get_root_cause(response).await;
86            Self::RemoteServerLogical(anyhow!(root_cause))
87        } else if error_code.is_server_error() {
88            let root_cause = Self::get_root_cause(response).await;
89            Self::RemoteServerTechnical(anyhow!(root_cause))
90        } else {
91            let response_text = response.text().await.unwrap_or_default();
92            Self::UnhandledStatusCode(error_code, response_text)
93        }
94    }
95
96    async fn get_root_cause(response: Response) -> String {
97        let error_code = response.status();
98        let canonical_reason = error_code
99            .canonical_reason()
100            .unwrap_or_default()
101            .to_lowercase();
102        let is_json = response
103            .headers()
104            .get(header::CONTENT_TYPE)
105            .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
106
107        if is_json {
108            let json_value: serde_json::Value = response.json().await.unwrap_or_default();
109
110            if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
111                format!(
112                    "{}: {}: {}",
113                    canonical_reason, client_error.label, client_error.message
114                )
115            } else if let Ok(server_error) =
116                serde_json::from_value::<ServerError>(json_value.clone())
117            {
118                format!("{}: {}", canonical_reason, server_error.message)
119            } else if json_value.is_null() {
120                canonical_reason.to_string()
121            } else {
122                format!("{}: {}", canonical_reason, json_value)
123            }
124        } else {
125            let response_text = response.text().await.unwrap_or_default();
126            format!("{}: {}", canonical_reason, response_text)
127        }
128    }
129}
130
131/// Trait for mocking and testing a `AggregatorClient`
132#[cfg_attr(test, mockall::automock)]
133#[async_trait]
134pub trait AggregatorClient: Sync + Send {
135    /// Retrieves epoch settings from the aggregator
136    async fn retrieve_epoch_settings(
137        &self,
138    ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError>;
139}
140
141/// AggregatorHTTPClient is a http client for an aggregator
142pub struct AggregatorHTTPClient {
143    aggregator_endpoint: String,
144    relay_endpoint: Option<String>,
145    api_version_provider: Arc<APIVersionProvider>,
146    timeout_duration: Option<Duration>,
147    logger: Logger,
148}
149
150impl AggregatorHTTPClient {
151    /// AggregatorHTTPClient factory
152    pub fn new(
153        aggregator_endpoint: String,
154        relay_endpoint: Option<String>,
155        api_version_provider: Arc<APIVersionProvider>,
156        timeout_duration: Option<Duration>,
157        logger: Logger,
158    ) -> Self {
159        let logger = logger.new_with_component_name::<Self>();
160        debug!(logger, "New AggregatorHTTPClient created");
161        Self {
162            aggregator_endpoint,
163            relay_endpoint,
164            api_version_provider,
165            timeout_duration,
166            logger,
167        }
168    }
169
170    fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
171        let client = match &self.relay_endpoint {
172            Some(relay_endpoint) => Client::builder()
173                .proxy(
174                    Proxy::all(relay_endpoint)
175                        .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
176                )
177                .build()
178                .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
179            None => Client::new(),
180        };
181
182        Ok(client)
183    }
184
185    /// Forge a client request adding protocol version in the headers.
186    pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
187        let request_builder = request_builder
188            .header(
189                MITHRIL_API_VERSION_HEADER,
190                self.api_version_provider
191                    .compute_current_version()
192                    .unwrap()
193                    .to_string(),
194            )
195            .header(MITHRIL_AGGREGATOR_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
196
197        if let Some(duration) = self.timeout_duration {
198            request_builder.timeout(duration)
199        } else {
200            request_builder
201        }
202    }
203
204    /// API version error handling
205    fn handle_api_error(&self, response: &Response) -> AggregatorClientError {
206        if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
207            AggregatorClientError::ApiVersionMismatch(anyhow!(
208                "server version: '{}', signer version: '{}'",
209                version.to_str().unwrap(),
210                self.api_version_provider.compute_current_version().unwrap()
211            ))
212        } else {
213            AggregatorClientError::ApiVersionMismatch(anyhow!(
214                "version precondition failed, sent version '{}'.",
215                self.api_version_provider.compute_current_version().unwrap()
216            ))
217        }
218    }
219}
220
221#[async_trait]
222impl AggregatorClient for AggregatorHTTPClient {
223    async fn retrieve_epoch_settings(
224        &self,
225    ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
226        debug!(self.logger, "Retrieve epoch settings");
227        let url = format!("{}/epoch-settings", self.aggregator_endpoint);
228        let response = self
229            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
230            .send()
231            .await;
232
233        match response {
234            Ok(response) => match response.status() {
235                StatusCode::OK => match response.json::<EpochSettingsMessage>().await {
236                    Ok(message) => {
237                        let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
238                            .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
239                        Ok(Some(epoch_settings))
240                    }
241                    Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
242                },
243                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
244                _ => Err(AggregatorClientError::from_response(response).await),
245            },
246            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
247        }
248    }
249}
250
251#[cfg(test)]
252pub(crate) mod dumb {
253    use tokio::sync::RwLock;
254
255    use super::*;
256
257    /// This aggregator client is intended to be used by test services.
258    /// It actually does not communicate with an aggregator host but mimics this behavior.
259    /// It is driven by a Tester that controls the data it can return, and it can return its internal state for testing.
260    pub struct DumbAggregatorClient {
261        epoch_settings: RwLock<Option<LeaderAggregatorEpochSettings>>,
262    }
263
264    impl Default for DumbAggregatorClient {
265        fn default() -> Self {
266            Self {
267                epoch_settings: RwLock::new(Some(LeaderAggregatorEpochSettings::dummy())),
268            }
269        }
270    }
271
272    #[async_trait]
273    impl AggregatorClient for DumbAggregatorClient {
274        async fn retrieve_epoch_settings(
275            &self,
276        ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
277            let epoch_settings = self.epoch_settings.read().await.clone();
278
279            Ok(epoch_settings)
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use http::response::Builder as HttpResponseBuilder;
287    use httpmock::prelude::*;
288    use serde_json::json;
289
290    use mithril_common::entities::Epoch;
291    use mithril_common::era::{EraChecker, SupportedEra};
292
293    use crate::test_tools::TestLogger;
294
295    use super::*;
296
297    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
298        let server = MockServer::start();
299        let aggregator_endpoint = server.url("");
300        let relay_endpoint = None;
301        let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
302        let api_version_provider = APIVersionProvider::new(Arc::new(era_checker));
303
304        (
305            server,
306            AggregatorHTTPClient::new(
307                aggregator_endpoint,
308                relay_endpoint,
309                Arc::new(api_version_provider),
310                None,
311                TestLogger::stdout(),
312            ),
313        )
314    }
315
316    fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
317        HttpResponseBuilder::new()
318            .status(status_code)
319            .body(body.into())
320            .unwrap()
321            .into()
322    }
323
324    fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
325        HttpResponseBuilder::new()
326            .status(status_code)
327            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
328            .body(serde_json::to_string(&body).unwrap())
329            .unwrap()
330            .into()
331    }
332
333    macro_rules! assert_error_text_contains {
334        ($error: expr, $expect_contains: expr) => {
335            let error = &$error;
336            assert!(
337                error.contains($expect_contains),
338                "Expected error message to contain '{}'\ngot '{error:?}'",
339                $expect_contains,
340            );
341        };
342    }
343
344    #[tokio::test]
345    async fn test_epoch_settings_ok_200() {
346        let (server, client) = setup_server_and_client();
347        let epoch_settings_expected = EpochSettingsMessage::dummy();
348        let _server_mock = server.mock(|when, then| {
349            when.path("/epoch-settings");
350            then.status(200)
351                .body(json!(epoch_settings_expected).to_string());
352        });
353
354        let epoch_settings = client.retrieve_epoch_settings().await;
355        epoch_settings.as_ref().expect("unexpected error");
356        assert_eq!(
357            FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
358            epoch_settings.unwrap().unwrap()
359        );
360    }
361
362    #[tokio::test]
363    async fn test_epoch_settings_ko_412() {
364        let (server, client) = setup_server_and_client();
365        let _server_mock = server.mock(|when, then| {
366            when.path("/epoch-settings");
367            then.status(412)
368                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
369        });
370
371        let epoch_settings = client.retrieve_epoch_settings().await.unwrap_err();
372
373        assert!(epoch_settings.is_api_version_mismatch());
374    }
375
376    #[tokio::test]
377    async fn test_epoch_settings_ko_500() {
378        let (server, client) = setup_server_and_client();
379        let _server_mock = server.mock(|when, then| {
380            when.path("/epoch-settings");
381            then.status(500).body("an error occurred");
382        });
383
384        match client.retrieve_epoch_settings().await.unwrap_err() {
385            AggregatorClientError::RemoteServerTechnical(_) => (),
386            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
387        };
388    }
389
390    #[tokio::test]
391    async fn test_epoch_settings_timeout() {
392        let (server, mut client) = setup_server_and_client();
393        client.timeout_duration = Some(Duration::from_millis(10));
394        let _server_mock = server.mock(|when, then| {
395            when.path("/epoch-settings");
396            then.delay(Duration::from_millis(100));
397        });
398
399        let error = client
400            .retrieve_epoch_settings()
401            .await
402            .expect_err("retrieve_epoch_settings should fail");
403
404        assert!(
405            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
406            "unexpected error type: {error:?}"
407        );
408    }
409
410    #[tokio::test]
411    async fn test_4xx_errors_are_handled_as_remote_server_logical() {
412        let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
413        let handled_error = AggregatorClientError::from_response(response).await;
414
415        assert!(
416            matches!(
417                handled_error,
418                AggregatorClientError::RemoteServerLogical(..)
419            ),
420            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
421        );
422    }
423
424    #[tokio::test]
425    async fn test_5xx_errors_are_handled_as_remote_server_technical() {
426        let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
427        let handled_error = AggregatorClientError::from_response(response).await;
428
429        assert!(
430            matches!(
431                handled_error,
432                AggregatorClientError::RemoteServerTechnical(..)
433            ),
434            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
435        );
436    }
437
438    #[tokio::test]
439    async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text(
440    ) {
441        let response = build_text_response(StatusCode::OK, "ok text");
442        let handled_error = AggregatorClientError::from_response(response).await;
443
444        assert!(
445            matches!(
446                handled_error,
447                AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
448            ),
449            "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
450        );
451    }
452
453    #[tokio::test]
454    async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
455        let error_text = "An error occurred; please try again later.";
456        let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
457
458        assert_error_text_contains!(
459            AggregatorClientError::get_root_cause(response).await,
460            "expectation failed: An error occurred; please try again later."
461        );
462    }
463
464    #[tokio::test]
465    async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message(
466    ) {
467        let client_error = ClientError::new("label", "message");
468        let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
469
470        assert_error_text_contains!(
471            AggregatorClientError::get_root_cause(response).await,
472            "bad request: label: message"
473        );
474    }
475
476    #[tokio::test]
477    async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message(
478    ) {
479        let server_error = ServerError::new("message");
480        let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
481
482        assert_error_text_contains!(
483            AggregatorClientError::get_root_cause(response).await,
484            "bad request: message"
485        );
486    }
487
488    #[tokio::test]
489    async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
490        let response = build_json_response(
491            StatusCode::INTERNAL_SERVER_ERROR,
492            &json!({ "second": "unknown", "first": "foreign" }),
493        );
494
495        assert_error_text_contains!(
496            AggregatorClientError::get_root_cause(response).await,
497            r#"internal server error: {"first":"foreign","second":"unknown"}"#
498        );
499    }
500
501    #[tokio::test]
502    async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
503        let response = HttpResponseBuilder::new()
504            .status(StatusCode::BAD_REQUEST)
505            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
506            .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
507            .unwrap()
508            .into();
509
510        let root_cause = AggregatorClientError::get_root_cause(response).await;
511
512        assert_error_text_contains!(root_cause, "bad request");
513        assert!(
514            !root_cause.contains("bad request: "),
515            "Expected error message should not contain additional information \ngot '{root_cause:?}'"
516        );
517    }
518}