mithril_aggregator/services/
aggregator_client.rs

1use anyhow::{Context, anyhow};
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode, Url};
5
6use semver::Version;
7use slog::{Logger, debug, error, warn};
8use std::{io, sync::Arc, time::Duration};
9use thiserror::Error;
10
11use mithril_common::{
12    MITHRIL_AGGREGATOR_VERSION_HEADER, MITHRIL_API_VERSION_HEADER, StdError, StdResult,
13    api_version::APIVersionProvider,
14    certificate_chain::{CertificateRetriever, CertificateRetrieverError},
15    entities::{Certificate, ClientError, ServerError},
16    logging::LoggerExtensions,
17    messages::{
18        CertificateListMessage, CertificateMessage, EpochSettingsMessage, TryFromMessageAdapter,
19    },
20};
21
22use crate::entities::LeaderAggregatorEpochSettings;
23use crate::message_adapters::FromEpochSettingsAdapter;
24use crate::services::{LeaderAggregatorClient, RemoteCertificateRetriever};
25
26const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
27
28const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
29    "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
30
31/// Error structure for the Aggregator Client.
32#[derive(Error, Debug)]
33pub enum AggregatorClientError {
34    /// The aggregator host has returned a technical error.
35    #[error("remote server technical error")]
36    RemoteServerTechnical(#[source] StdError),
37
38    /// The aggregator host responded it cannot fulfill our request.
39    #[error("remote server logical error")]
40    RemoteServerLogical(#[source] StdError),
41
42    /// Could not reach aggregator.
43    #[error("remote server unreachable")]
44    RemoteServerUnreachable(#[source] StdError),
45
46    /// Unhandled status code
47    #[error("unhandled status code: {0}, response text: {1}")]
48    UnhandledStatusCode(StatusCode, String),
49
50    /// Could not parse response.
51    #[error("json parsing failed")]
52    JsonParseFailed(#[source] StdError),
53
54    /// Mostly network errors.
55    #[error("Input/Output error")]
56    IOError(#[from] io::Error),
57
58    /// HTTP client creation error
59    #[error("HTTP client creation failed")]
60    HTTPClientCreation(#[source] StdError),
61
62    /// Proxy creation error
63    #[error("proxy creation failed")]
64    ProxyCreation(#[source] StdError),
65
66    /// Adapter error
67    #[error("adapter failed")]
68    Adapter(#[source] StdError),
69}
70
71impl AggregatorClientError {
72    /// Create an `AggregatorClientError` from a response.
73    ///
74    /// This method is meant to be used after handling domain-specific cases leaving only
75    /// 4xx or 5xx status codes.
76    /// Otherwise, it will return an `UnhandledStatusCode` error.
77    pub async fn from_response(response: Response) -> Self {
78        let error_code = response.status();
79
80        if error_code.is_client_error() {
81            let root_cause = Self::get_root_cause(response).await;
82            Self::RemoteServerLogical(anyhow!(root_cause))
83        } else if error_code.is_server_error() {
84            let root_cause = Self::get_root_cause(response).await;
85            Self::RemoteServerTechnical(anyhow!(root_cause))
86        } else {
87            let response_text = response.text().await.unwrap_or_default();
88            Self::UnhandledStatusCode(error_code, response_text)
89        }
90    }
91
92    async fn get_root_cause(response: Response) -> String {
93        let error_code = response.status();
94        let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
95        let is_json = response
96            .headers()
97            .get(header::CONTENT_TYPE)
98            .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
99
100        if is_json {
101            let json_value: serde_json::Value = response.json().await.unwrap_or_default();
102
103            if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
104                format!(
105                    "{}: {}: {}",
106                    canonical_reason, client_error.label, client_error.message
107                )
108            } else if let Ok(server_error) =
109                serde_json::from_value::<ServerError>(json_value.clone())
110            {
111                format!("{}: {}", canonical_reason, server_error.message)
112            } else if json_value.is_null() {
113                canonical_reason.to_string()
114            } else {
115                format!("{canonical_reason}: {json_value}")
116            }
117        } else {
118            let response_text = response.text().await.unwrap_or_default();
119            format!("{canonical_reason}: {response_text}")
120        }
121    }
122}
123
124/// AggregatorHTTPClient is a http client for an aggregator
125pub struct AggregatorHTTPClient {
126    aggregator_endpoint: Url,
127    relay_endpoint: Option<String>,
128    api_version_provider: Arc<APIVersionProvider>,
129    timeout_duration: Option<Duration>,
130    logger: Logger,
131}
132
133impl AggregatorHTTPClient {
134    /// AggregatorHTTPClient factory
135    pub fn new(
136        aggregator_endpoint: Url,
137        relay_endpoint: Option<String>,
138        api_version_provider: Arc<APIVersionProvider>,
139        timeout_duration: Option<Duration>,
140        logger: Logger,
141    ) -> Self {
142        let logger = logger.new_with_component_name::<Self>();
143        debug!(logger, "New AggregatorHTTPClient created");
144
145        // Trailing slash is significant because url::join
146        // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove
147        // the 'path' part of the url if it doesn't end with a trailing slash.
148        let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
149            aggregator_endpoint
150        } else {
151            let mut url = aggregator_endpoint.clone();
152            url.set_path(&format!("{}/", aggregator_endpoint.path()));
153            url
154        };
155
156        Self {
157            aggregator_endpoint,
158            relay_endpoint,
159            api_version_provider,
160            timeout_duration,
161            logger,
162        }
163    }
164
165    fn join_aggregator_endpoint(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
166        self.aggregator_endpoint
167            .join(endpoint)
168            .with_context(|| {
169                format!(
170                    "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
171                    self.aggregator_endpoint
172                )
173            })
174            .map_err(AggregatorClientError::HTTPClientCreation)
175    }
176
177    fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
178        let client = match &self.relay_endpoint {
179            Some(relay_endpoint) => Client::builder()
180                .proxy(
181                    Proxy::all(relay_endpoint)
182                        .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
183                )
184                .build()
185                .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
186            None => Client::new(),
187        };
188
189        Ok(client)
190    }
191
192    /// Forge a client request adding protocol version in the headers.
193    pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
194        let request_builder = request_builder
195            .header(
196                MITHRIL_API_VERSION_HEADER,
197                self.api_version_provider
198                    .compute_current_version()
199                    .unwrap()
200                    .to_string(),
201            )
202            .header(MITHRIL_AGGREGATOR_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
203
204        if let Some(duration) = self.timeout_duration {
205            request_builder.timeout(duration)
206        } else {
207            request_builder
208        }
209    }
210
211    /// Check API version mismatch and log a warning if the leader aggregator's version is more recent.
212    fn warn_if_api_version_mismatch(&self, response: &Response) {
213        let leader_version = response
214            .headers()
215            .get(MITHRIL_API_VERSION_HEADER)
216            .and_then(|v| v.to_str().ok())
217            .and_then(|s| Version::parse(s).ok());
218
219        let follower_version = self.api_version_provider.compute_current_version();
220
221        match (leader_version, follower_version) {
222            (Some(leader), Ok(follower)) if follower < leader => {
223                warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
224                    "leader_aggregator_version" => %leader,
225                    "aggregator_version" => %follower,
226                );
227            }
228            (Some(_), Err(error)) => {
229                error!(
230                    self.logger,
231                    "Failed to compute the current aggregator API version";
232                    "error" => error.to_string()
233                );
234            }
235            _ => {}
236        }
237    }
238}
239
240// Route specifics methods
241impl AggregatorHTTPClient {
242    async fn epoch_settings(
243        &self,
244    ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
245        debug!(self.logger, "Retrieve epoch settings");
246        let url = self.join_aggregator_endpoint("epoch-settings")?;
247        let response = self
248            .prepare_request_builder(self.prepare_http_client()?.get(url))
249            .send()
250            .await;
251
252        match response {
253            Ok(response) => match response.status() {
254                StatusCode::OK => {
255                    self.warn_if_api_version_mismatch(&response);
256                    match response.json::<EpochSettingsMessage>().await {
257                        Ok(message) => {
258                            let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
259                                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
260                            Ok(Some(epoch_settings))
261                        }
262                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
263                    }
264                }
265                _ => Err(AggregatorClientError::from_response(response).await),
266            },
267            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
268        }
269    }
270
271    async fn latest_certificates_list(
272        &self,
273    ) -> Result<CertificateListMessage, AggregatorClientError> {
274        debug!(self.logger, "Retrieve latest certificates list");
275        let url = self.join_aggregator_endpoint("certificates")?;
276        let response = self
277            .prepare_request_builder(self.prepare_http_client()?.get(url))
278            .send()
279            .await;
280
281        match response {
282            Ok(response) => match response.status() {
283                StatusCode::OK => {
284                    self.warn_if_api_version_mismatch(&response);
285                    match response.json::<CertificateListMessage>().await {
286                        Ok(message) => Ok(message),
287                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
288                    }
289                }
290                _ => Err(AggregatorClientError::from_response(response).await),
291            },
292            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
293        }
294    }
295
296    async fn certificate_details(
297        &self,
298        certificate_hash: &str,
299    ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
300        debug!(self.logger, "Retrieve certificate details"; "certificate_hash" => %certificate_hash);
301        let url = self.join_aggregator_endpoint(&format!("certificate/{certificate_hash}"))?;
302        let response = self
303            .prepare_request_builder(self.prepare_http_client()?.get(url))
304            .send()
305            .await;
306
307        match response {
308            Ok(response) => match response.status() {
309                StatusCode::OK => {
310                    self.warn_if_api_version_mismatch(&response);
311                    match response.json::<CertificateMessage>().await {
312                        Ok(message) => Ok(Some(message)),
313                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
314                    }
315                }
316                StatusCode::NOT_FOUND => Ok(None),
317                _ => Err(AggregatorClientError::from_response(response).await),
318            },
319            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
320        }
321    }
322
323    async fn latest_genesis_certificate(
324        &self,
325    ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
326        self.certificate_details("genesis").await
327    }
328}
329
330#[async_trait]
331impl LeaderAggregatorClient for AggregatorHTTPClient {
332    async fn retrieve_epoch_settings(&self) -> StdResult<Option<LeaderAggregatorEpochSettings>> {
333        let epoch_settings = self.epoch_settings().await?;
334        Ok(epoch_settings)
335    }
336}
337
338#[async_trait]
339impl CertificateRetriever for AggregatorHTTPClient {
340    async fn get_certificate_details(
341        &self,
342        certificate_hash: &str,
343    ) -> Result<Certificate, CertificateRetrieverError> {
344        let message = self
345            .certificate_details(certificate_hash)
346            .await
347            .with_context(|| {
348                format!("Failed to retrieve certificate with hash: '{certificate_hash}'")
349            })
350            .map_err(CertificateRetrieverError)?
351            .ok_or(CertificateRetrieverError(anyhow!(
352                "Certificate does not exist: '{certificate_hash}'"
353            )))?;
354
355        message.try_into().map_err(CertificateRetrieverError)
356    }
357}
358
359#[async_trait]
360impl RemoteCertificateRetriever for AggregatorHTTPClient {
361    async fn get_latest_certificate_details(&self) -> StdResult<Option<Certificate>> {
362        let latest_certificates_list = self.latest_certificates_list().await?;
363
364        match latest_certificates_list.first() {
365            None => Ok(None),
366            Some(latest_certificate_list_item) => {
367                let latest_certificate_message =
368                    self.certificate_details(&latest_certificate_list_item.hash).await?;
369                latest_certificate_message.map(TryInto::try_into).transpose()
370            }
371        }
372    }
373
374    async fn get_genesis_certificate_details(&self) -> StdResult<Option<Certificate>> {
375        match self.latest_genesis_certificate().await? {
376            Some(message) => Ok(Some(message.try_into()?)),
377            None => Ok(None),
378        }
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use http::response::Builder as HttpResponseBuilder;
385    use httpmock::prelude::*;
386    use reqwest::IntoUrl;
387    use serde_json::json;
388
389    use mithril_common::messages::CertificateListItemMessage;
390    use mithril_common::test::double::{Dummy, DummyApiVersionDiscriminantSource};
391
392    use crate::test::TestLogger;
393
394    use super::*;
395
396    fn setup_client<U: IntoUrl>(server_url: U) -> AggregatorHTTPClient {
397        let discriminant_source = DummyApiVersionDiscriminantSource::default();
398        let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
399
400        AggregatorHTTPClient::new(
401            server_url.into_url().unwrap(),
402            None,
403            Arc::new(api_version_provider),
404            None,
405            TestLogger::stdout(),
406        )
407    }
408
409    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
410        let server = MockServer::start();
411        let aggregator_endpoint = server.url("");
412        let client = setup_client(&aggregator_endpoint);
413
414        (server, client)
415    }
416
417    fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
418        HttpResponseBuilder::new()
419            .status(status_code)
420            .body(body.into())
421            .unwrap()
422            .into()
423    }
424
425    fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
426        HttpResponseBuilder::new()
427            .status(status_code)
428            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
429            .body(serde_json::to_string(&body).unwrap())
430            .unwrap()
431            .into()
432    }
433
434    macro_rules! assert_error_text_contains {
435        ($error: expr, $expect_contains: expr) => {
436            let error = &$error;
437            assert!(
438                error.contains($expect_contains),
439                "Expected error message to contain '{}'\ngot '{error:?}'",
440                $expect_contains,
441            );
442        };
443    }
444
445    #[tokio::test]
446    async fn test_epoch_settings_ok_200() {
447        let (server, client) = setup_server_and_client();
448        let epoch_settings_expected = EpochSettingsMessage::dummy();
449        let _server_mock = server.mock(|when, then| {
450            when.path("/epoch-settings");
451            then.status(200).body(json!(epoch_settings_expected).to_string());
452        });
453
454        let epoch_settings = client.retrieve_epoch_settings().await;
455        epoch_settings.as_ref().expect("unexpected error");
456        assert_eq!(
457            FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
458            epoch_settings.unwrap().unwrap()
459        );
460    }
461
462    #[tokio::test]
463    async fn test_epoch_settings_ko_500() {
464        let (server, client) = setup_server_and_client();
465        let _server_mock = server.mock(|when, then| {
466            when.path("/epoch-settings");
467            then.status(500).body("an error occurred");
468        });
469
470        match client.epoch_settings().await.unwrap_err() {
471            AggregatorClientError::RemoteServerTechnical(_) => (),
472            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
473        };
474    }
475
476    #[tokio::test]
477    async fn test_epoch_settings_timeout() {
478        let (server, mut client) = setup_server_and_client();
479        client.timeout_duration = Some(Duration::from_millis(10));
480        let _server_mock = server.mock(|when, then| {
481            when.path("/epoch-settings");
482            then.delay(Duration::from_millis(100));
483        });
484
485        let error = client
486            .epoch_settings()
487            .await
488            .expect_err("retrieve_epoch_settings should fail");
489
490        assert!(
491            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
492            "unexpected error type: {error:?}"
493        );
494    }
495
496    #[tokio::test]
497    async fn test_latest_certificates_list_ok_200() {
498        let (server, client) = setup_server_and_client();
499        let expected_list = vec![
500            CertificateListItemMessage::dummy(),
501            CertificateListItemMessage::dummy(),
502        ];
503        let _server_mock = server.mock(|when, then| {
504            when.path("/certificates");
505            then.status(200).body(json!(expected_list).to_string());
506        });
507
508        let fetched_list = client.latest_certificates_list().await.unwrap();
509
510        assert_eq!(expected_list, fetched_list);
511    }
512
513    #[tokio::test]
514    async fn test_latest_certificates_list_ko_500() {
515        let (server, client) = setup_server_and_client();
516        let _server_mock = server.mock(|when, then| {
517            when.path("/certificates");
518            then.status(500).body("an error occurred");
519        });
520
521        match client.latest_certificates_list().await.unwrap_err() {
522            AggregatorClientError::RemoteServerTechnical(_) => (),
523            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
524        };
525    }
526
527    #[tokio::test]
528    async fn test_latest_certificates_list_timeout() {
529        let (server, mut client) = setup_server_and_client();
530        client.timeout_duration = Some(Duration::from_millis(10));
531        let _server_mock = server.mock(|when, then| {
532            when.path("/certificates");
533            then.delay(Duration::from_millis(100));
534        });
535
536        let error = client
537            .latest_certificates_list()
538            .await
539            .expect_err("retrieve_epoch_settings should fail");
540
541        assert!(
542            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
543            "unexpected error type: {error:?}"
544        );
545    }
546
547    #[tokio::test]
548    async fn test_certificates_details_ok_200() {
549        let (server, client) = setup_server_and_client();
550        let expected_message = CertificateMessage::dummy();
551        let _server_mock = server.mock(|when, then| {
552            when.path(format!("/certificate/{}", expected_message.hash));
553            then.status(200).body(json!(expected_message).to_string());
554        });
555
556        let fetched_message = client.certificate_details(&expected_message.hash).await.unwrap();
557
558        assert_eq!(Some(expected_message), fetched_message);
559    }
560
561    #[tokio::test]
562    async fn test_certificates_details_ok_404() {
563        let (server, client) = setup_server_and_client();
564        let _server_mock = server.mock(|when, then| {
565            when.path("/certificate/not-found");
566            then.status(404);
567        });
568
569        let fetched_message = client.latest_genesis_certificate().await.unwrap();
570
571        assert_eq!(None, fetched_message);
572    }
573
574    #[tokio::test]
575    async fn test_certificates_details_ko_500() {
576        let (server, client) = setup_server_and_client();
577        let _server_mock = server.mock(|when, then| {
578            when.path("/certificate/whatever");
579            then.status(500).body("an error occurred");
580        });
581
582        match client.certificate_details("whatever").await.unwrap_err() {
583            AggregatorClientError::RemoteServerTechnical(_) => (),
584            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
585        };
586    }
587
588    #[tokio::test]
589    async fn test_certificates_details_timeout() {
590        let (server, mut client) = setup_server_and_client();
591        client.timeout_duration = Some(Duration::from_millis(10));
592        let _server_mock = server.mock(|when, then| {
593            when.path("/certificate/whatever");
594            then.delay(Duration::from_millis(100));
595        });
596
597        let error = client
598            .certificate_details("whatever")
599            .await
600            .expect_err("retrieve_epoch_settings should fail");
601
602        assert!(
603            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
604            "unexpected error type: {error:?}"
605        );
606    }
607
608    #[tokio::test]
609    async fn test_latest_genesis_ok_200() {
610        let (server, client) = setup_server_and_client();
611        let genesis_message = CertificateMessage::dummy();
612        let _server_mock = server.mock(|when, then| {
613            when.path("/certificate/genesis");
614            then.status(200).body(json!(genesis_message).to_string());
615        });
616
617        let fetched = client.latest_genesis_certificate().await.unwrap();
618
619        assert_eq!(Some(genesis_message), fetched);
620    }
621
622    #[tokio::test]
623    async fn test_latest_genesis_ok_404() {
624        let (server, client) = setup_server_and_client();
625        let _server_mock = server.mock(|when, then| {
626            when.path("/certificate/genesis");
627            then.status(404);
628        });
629
630        let fetched = client.latest_genesis_certificate().await.unwrap();
631
632        assert_eq!(None, fetched);
633    }
634
635    #[tokio::test]
636    async fn test_latest_genesis_ko_500() {
637        let (server, client) = setup_server_and_client();
638        let _server_mock = server.mock(|when, then| {
639            when.path("/certificate/genesis");
640            then.status(500).body("an error occurred");
641        });
642
643        let error = client.latest_genesis_certificate().await.unwrap_err();
644
645        assert!(
646            matches!(error, AggregatorClientError::RemoteServerTechnical(_)),
647            "Expected Aggregator::RemoteServerTechnical error, got {error:?}"
648        );
649    }
650
651    #[tokio::test]
652    async fn test_latest_genesis_timeout() {
653        let (server, mut client) = setup_server_and_client();
654        client.timeout_duration = Some(Duration::from_millis(10));
655        let _server_mock = server.mock(|when, then| {
656            when.path("/certificate/genesis");
657            then.delay(Duration::from_millis(100));
658        });
659
660        let error = client.latest_genesis_certificate().await.unwrap_err();
661
662        assert!(
663            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
664            "unexpected error type: {error:?}"
665        );
666    }
667
668    #[tokio::test]
669    async fn test_4xx_errors_are_handled_as_remote_server_logical() {
670        let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
671        let handled_error = AggregatorClientError::from_response(response).await;
672
673        assert!(
674            matches!(
675                handled_error,
676                AggregatorClientError::RemoteServerLogical(..)
677            ),
678            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
679        );
680    }
681
682    #[tokio::test]
683    async fn test_5xx_errors_are_handled_as_remote_server_technical() {
684        let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
685        let handled_error = AggregatorClientError::from_response(response).await;
686
687        assert!(
688            matches!(
689                handled_error,
690                AggregatorClientError::RemoteServerTechnical(..)
691            ),
692            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
693        );
694    }
695
696    #[tokio::test]
697    async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
698     {
699        let response = build_text_response(StatusCode::OK, "ok text");
700        let handled_error = AggregatorClientError::from_response(response).await;
701
702        assert!(
703            matches!(
704                handled_error,
705                AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
706            ),
707            "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
708        );
709    }
710
711    #[tokio::test]
712    async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
713        let error_text = "An error occurred; please try again later.";
714        let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
715
716        assert_error_text_contains!(
717            AggregatorClientError::get_root_cause(response).await,
718            "expectation failed: An error occurred; please try again later."
719        );
720    }
721
722    #[tokio::test]
723    async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
724     {
725        let client_error = ClientError::new("label", "message");
726        let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
727
728        assert_error_text_contains!(
729            AggregatorClientError::get_root_cause(response).await,
730            "bad request: label: message"
731        );
732    }
733
734    #[tokio::test]
735    async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
736     {
737        let server_error = ServerError::new("message");
738        let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
739
740        assert_error_text_contains!(
741            AggregatorClientError::get_root_cause(response).await,
742            "bad request: message"
743        );
744    }
745
746    #[tokio::test]
747    async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
748        let response = build_json_response(
749            StatusCode::INTERNAL_SERVER_ERROR,
750            &json!({ "second": "unknown", "first": "foreign" }),
751        );
752
753        assert_error_text_contains!(
754            AggregatorClientError::get_root_cause(response).await,
755            r#"internal server error: {"first":"foreign","second":"unknown"}"#
756        );
757    }
758
759    #[tokio::test]
760    async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
761        let response = HttpResponseBuilder::new()
762            .status(StatusCode::BAD_REQUEST)
763            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
764            .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
765            .unwrap()
766            .into();
767
768        let root_cause = AggregatorClientError::get_root_cause(response).await;
769
770        assert_error_text_contains!(root_cause, "bad request");
771        assert!(
772            !root_cause.contains("bad request: "),
773            "Expected error message should not contain additional information \ngot '{root_cause:?}'"
774        );
775    }
776
777    mod warn_if_api_version_mismatch {
778        use std::collections::HashMap;
779
780        use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
781        use mithril_common::test::logging::MemoryDrainForTestInspector;
782
783        use super::*;
784
785        fn version_provider_with_open_api_version<V: Into<String>>(
786            version: V,
787        ) -> APIVersionProvider {
788            let mut version_provider = version_provider_without_open_api_version();
789            let mut open_api_versions = HashMap::new();
790            open_api_versions.insert(
791                "openapi.yaml".to_string(),
792                Version::parse(&version.into()).unwrap(),
793            );
794            version_provider.update_open_api_versions(open_api_versions);
795
796            version_provider
797        }
798
799        fn version_provider_without_open_api_version() -> APIVersionProvider {
800            let mut version_provider =
801                APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::default()));
802            version_provider.update_open_api_versions(HashMap::new());
803
804            version_provider
805        }
806
807        fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
808            key: K,
809            value: V,
810        ) -> Response {
811            HttpResponseBuilder::new()
812                .header(key.into(), value.into())
813                .body("whatever")
814                .unwrap()
815                .into()
816        }
817
818        fn assert_api_version_warning_logged<L: Into<String>, A: Into<String>>(
819            log_inspector: &MemoryDrainForTestInspector,
820            leader_aggregator_version: L,
821            aggregator_version: A,
822        ) {
823            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
824            assert!(log_inspector.contains_log(&format!(
825                "leader_aggregator_version={}",
826                leader_aggregator_version.into()
827            )));
828            assert!(
829                log_inspector
830                    .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
831            );
832        }
833
834        #[test]
835        fn test_logs_warning_when_leader_aggregator_api_version_is_newer() {
836            let leader_aggregator_version = "2.0.0";
837            let aggregator_version = "1.0.0";
838            let (logger, log_inspector) = TestLogger::memory();
839            let version_provider = version_provider_with_open_api_version(aggregator_version);
840            let mut client = setup_client("http://whatever");
841            client.api_version_provider = Arc::new(version_provider);
842            client.logger = logger;
843            let response = build_fake_response_with_header(
844                MITHRIL_API_VERSION_HEADER,
845                leader_aggregator_version,
846            );
847
848            assert!(
849                Version::parse(leader_aggregator_version).unwrap()
850                    > Version::parse(aggregator_version).unwrap()
851            );
852
853            client.warn_if_api_version_mismatch(&response);
854
855            assert_api_version_warning_logged(
856                &log_inspector,
857                leader_aggregator_version,
858                aggregator_version,
859            );
860        }
861
862        #[test]
863        fn test_no_warning_logged_when_versions_match() {
864            let version = "1.0.0";
865            let (logger, log_inspector) = TestLogger::memory();
866            let version_provider = version_provider_with_open_api_version(version);
867            let mut client = setup_client("http://whatever");
868            client.api_version_provider = Arc::new(version_provider);
869            client.logger = logger;
870            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
871
872            client.warn_if_api_version_mismatch(&response);
873
874            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
875        }
876
877        #[test]
878        fn test_no_warning_logged_when_leader_aggregator_api_version_is_older() {
879            let leader_aggregator_version = "1.0.0";
880            let aggregator_version = "2.0.0";
881            let (logger, log_inspector) = TestLogger::memory();
882            let version_provider = version_provider_with_open_api_version(aggregator_version);
883            let mut client = setup_client("http://whatever");
884            client.api_version_provider = Arc::new(version_provider);
885            client.logger = logger;
886            let response = build_fake_response_with_header(
887                MITHRIL_API_VERSION_HEADER,
888                leader_aggregator_version,
889            );
890
891            assert!(
892                Version::parse(leader_aggregator_version).unwrap()
893                    < Version::parse(aggregator_version).unwrap()
894            );
895
896            client.warn_if_api_version_mismatch(&response);
897
898            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
899        }
900
901        #[test]
902        fn test_does_not_log_or_fail_when_header_is_missing() {
903            let (logger, log_inspector) = TestLogger::memory();
904            let mut client = setup_client("http://whatever");
905            client.logger = logger;
906            let response =
907                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
908
909            client.warn_if_api_version_mismatch(&response);
910
911            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
912        }
913
914        #[test]
915        fn test_does_not_log_or_fail_when_header_is_not_a_version() {
916            let (logger, log_inspector) = TestLogger::memory();
917            let mut client = setup_client("http://whatever");
918            client.logger = logger;
919            let response =
920                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
921
922            client.warn_if_api_version_mismatch(&response);
923
924            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
925        }
926
927        #[test]
928        fn test_logs_error_when_aggregator_version_cannot_be_computed() {
929            let (logger, log_inspector) = TestLogger::memory();
930            let version_provider = version_provider_without_open_api_version();
931            let mut client = setup_client("http://whatever");
932            client.api_version_provider = Arc::new(version_provider);
933            client.logger = logger;
934            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
935
936            client.warn_if_api_version_mismatch(&response);
937
938            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
939        }
940
941        #[tokio::test]
942        async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
943            let leader_aggregator_version = "2.0.0";
944            let aggregator_version = "1.0.0";
945            let (server, mut client) = setup_server_and_client();
946            let (logger, log_inspector) = TestLogger::memory();
947            let version_provider = version_provider_with_open_api_version(aggregator_version);
948            client.api_version_provider = Arc::new(version_provider);
949            client.logger = logger;
950            let epoch_settings_expected = EpochSettingsMessage::dummy();
951            let _server_mock = server.mock(|when, then| {
952                when.path("/epoch-settings");
953                then.status(200)
954                    .body(json!(epoch_settings_expected).to_string())
955                    .header(MITHRIL_API_VERSION_HEADER, leader_aggregator_version);
956            });
957
958            assert!(
959                Version::parse(leader_aggregator_version).unwrap()
960                    > Version::parse(aggregator_version).unwrap()
961            );
962
963            client.retrieve_epoch_settings().await.unwrap();
964
965            assert_api_version_warning_logged(
966                &log_inspector,
967                leader_aggregator_version,
968                aggregator_version,
969            );
970        }
971    }
972
973    mod remote_certificate_retriever {
974        use mithril_common::test::double::fake_data;
975
976        use super::*;
977
978        #[tokio::test]
979        async fn test_get_latest_certificate_details() {
980            let (server, client) = setup_server_and_client();
981            let expected_certificate = fake_data::certificate("expected");
982            let latest_message: CertificateMessage =
983                expected_certificate.clone().try_into().unwrap();
984            let latest_certificates = vec![
985                CertificateListItemMessage {
986                    hash: expected_certificate.hash.clone(),
987                    ..CertificateListItemMessage::dummy()
988                },
989                CertificateListItemMessage::dummy(),
990                CertificateListItemMessage::dummy(),
991            ];
992            let _server_mock = server.mock(|when, then| {
993                when.path("/certificates");
994                then.status(200).body(json!(latest_certificates).to_string());
995            });
996            let _server_mock = server.mock(|when, then| {
997                when.path(format!("/certificate/{}", latest_message.hash));
998                then.status(200).body(json!(latest_message).to_string());
999            });
1000
1001            let fetched_certificate = client.get_latest_certificate_details().await.unwrap();
1002
1003            assert_eq!(Some(expected_certificate), fetched_certificate);
1004        }
1005
1006        #[tokio::test]
1007        async fn test_get_latest_genesis_certificate() {
1008            let (server, client) = setup_server_and_client();
1009            let genesis_message = CertificateMessage::dummy();
1010            let expected_genesis: Certificate = genesis_message.clone().try_into().unwrap();
1011            let _server_mock = server.mock(|when, then| {
1012                when.path("/certificate/genesis");
1013                then.status(200).body(json!(genesis_message).to_string());
1014            });
1015
1016            let fetched = client.get_genesis_certificate_details().await.unwrap();
1017
1018            assert_eq!(Some(expected_genesis), fetched);
1019        }
1020    }
1021}