mithril_aggregator/services/
aggregator_client.rs

1use anyhow::{Context, anyhow};
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode, Url};
5
6use semver::Version;
7use slog::{Logger, debug, error, warn};
8use std::{io, sync::Arc, time::Duration};
9use thiserror::Error;
10
11use mithril_common::{
12    MITHRIL_AGGREGATOR_VERSION_HEADER, MITHRIL_API_VERSION_HEADER, StdError, StdResult,
13    api_version::APIVersionProvider,
14    certificate_chain::{CertificateRetriever, CertificateRetrieverError},
15    entities::{Certificate, ClientError, ServerError},
16    logging::LoggerExtensions,
17    messages::{
18        CertificateListMessage, CertificateMessage, EpochSettingsMessage, TryFromMessageAdapter,
19    },
20};
21
22use crate::entities::LeaderAggregatorEpochSettings;
23use crate::message_adapters::FromEpochSettingsAdapter;
24use crate::services::{LeaderAggregatorClient, RemoteCertificateRetriever};
25
26const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
27
28const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
29    "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
30
31/// Error structure for the Aggregator Client.
32#[derive(Error, Debug)]
33pub enum AggregatorClientError {
34    /// The aggregator host has returned a technical error.
35    #[error("remote server technical error")]
36    RemoteServerTechnical(#[source] StdError),
37
38    /// The aggregator host responded it cannot fulfill our request.
39    #[error("remote server logical error")]
40    RemoteServerLogical(#[source] StdError),
41
42    /// Could not reach aggregator.
43    #[error("remote server unreachable")]
44    RemoteServerUnreachable(#[source] StdError),
45
46    /// Unhandled status code
47    #[error("unhandled status code: {0}, response text: {1}")]
48    UnhandledStatusCode(StatusCode, String),
49
50    /// Could not parse response.
51    #[error("json parsing failed")]
52    JsonParseFailed(#[source] StdError),
53
54    /// Mostly network errors.
55    #[error("Input/Output error")]
56    IOError(#[from] io::Error),
57
58    /// HTTP client creation error
59    #[error("HTTP client creation failed")]
60    HTTPClientCreation(#[source] StdError),
61
62    /// Proxy creation error
63    #[error("proxy creation failed")]
64    ProxyCreation(#[source] StdError),
65
66    /// Adapter error
67    #[error("adapter failed")]
68    Adapter(#[source] StdError),
69}
70
71impl AggregatorClientError {
72    /// Create an `AggregatorClientError` from a response.
73    ///
74    /// This method is meant to be used after handling domain-specific cases leaving only
75    /// 4xx or 5xx status codes.
76    /// Otherwise, it will return an `UnhandledStatusCode` error.
77    pub async fn from_response(response: Response) -> Self {
78        let error_code = response.status();
79
80        if error_code.is_client_error() {
81            let root_cause = Self::get_root_cause(response).await;
82            Self::RemoteServerLogical(anyhow!(root_cause))
83        } else if error_code.is_server_error() {
84            let root_cause = Self::get_root_cause(response).await;
85            Self::RemoteServerTechnical(anyhow!(root_cause))
86        } else {
87            let response_text = response.text().await.unwrap_or_default();
88            Self::UnhandledStatusCode(error_code, response_text)
89        }
90    }
91
92    async fn get_root_cause(response: Response) -> String {
93        let error_code = response.status();
94        let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
95        let is_json = response
96            .headers()
97            .get(header::CONTENT_TYPE)
98            .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
99
100        if is_json {
101            let json_value: serde_json::Value = response.json().await.unwrap_or_default();
102
103            if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
104                format!(
105                    "{}: {}: {}",
106                    canonical_reason, client_error.label, client_error.message
107                )
108            } else if let Ok(server_error) =
109                serde_json::from_value::<ServerError>(json_value.clone())
110            {
111                format!("{}: {}", canonical_reason, server_error.message)
112            } else if json_value.is_null() {
113                canonical_reason.to_string()
114            } else {
115                format!("{canonical_reason}: {json_value}")
116            }
117        } else {
118            let response_text = response.text().await.unwrap_or_default();
119            format!("{canonical_reason}: {response_text}")
120        }
121    }
122}
123
124/// AggregatorHTTPClient is a http client for an aggregator
125pub struct AggregatorHTTPClient {
126    aggregator_endpoint: Url,
127    relay_endpoint: Option<String>,
128    api_version_provider: Arc<APIVersionProvider>,
129    timeout_duration: Option<Duration>,
130    logger: Logger,
131}
132
133impl AggregatorHTTPClient {
134    /// AggregatorHTTPClient factory
135    pub fn new(
136        aggregator_endpoint: Url,
137        relay_endpoint: Option<String>,
138        api_version_provider: Arc<APIVersionProvider>,
139        timeout_duration: Option<Duration>,
140        logger: Logger,
141    ) -> Self {
142        let logger = logger.new_with_component_name::<Self>();
143        debug!(logger, "New AggregatorHTTPClient created");
144
145        // Trailing slash is significant because url::join
146        // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove
147        // the 'path' part of the url if it doesn't end with a trailing slash.
148        let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
149            aggregator_endpoint
150        } else {
151            let mut url = aggregator_endpoint.clone();
152            url.set_path(&format!("{}/", aggregator_endpoint.path()));
153            url
154        };
155
156        Self {
157            aggregator_endpoint,
158            relay_endpoint,
159            api_version_provider,
160            timeout_duration,
161            logger,
162        }
163    }
164
165    fn join_aggregator_endpoint(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
166        self.aggregator_endpoint
167            .join(endpoint)
168            .with_context(|| {
169                format!(
170                    "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
171                    self.aggregator_endpoint
172                )
173            })
174            .map_err(AggregatorClientError::HTTPClientCreation)
175    }
176
177    fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
178        let client = match &self.relay_endpoint {
179            Some(relay_endpoint) => Client::builder()
180                .proxy(
181                    Proxy::all(relay_endpoint)
182                        .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
183                )
184                .build()
185                .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
186            None => Client::new(),
187        };
188
189        Ok(client)
190    }
191
192    /// Forge a client request adding protocol version in the headers.
193    pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
194        let request_builder = request_builder
195            .header(
196                MITHRIL_API_VERSION_HEADER,
197                self.api_version_provider
198                    .compute_current_version()
199                    .unwrap()
200                    .to_string(),
201            )
202            .header(MITHRIL_AGGREGATOR_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
203
204        if let Some(duration) = self.timeout_duration {
205            request_builder.timeout(duration)
206        } else {
207            request_builder
208        }
209    }
210
211    /// Check API version mismatch and log a warning if the leader aggregator's version is more recent.
212    fn warn_if_api_version_mismatch(&self, response: &Response) {
213        let leader_version = response
214            .headers()
215            .get(MITHRIL_API_VERSION_HEADER)
216            .and_then(|v| v.to_str().ok())
217            .and_then(|s| Version::parse(s).ok());
218
219        let follower_version = self.api_version_provider.compute_current_version();
220
221        match (leader_version, follower_version) {
222            (Some(leader), Ok(follower)) if follower < leader => {
223                warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
224                    "leader_aggregator_version" => %leader,
225                    "aggregator_version" => %follower,
226                );
227            }
228            (Some(_), Err(error)) => {
229                error!(
230                    self.logger,
231                    "Failed to compute the current aggregator API version";
232                    "error" => error.to_string()
233                );
234            }
235            _ => {}
236        }
237    }
238}
239
240// Route specifics methods
241impl AggregatorHTTPClient {
242    async fn epoch_settings(
243        &self,
244    ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
245        debug!(self.logger, "Retrieve epoch settings");
246        let url = self.join_aggregator_endpoint("epoch-settings")?;
247        let response = self
248            .prepare_request_builder(self.prepare_http_client()?.get(url))
249            .send()
250            .await;
251
252        match response {
253            Ok(response) => match response.status() {
254                StatusCode::OK => {
255                    self.warn_if_api_version_mismatch(&response);
256                    match response.json::<EpochSettingsMessage>().await {
257                        Ok(message) => {
258                            let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
259                                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
260                            Ok(Some(epoch_settings))
261                        }
262                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
263                    }
264                }
265                _ => Err(AggregatorClientError::from_response(response).await),
266            },
267            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
268        }
269    }
270
271    async fn latest_certificates_list(
272        &self,
273    ) -> Result<CertificateListMessage, AggregatorClientError> {
274        debug!(self.logger, "Retrieve latest certificates list");
275        let url = self.join_aggregator_endpoint("certificates")?;
276        let response = self
277            .prepare_request_builder(self.prepare_http_client()?.get(url))
278            .send()
279            .await;
280
281        match response {
282            Ok(response) => match response.status() {
283                StatusCode::OK => {
284                    self.warn_if_api_version_mismatch(&response);
285                    match response.json::<CertificateListMessage>().await {
286                        Ok(message) => Ok(message),
287                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
288                    }
289                }
290                _ => Err(AggregatorClientError::from_response(response).await),
291            },
292            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
293        }
294    }
295
296    async fn certificate_details(
297        &self,
298        certificate_hash: &str,
299    ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
300        debug!(self.logger, "Retrieve certificate details"; "certificate_hash" => %certificate_hash);
301        let url = self.join_aggregator_endpoint(&format!("certificate/{certificate_hash}"))?;
302        let response = self
303            .prepare_request_builder(self.prepare_http_client()?.get(url))
304            .send()
305            .await;
306
307        match response {
308            Ok(response) => match response.status() {
309                StatusCode::OK => {
310                    self.warn_if_api_version_mismatch(&response);
311                    match response.json::<CertificateMessage>().await {
312                        Ok(message) => Ok(Some(message)),
313                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
314                    }
315                }
316                StatusCode::NOT_FOUND => Ok(None),
317                _ => Err(AggregatorClientError::from_response(response).await),
318            },
319            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
320        }
321    }
322
323    async fn latest_genesis_certificate(
324        &self,
325    ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
326        self.certificate_details("genesis").await
327    }
328}
329
330#[async_trait]
331impl LeaderAggregatorClient for AggregatorHTTPClient {
332    async fn retrieve_epoch_settings(&self) -> StdResult<Option<LeaderAggregatorEpochSettings>> {
333        let epoch_settings = self.epoch_settings().await?;
334        Ok(epoch_settings)
335    }
336}
337
338#[async_trait]
339impl CertificateRetriever for AggregatorHTTPClient {
340    async fn get_certificate_details(
341        &self,
342        certificate_hash: &str,
343    ) -> Result<Certificate, CertificateRetrieverError> {
344        let message = self
345            .certificate_details(certificate_hash)
346            .await
347            .with_context(|| {
348                format!("Failed to retrieve certificate with hash: '{certificate_hash}'")
349            })
350            .map_err(CertificateRetrieverError)?
351            .ok_or(CertificateRetrieverError(anyhow!(
352                "Certificate does not exist: '{certificate_hash}'"
353            )))?;
354
355        message.try_into().map_err(CertificateRetrieverError)
356    }
357}
358
359#[async_trait]
360impl RemoteCertificateRetriever for AggregatorHTTPClient {
361    async fn get_latest_certificate_details(&self) -> StdResult<Option<Certificate>> {
362        let latest_certificates_list = self.latest_certificates_list().await?;
363
364        match latest_certificates_list.first() {
365            None => Ok(None),
366            Some(latest_certificate_list_item) => {
367                let latest_certificate_message =
368                    self.certificate_details(&latest_certificate_list_item.hash).await?;
369                latest_certificate_message.map(TryInto::try_into).transpose()
370            }
371        }
372    }
373
374    async fn get_genesis_certificate_details(&self) -> StdResult<Option<Certificate>> {
375        match self.latest_genesis_certificate().await? {
376            Some(message) => Ok(Some(message.try_into()?)),
377            None => Ok(None),
378        }
379    }
380}
381
382#[cfg(test)]
383pub(crate) mod dumb {
384    use tokio::sync::RwLock;
385
386    use mithril_common::test::double::Dummy;
387
388    use super::*;
389
390    /// This aggregator client is intended to be used by test services.
391    /// It actually does not communicate with an aggregator host but mimics this behavior.
392    /// It is driven by a Tester that controls the data it can return, and it can return its internal state for testing.
393    pub struct DumbAggregatorClient {
394        epoch_settings: RwLock<Option<LeaderAggregatorEpochSettings>>,
395    }
396
397    impl Default for DumbAggregatorClient {
398        fn default() -> Self {
399            Self {
400                epoch_settings: RwLock::new(Some(LeaderAggregatorEpochSettings::dummy())),
401            }
402        }
403    }
404
405    #[async_trait]
406    impl LeaderAggregatorClient for DumbAggregatorClient {
407        async fn retrieve_epoch_settings(
408            &self,
409        ) -> StdResult<Option<LeaderAggregatorEpochSettings>> {
410            let epoch_settings = self.epoch_settings.read().await.clone();
411
412            Ok(epoch_settings)
413        }
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use http::response::Builder as HttpResponseBuilder;
420    use httpmock::prelude::*;
421    use reqwest::IntoUrl;
422    use serde_json::json;
423
424    use mithril_common::messages::CertificateListItemMessage;
425    use mithril_common::test::double::{Dummy, DummyApiVersionDiscriminantSource};
426
427    use crate::test::TestLogger;
428
429    use super::*;
430
431    fn setup_client<U: IntoUrl>(server_url: U) -> AggregatorHTTPClient {
432        let discriminant_source = DummyApiVersionDiscriminantSource::default();
433        let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
434
435        AggregatorHTTPClient::new(
436            server_url.into_url().unwrap(),
437            None,
438            Arc::new(api_version_provider),
439            None,
440            TestLogger::stdout(),
441        )
442    }
443
444    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
445        let server = MockServer::start();
446        let aggregator_endpoint = server.url("");
447        let client = setup_client(&aggregator_endpoint);
448
449        (server, client)
450    }
451
452    fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
453        HttpResponseBuilder::new()
454            .status(status_code)
455            .body(body.into())
456            .unwrap()
457            .into()
458    }
459
460    fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
461        HttpResponseBuilder::new()
462            .status(status_code)
463            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
464            .body(serde_json::to_string(&body).unwrap())
465            .unwrap()
466            .into()
467    }
468
469    macro_rules! assert_error_text_contains {
470        ($error: expr, $expect_contains: expr) => {
471            let error = &$error;
472            assert!(
473                error.contains($expect_contains),
474                "Expected error message to contain '{}'\ngot '{error:?}'",
475                $expect_contains,
476            );
477        };
478    }
479
480    #[tokio::test]
481    async fn test_epoch_settings_ok_200() {
482        let (server, client) = setup_server_and_client();
483        let epoch_settings_expected = EpochSettingsMessage::dummy();
484        let _server_mock = server.mock(|when, then| {
485            when.path("/epoch-settings");
486            then.status(200).body(json!(epoch_settings_expected).to_string());
487        });
488
489        let epoch_settings = client.retrieve_epoch_settings().await;
490        epoch_settings.as_ref().expect("unexpected error");
491        assert_eq!(
492            FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
493            epoch_settings.unwrap().unwrap()
494        );
495    }
496
497    #[tokio::test]
498    async fn test_epoch_settings_ko_500() {
499        let (server, client) = setup_server_and_client();
500        let _server_mock = server.mock(|when, then| {
501            when.path("/epoch-settings");
502            then.status(500).body("an error occurred");
503        });
504
505        match client.epoch_settings().await.unwrap_err() {
506            AggregatorClientError::RemoteServerTechnical(_) => (),
507            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
508        };
509    }
510
511    #[tokio::test]
512    async fn test_epoch_settings_timeout() {
513        let (server, mut client) = setup_server_and_client();
514        client.timeout_duration = Some(Duration::from_millis(10));
515        let _server_mock = server.mock(|when, then| {
516            when.path("/epoch-settings");
517            then.delay(Duration::from_millis(100));
518        });
519
520        let error = client
521            .epoch_settings()
522            .await
523            .expect_err("retrieve_epoch_settings should fail");
524
525        assert!(
526            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
527            "unexpected error type: {error:?}"
528        );
529    }
530
531    #[tokio::test]
532    async fn test_latest_certificates_list_ok_200() {
533        let (server, client) = setup_server_and_client();
534        let expected_list = vec![
535            CertificateListItemMessage::dummy(),
536            CertificateListItemMessage::dummy(),
537        ];
538        let _server_mock = server.mock(|when, then| {
539            when.path("/certificates");
540            then.status(200).body(json!(expected_list).to_string());
541        });
542
543        let fetched_list = client.latest_certificates_list().await.unwrap();
544
545        assert_eq!(expected_list, fetched_list);
546    }
547
548    #[tokio::test]
549    async fn test_latest_certificates_list_ko_500() {
550        let (server, client) = setup_server_and_client();
551        let _server_mock = server.mock(|when, then| {
552            when.path("/certificates");
553            then.status(500).body("an error occurred");
554        });
555
556        match client.latest_certificates_list().await.unwrap_err() {
557            AggregatorClientError::RemoteServerTechnical(_) => (),
558            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
559        };
560    }
561
562    #[tokio::test]
563    async fn test_latest_certificates_list_timeout() {
564        let (server, mut client) = setup_server_and_client();
565        client.timeout_duration = Some(Duration::from_millis(10));
566        let _server_mock = server.mock(|when, then| {
567            when.path("/certificates");
568            then.delay(Duration::from_millis(100));
569        });
570
571        let error = client
572            .latest_certificates_list()
573            .await
574            .expect_err("retrieve_epoch_settings should fail");
575
576        assert!(
577            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
578            "unexpected error type: {error:?}"
579        );
580    }
581
582    #[tokio::test]
583    async fn test_certificates_details_ok_200() {
584        let (server, client) = setup_server_and_client();
585        let expected_message = CertificateMessage::dummy();
586        let _server_mock = server.mock(|when, then| {
587            when.path(format!("/certificate/{}", expected_message.hash));
588            then.status(200).body(json!(expected_message).to_string());
589        });
590
591        let fetched_message = client.certificate_details(&expected_message.hash).await.unwrap();
592
593        assert_eq!(Some(expected_message), fetched_message);
594    }
595
596    #[tokio::test]
597    async fn test_certificates_details_ok_404() {
598        let (server, client) = setup_server_and_client();
599        let _server_mock = server.mock(|when, then| {
600            when.path("/certificate/not-found");
601            then.status(404);
602        });
603
604        let fetched_message = client.latest_genesis_certificate().await.unwrap();
605
606        assert_eq!(None, fetched_message);
607    }
608
609    #[tokio::test]
610    async fn test_certificates_details_ko_500() {
611        let (server, client) = setup_server_and_client();
612        let _server_mock = server.mock(|when, then| {
613            when.path("/certificate/whatever");
614            then.status(500).body("an error occurred");
615        });
616
617        match client.certificate_details("whatever").await.unwrap_err() {
618            AggregatorClientError::RemoteServerTechnical(_) => (),
619            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
620        };
621    }
622
623    #[tokio::test]
624    async fn test_certificates_details_timeout() {
625        let (server, mut client) = setup_server_and_client();
626        client.timeout_duration = Some(Duration::from_millis(10));
627        let _server_mock = server.mock(|when, then| {
628            when.path("/certificate/whatever");
629            then.delay(Duration::from_millis(100));
630        });
631
632        let error = client
633            .certificate_details("whatever")
634            .await
635            .expect_err("retrieve_epoch_settings should fail");
636
637        assert!(
638            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
639            "unexpected error type: {error:?}"
640        );
641    }
642
643    #[tokio::test]
644    async fn test_latest_genesis_ok_200() {
645        let (server, client) = setup_server_and_client();
646        let genesis_message = CertificateMessage::dummy();
647        let _server_mock = server.mock(|when, then| {
648            when.path("/certificate/genesis");
649            then.status(200).body(json!(genesis_message).to_string());
650        });
651
652        let fetched = client.latest_genesis_certificate().await.unwrap();
653
654        assert_eq!(Some(genesis_message), fetched);
655    }
656
657    #[tokio::test]
658    async fn test_latest_genesis_ok_404() {
659        let (server, client) = setup_server_and_client();
660        let _server_mock = server.mock(|when, then| {
661            when.path("/certificate/genesis");
662            then.status(404);
663        });
664
665        let fetched = client.latest_genesis_certificate().await.unwrap();
666
667        assert_eq!(None, fetched);
668    }
669
670    #[tokio::test]
671    async fn test_latest_genesis_ko_500() {
672        let (server, client) = setup_server_and_client();
673        let _server_mock = server.mock(|when, then| {
674            when.path("/certificate/genesis");
675            then.status(500).body("an error occurred");
676        });
677
678        let error = client.latest_genesis_certificate().await.unwrap_err();
679
680        assert!(
681            matches!(error, AggregatorClientError::RemoteServerTechnical(_)),
682            "Expected Aggregator::RemoteServerTechnical error, got {error:?}"
683        );
684    }
685
686    #[tokio::test]
687    async fn test_latest_genesis_timeout() {
688        let (server, mut client) = setup_server_and_client();
689        client.timeout_duration = Some(Duration::from_millis(10));
690        let _server_mock = server.mock(|when, then| {
691            when.path("/certificate/genesis");
692            then.delay(Duration::from_millis(100));
693        });
694
695        let error = client.latest_genesis_certificate().await.unwrap_err();
696
697        assert!(
698            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
699            "unexpected error type: {error:?}"
700        );
701    }
702
703    #[tokio::test]
704    async fn test_4xx_errors_are_handled_as_remote_server_logical() {
705        let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
706        let handled_error = AggregatorClientError::from_response(response).await;
707
708        assert!(
709            matches!(
710                handled_error,
711                AggregatorClientError::RemoteServerLogical(..)
712            ),
713            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
714        );
715    }
716
717    #[tokio::test]
718    async fn test_5xx_errors_are_handled_as_remote_server_technical() {
719        let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
720        let handled_error = AggregatorClientError::from_response(response).await;
721
722        assert!(
723            matches!(
724                handled_error,
725                AggregatorClientError::RemoteServerTechnical(..)
726            ),
727            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
728        );
729    }
730
731    #[tokio::test]
732    async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
733     {
734        let response = build_text_response(StatusCode::OK, "ok text");
735        let handled_error = AggregatorClientError::from_response(response).await;
736
737        assert!(
738            matches!(
739                handled_error,
740                AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
741            ),
742            "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
743        );
744    }
745
746    #[tokio::test]
747    async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
748        let error_text = "An error occurred; please try again later.";
749        let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
750
751        assert_error_text_contains!(
752            AggregatorClientError::get_root_cause(response).await,
753            "expectation failed: An error occurred; please try again later."
754        );
755    }
756
757    #[tokio::test]
758    async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
759     {
760        let client_error = ClientError::new("label", "message");
761        let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
762
763        assert_error_text_contains!(
764            AggregatorClientError::get_root_cause(response).await,
765            "bad request: label: message"
766        );
767    }
768
769    #[tokio::test]
770    async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
771     {
772        let server_error = ServerError::new("message");
773        let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
774
775        assert_error_text_contains!(
776            AggregatorClientError::get_root_cause(response).await,
777            "bad request: message"
778        );
779    }
780
781    #[tokio::test]
782    async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
783        let response = build_json_response(
784            StatusCode::INTERNAL_SERVER_ERROR,
785            &json!({ "second": "unknown", "first": "foreign" }),
786        );
787
788        assert_error_text_contains!(
789            AggregatorClientError::get_root_cause(response).await,
790            r#"internal server error: {"first":"foreign","second":"unknown"}"#
791        );
792    }
793
794    #[tokio::test]
795    async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
796        let response = HttpResponseBuilder::new()
797            .status(StatusCode::BAD_REQUEST)
798            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
799            .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
800            .unwrap()
801            .into();
802
803        let root_cause = AggregatorClientError::get_root_cause(response).await;
804
805        assert_error_text_contains!(root_cause, "bad request");
806        assert!(
807            !root_cause.contains("bad request: "),
808            "Expected error message should not contain additional information \ngot '{root_cause:?}'"
809        );
810    }
811
812    mod warn_if_api_version_mismatch {
813        use std::collections::HashMap;
814
815        use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
816        use mithril_common::test::logging::MemoryDrainForTestInspector;
817
818        use super::*;
819
820        fn version_provider_with_open_api_version<V: Into<String>>(
821            version: V,
822        ) -> APIVersionProvider {
823            let mut version_provider = version_provider_without_open_api_version();
824            let mut open_api_versions = HashMap::new();
825            open_api_versions.insert(
826                "openapi.yaml".to_string(),
827                Version::parse(&version.into()).unwrap(),
828            );
829            version_provider.update_open_api_versions(open_api_versions);
830
831            version_provider
832        }
833
834        fn version_provider_without_open_api_version() -> APIVersionProvider {
835            let mut version_provider =
836                APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::default()));
837            version_provider.update_open_api_versions(HashMap::new());
838
839            version_provider
840        }
841
842        fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
843            key: K,
844            value: V,
845        ) -> Response {
846            HttpResponseBuilder::new()
847                .header(key.into(), value.into())
848                .body("whatever")
849                .unwrap()
850                .into()
851        }
852
853        fn assert_api_version_warning_logged<L: Into<String>, A: Into<String>>(
854            log_inspector: &MemoryDrainForTestInspector,
855            leader_aggregator_version: L,
856            aggregator_version: A,
857        ) {
858            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
859            assert!(log_inspector.contains_log(&format!(
860                "leader_aggregator_version={}",
861                leader_aggregator_version.into()
862            )));
863            assert!(
864                log_inspector
865                    .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
866            );
867        }
868
869        #[test]
870        fn test_logs_warning_when_leader_aggregator_api_version_is_newer() {
871            let leader_aggregator_version = "2.0.0";
872            let aggregator_version = "1.0.0";
873            let (logger, log_inspector) = TestLogger::memory();
874            let version_provider = version_provider_with_open_api_version(aggregator_version);
875            let mut client = setup_client("http://whatever");
876            client.api_version_provider = Arc::new(version_provider);
877            client.logger = logger;
878            let response = build_fake_response_with_header(
879                MITHRIL_API_VERSION_HEADER,
880                leader_aggregator_version,
881            );
882
883            assert!(
884                Version::parse(leader_aggregator_version).unwrap()
885                    > Version::parse(aggregator_version).unwrap()
886            );
887
888            client.warn_if_api_version_mismatch(&response);
889
890            assert_api_version_warning_logged(
891                &log_inspector,
892                leader_aggregator_version,
893                aggregator_version,
894            );
895        }
896
897        #[test]
898        fn test_no_warning_logged_when_versions_match() {
899            let version = "1.0.0";
900            let (logger, log_inspector) = TestLogger::memory();
901            let version_provider = version_provider_with_open_api_version(version);
902            let mut client = setup_client("http://whatever");
903            client.api_version_provider = Arc::new(version_provider);
904            client.logger = logger;
905            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
906
907            client.warn_if_api_version_mismatch(&response);
908
909            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
910        }
911
912        #[test]
913        fn test_no_warning_logged_when_leader_aggregator_api_version_is_older() {
914            let leader_aggregator_version = "1.0.0";
915            let aggregator_version = "2.0.0";
916            let (logger, log_inspector) = TestLogger::memory();
917            let version_provider = version_provider_with_open_api_version(aggregator_version);
918            let mut client = setup_client("http://whatever");
919            client.api_version_provider = Arc::new(version_provider);
920            client.logger = logger;
921            let response = build_fake_response_with_header(
922                MITHRIL_API_VERSION_HEADER,
923                leader_aggregator_version,
924            );
925
926            assert!(
927                Version::parse(leader_aggregator_version).unwrap()
928                    < Version::parse(aggregator_version).unwrap()
929            );
930
931            client.warn_if_api_version_mismatch(&response);
932
933            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
934        }
935
936        #[test]
937        fn test_does_not_log_or_fail_when_header_is_missing() {
938            let (logger, log_inspector) = TestLogger::memory();
939            let mut client = setup_client("http://whatever");
940            client.logger = logger;
941            let response =
942                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
943
944            client.warn_if_api_version_mismatch(&response);
945
946            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
947        }
948
949        #[test]
950        fn test_does_not_log_or_fail_when_header_is_not_a_version() {
951            let (logger, log_inspector) = TestLogger::memory();
952            let mut client = setup_client("http://whatever");
953            client.logger = logger;
954            let response =
955                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
956
957            client.warn_if_api_version_mismatch(&response);
958
959            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
960        }
961
962        #[test]
963        fn test_logs_error_when_aggregator_version_cannot_be_computed() {
964            let (logger, log_inspector) = TestLogger::memory();
965            let version_provider = version_provider_without_open_api_version();
966            let mut client = setup_client("http://whatever");
967            client.api_version_provider = Arc::new(version_provider);
968            client.logger = logger;
969            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
970
971            client.warn_if_api_version_mismatch(&response);
972
973            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
974        }
975
976        #[tokio::test]
977        async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
978            let leader_aggregator_version = "2.0.0";
979            let aggregator_version = "1.0.0";
980            let (server, mut client) = setup_server_and_client();
981            let (logger, log_inspector) = TestLogger::memory();
982            let version_provider = version_provider_with_open_api_version(aggregator_version);
983            client.api_version_provider = Arc::new(version_provider);
984            client.logger = logger;
985            let epoch_settings_expected = EpochSettingsMessage::dummy();
986            let _server_mock = server.mock(|when, then| {
987                when.path("/epoch-settings");
988                then.status(200)
989                    .body(json!(epoch_settings_expected).to_string())
990                    .header(MITHRIL_API_VERSION_HEADER, leader_aggregator_version);
991            });
992
993            assert!(
994                Version::parse(leader_aggregator_version).unwrap()
995                    > Version::parse(aggregator_version).unwrap()
996            );
997
998            client.retrieve_epoch_settings().await.unwrap();
999
1000            assert_api_version_warning_logged(
1001                &log_inspector,
1002                leader_aggregator_version,
1003                aggregator_version,
1004            );
1005        }
1006    }
1007
1008    mod remote_certificate_retriever {
1009        use mithril_common::test::double::fake_data;
1010
1011        use super::*;
1012
1013        #[tokio::test]
1014        async fn test_get_latest_certificate_details() {
1015            let (server, client) = setup_server_and_client();
1016            let expected_certificate = fake_data::certificate("expected");
1017            let latest_message: CertificateMessage =
1018                expected_certificate.clone().try_into().unwrap();
1019            let latest_certificates = vec![
1020                CertificateListItemMessage {
1021                    hash: expected_certificate.hash.clone(),
1022                    ..CertificateListItemMessage::dummy()
1023                },
1024                CertificateListItemMessage::dummy(),
1025                CertificateListItemMessage::dummy(),
1026            ];
1027            let _server_mock = server.mock(|when, then| {
1028                when.path("/certificates");
1029                then.status(200).body(json!(latest_certificates).to_string());
1030            });
1031            let _server_mock = server.mock(|when, then| {
1032                when.path(format!("/certificate/{}", latest_message.hash));
1033                then.status(200).body(json!(latest_message).to_string());
1034            });
1035
1036            let fetched_certificate = client.get_latest_certificate_details().await.unwrap();
1037
1038            assert_eq!(Some(expected_certificate), fetched_certificate);
1039        }
1040
1041        #[tokio::test]
1042        async fn test_get_latest_genesis_certificate() {
1043            let (server, client) = setup_server_and_client();
1044            let genesis_message = CertificateMessage::dummy();
1045            let expected_genesis: Certificate = genesis_message.clone().try_into().unwrap();
1046            let _server_mock = server.mock(|when, then| {
1047                when.path("/certificate/genesis");
1048                then.status(200).body(json!(genesis_message).to_string());
1049            });
1050
1051            let fetched = client.get_genesis_certificate_details().await.unwrap();
1052
1053            assert_eq!(Some(expected_genesis), fetched);
1054        }
1055    }
1056}