mithril_client/
aggregator_client.rs

1//! Mechanisms to exchange data with an Aggregator.
2//!
3//! The [AggregatorClient] trait abstracts how the communication with an Aggregator
4//! is done.
5//! The clients that need to communicate only need to define their request using the
6//! [AggregatorRequest] enum.
7//!
8//! An implementation using HTTP is available: [AggregatorHTTPClient].
9
10use anyhow::{anyhow, Context};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{debug, error, warn, Logger};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::entities::{ClientError, ServerError};
23use mithril_common::logging::LoggerExtensions;
24use mithril_common::messages::CardanoDatabaseImmutableFilesRestoredMessage;
25use mithril_common::MITHRIL_API_VERSION_HEADER;
26
27use crate::common::Epoch;
28use crate::{MithrilError, MithrilResult};
29
30const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
31    "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
32
33/// Error tied with the Aggregator client
34#[derive(Error, Debug)]
35pub enum AggregatorClientError {
36    /// Error raised when querying the aggregator returned a 5XX error.
37    #[error("Internal error of the Aggregator")]
38    RemoteServerTechnical(#[source] MithrilError),
39
40    /// Error raised when querying the aggregator returned a 4XX error.
41    #[error("Invalid request to the Aggregator")]
42    RemoteServerLogical(#[source] MithrilError),
43
44    /// HTTP subsystem error
45    #[error("HTTP subsystem error")]
46    SubsystemError(#[source] MithrilError),
47}
48
49/// What can be read from an [AggregatorClient].
50#[derive(Debug, Clone, Eq, PartialEq)]
51#[cfg_attr(test, derive(strum::EnumIter))]
52pub enum AggregatorRequest {
53    /// Get a specific [certificate][crate::MithrilCertificate] from the aggregator
54    GetCertificate {
55        /// Hash of the certificate to retrieve
56        hash: String,
57    },
58
59    /// Lists the aggregator [certificates][crate::MithrilCertificate]
60    ListCertificates,
61
62    /// Get a specific [Mithril stake distribution][crate::MithrilStakeDistribution] from the aggregator
63    GetMithrilStakeDistribution {
64        /// Hash of the Mithril stake distribution to retrieve
65        hash: String,
66    },
67
68    /// Lists the aggregator [Mithril stake distribution][crate::MithrilStakeDistribution]
69    ListMithrilStakeDistributions,
70
71    /// Get a specific [snapshot][crate::Snapshot] from the aggregator
72    GetSnapshot {
73        /// Digest of the snapshot to retrieve
74        digest: String,
75    },
76
77    /// Lists the aggregator [snapshots][crate::Snapshot]
78    ListSnapshots,
79
80    /// Increments the aggregator snapshot download statistics
81    IncrementSnapshotStatistic {
82        /// Snapshot as HTTP request body
83        snapshot: String,
84    },
85
86    /// Get a specific [Cardano database snapshot][crate::CardanoDatabaseSnapshot] from the aggregator
87    GetCardanoDatabaseSnapshot {
88        /// Hash of the snapshot to retrieve
89        hash: String,
90    },
91
92    /// Lists the aggregator [Cardano database snapshots][crate::CardanoDatabaseSnapshot]
93    ListCardanoDatabaseSnapshots,
94
95    /// Increments the aggregator Cardano database snapshot immutable files restored statistics
96    IncrementCardanoDatabaseImmutablesRestoredStatistic {
97        /// Number of immutable files restored
98        number_of_immutables: u64,
99    },
100
101    /// Increments the aggregator Cardano database snapshot ancillary files restored statistics
102    IncrementCardanoDatabaseAncillaryStatistic,
103
104    /// Increments the aggregator Cardano database snapshot complete restoration statistics
105    IncrementCardanoDatabaseCompleteRestorationStatistic,
106
107    /// Increments the aggregator Cardano database snapshot partial restoration statistics
108    IncrementCardanoDatabasePartialRestorationStatistic,
109
110    /// Get proofs that the given set of Cardano transactions is included in the global Cardano transactions set
111    GetTransactionsProofs {
112        /// Hashes of the transactions to get proofs for.
113        transactions_hashes: Vec<String>,
114    },
115
116    /// Get a specific [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
117    GetCardanoTransactionSnapshot {
118        /// Hash of the Cardano transaction snapshot to retrieve
119        hash: String,
120    },
121
122    /// Lists the aggregator [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
123    ListCardanoTransactionSnapshots,
124
125    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by hash
126    GetCardanoStakeDistribution {
127        /// Hash of the Cardano stake distribution to retrieve
128        hash: String,
129    },
130
131    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by epoch
132    GetCardanoStakeDistributionByEpoch {
133        /// Epoch at the end of which the Cardano stake distribution is computed by the Cardano node
134        epoch: Epoch,
135    },
136
137    /// Lists the aggregator [Cardano stake distribution][crate::CardanoStakeDistribution]
138    ListCardanoStakeDistributions,
139}
140
141impl AggregatorRequest {
142    /// Get the request route relative to the aggregator root endpoint.
143    pub fn route(&self) -> String {
144        match self {
145            AggregatorRequest::GetCertificate { hash } => {
146                format!("certificate/{hash}")
147            }
148            AggregatorRequest::ListCertificates => "certificates".to_string(),
149            AggregatorRequest::GetMithrilStakeDistribution { hash } => {
150                format!("artifact/mithril-stake-distribution/{hash}")
151            }
152            AggregatorRequest::ListMithrilStakeDistributions => {
153                "artifact/mithril-stake-distributions".to_string()
154            }
155            AggregatorRequest::GetSnapshot { digest } => {
156                format!("artifact/snapshot/{digest}")
157            }
158            AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
159            AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
160                "statistics/snapshot".to_string()
161            }
162            AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
163                format!("artifact/cardano-database/{hash}")
164            }
165            AggregatorRequest::ListCardanoDatabaseSnapshots => {
166                "artifact/cardano-database".to_string()
167            }
168            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
169                number_of_immutables: _,
170            } => "statistics/cardano-database/immutable-files-restored".to_string(),
171            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
172                "statistics/cardano-database/ancillary-files-restored".to_string()
173            }
174            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
175                "statistics/cardano-database/complete-restoration".to_string()
176            }
177            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
178                "statistics/cardano-database/partial-restoration".to_string()
179            }
180            AggregatorRequest::GetTransactionsProofs {
181                transactions_hashes,
182            } => format!(
183                "proof/cardano-transaction?transaction_hashes={}",
184                transactions_hashes.join(",")
185            ),
186            AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
187                format!("artifact/cardano-transaction/{hash}")
188            }
189            AggregatorRequest::ListCardanoTransactionSnapshots => {
190                "artifact/cardano-transactions".to_string()
191            }
192            AggregatorRequest::GetCardanoStakeDistribution { hash } => {
193                format!("artifact/cardano-stake-distribution/{hash}")
194            }
195            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
196                format!("artifact/cardano-stake-distribution/epoch/{epoch}")
197            }
198            AggregatorRequest::ListCardanoStakeDistributions => {
199                "artifact/cardano-stake-distributions".to_string()
200            }
201        }
202    }
203
204    /// Get the request body to send to the aggregator
205    pub fn get_body(&self) -> Option<String> {
206        match self {
207            AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
208                Some(snapshot.to_string())
209            }
210            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
211                number_of_immutables,
212            } => serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
213                nb_immutable_files: *number_of_immutables as u32,
214            })
215            .ok(),
216            _ => None,
217        }
218    }
219}
220
221/// API that defines a client for the Aggregator
222#[cfg_attr(test, mockall::automock)]
223#[cfg_attr(target_family = "wasm", async_trait(?Send))]
224#[cfg_attr(not(target_family = "wasm"), async_trait)]
225pub trait AggregatorClient: Sync + Send {
226    /// Get the content back from the Aggregator
227    async fn get_content(
228        &self,
229        request: AggregatorRequest,
230    ) -> Result<String, AggregatorClientError>;
231
232    /// Post information to the Aggregator
233    async fn post_content(
234        &self,
235        request: AggregatorRequest,
236    ) -> Result<String, AggregatorClientError>;
237}
238
239/// Responsible for HTTP transport and API version check.
240pub struct AggregatorHTTPClient {
241    http_client: reqwest::Client,
242    aggregator_endpoint: Url,
243    api_versions: Arc<RwLock<Vec<Version>>>,
244    logger: Logger,
245    http_headers: HeaderMap,
246}
247
248impl AggregatorHTTPClient {
249    /// Constructs a new `AggregatorHTTPClient`
250    pub fn new(
251        aggregator_endpoint: Url,
252        api_versions: Vec<Version>,
253        logger: Logger,
254        custom_headers: Option<HashMap<String, String>>,
255    ) -> MithrilResult<Self> {
256        let http_client = reqwest::ClientBuilder::new()
257            .build()
258            .with_context(|| "Building http client for Aggregator client failed")?;
259
260        // Trailing slash is significant because url::join
261        // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove
262        // the 'path' part of the url if it doesn't end with a trailing slash.
263        let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
264            aggregator_endpoint
265        } else {
266            let mut url = aggregator_endpoint.clone();
267            url.set_path(&format!("{}/", aggregator_endpoint.path()));
268            url
269        };
270
271        let mut http_headers = HeaderMap::new();
272        if let Some(headers) = custom_headers {
273            for (key, value) in headers.iter() {
274                http_headers.insert(
275                    HeaderName::from_bytes(key.as_bytes())?,
276                    HeaderValue::from_str(value)?,
277                );
278            }
279        }
280
281        Ok(Self {
282            http_client,
283            aggregator_endpoint,
284            api_versions: Arc::new(RwLock::new(api_versions)),
285            logger: logger.new_with_component_name::<Self>(),
286            http_headers,
287        })
288    }
289
290    /// Computes the current api version
291    async fn compute_current_api_version(&self) -> Option<Version> {
292        self.api_versions.read().await.first().cloned()
293    }
294
295    /// Perform a HTTP GET request on the Aggregator and return the given JSON
296    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
297    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
298    async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
299        debug!(self.logger, "GET url='{url}'.");
300        let request_builder = self.http_client.get(url.clone());
301        let current_api_version = self
302            .compute_current_api_version()
303            .await
304            .unwrap()
305            .to_string();
306        debug!(
307            self.logger,
308            "Prepare request with version: {current_api_version}"
309        );
310        let request_builder = request_builder
311            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
312            .headers(self.http_headers.clone());
313
314        let response = request_builder.send().await.map_err(|e| {
315            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
316                "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
317            )))
318        })?;
319
320        match response.status() {
321            StatusCode::OK => {
322                self.warn_if_api_version_mismatch(&response).await;
323
324                Ok(response)
325            }
326            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
327            status_code if status_code.is_client_error() => {
328                Err(Self::remote_logical_error(response).await)
329            }
330            _ => Err(Self::remote_technical_error(response).await),
331        }
332    }
333
334    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
335    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
336    async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
337        debug!(self.logger, "POST url='{url}'"; "json" => json);
338        let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
339        let current_api_version = self
340            .compute_current_api_version()
341            .await
342            .unwrap()
343            .to_string();
344        debug!(
345            self.logger,
346            "Prepare request with version: {current_api_version}"
347        );
348        let request_builder = request_builder
349            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
350            .headers(self.http_headers.clone());
351
352        let response = request_builder.send().await.map_err(|e| {
353            AggregatorClientError::SubsystemError(
354                anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
355            )
356        })?;
357
358        match response.status() {
359            StatusCode::OK | StatusCode::CREATED => {
360                self.warn_if_api_version_mismatch(&response).await;
361
362                Ok(response)
363            }
364            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
365            status_code if status_code.is_client_error() => {
366                Err(Self::remote_logical_error(response).await)
367            }
368            _ => Err(Self::remote_technical_error(response).await),
369        }
370    }
371
372    fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
373        self.aggregator_endpoint
374            .join(endpoint)
375            .with_context(|| {
376                format!(
377                    "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
378                    self.aggregator_endpoint
379                )
380            })
381            .map_err(AggregatorClientError::SubsystemError)
382    }
383
384    fn not_found_error(url: Url) -> AggregatorClientError {
385        AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
386    }
387
388    async fn remote_logical_error(response: Response) -> AggregatorClientError {
389        let status_code = response.status();
390        let client_error = response
391            .json::<ClientError>()
392            .await
393            .unwrap_or(ClientError::new(
394                format!("Unhandled error {status_code}"),
395                "",
396            ));
397
398        AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
399    }
400
401    async fn remote_technical_error(response: Response) -> AggregatorClientError {
402        let status_code = response.status();
403        let server_error = response
404            .json::<ServerError>()
405            .await
406            .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
407
408        AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
409    }
410
411    /// Check API version mismatch and log a warning if the aggregator's version is more recent.
412    async fn warn_if_api_version_mismatch(&self, response: &Response) {
413        let aggregator_version = response
414            .headers()
415            .get(MITHRIL_API_VERSION_HEADER)
416            .and_then(|v| v.to_str().ok())
417            .and_then(|s| Version::parse(s).ok());
418
419        let client_version = self.compute_current_api_version().await;
420
421        match (aggregator_version, client_version) {
422            (Some(aggregator), Some(client)) if client < aggregator => {
423                warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
424                    "aggregator_version" => %aggregator,
425                    "client_version" => %client,
426                );
427            }
428            (Some(_), None) => {
429                error!(
430                    self.logger,
431                    "Failed to compute the current client API version"
432                );
433            }
434            _ => {}
435        }
436    }
437}
438
439#[cfg_attr(target_family = "wasm", async_trait(?Send))]
440#[cfg_attr(not(target_family = "wasm"), async_trait)]
441impl AggregatorClient for AggregatorHTTPClient {
442    async fn get_content(
443        &self,
444        request: AggregatorRequest,
445    ) -> Result<String, AggregatorClientError> {
446        let response = self.get(self.get_url_for_route(&request.route())?).await?;
447        let content = format!("{response:?}");
448
449        response.text().await.map_err(|e| {
450            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
451                "Could not find a JSON body in the response '{content}'."
452            )))
453        })
454    }
455
456    async fn post_content(
457        &self,
458        request: AggregatorRequest,
459    ) -> Result<String, AggregatorClientError> {
460        let response = self
461            .post(
462                self.get_url_for_route(&request.route())?,
463                &request.get_body().unwrap_or_default(),
464            )
465            .await?;
466
467        response.text().await.map_err(|e| {
468            AggregatorClientError::SubsystemError(
469                anyhow!(e).context("Could not find a text body in the response."),
470            )
471        })
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use httpmock::MockServer;
478    use std::collections::HashMap;
479    use strum::IntoEnumIterator;
480
481    use mithril_common::api_version::APIVersionProvider;
482    use mithril_common::entities::{ClientError, ServerError};
483
484    use crate::test_utils::TestLogger;
485
486    use super::*;
487
488    macro_rules! assert_error_eq {
489        ($left:expr, $right:expr) => {
490            assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
491        };
492    }
493
494    fn setup_client(
495        server_url: &str,
496        api_versions: Vec<Version>,
497        custom_headers: Option<HashMap<String, String>>,
498    ) -> AggregatorHTTPClient {
499        AggregatorHTTPClient::new(
500            Url::parse(server_url).unwrap(),
501            api_versions,
502            TestLogger::stdout(),
503            custom_headers,
504        )
505        .expect("building aggregator http client should not fail")
506    }
507
508    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
509        let server = MockServer::start();
510        let client = setup_client(
511            &server.url(""),
512            APIVersionProvider::compute_all_versions_sorted(),
513            None,
514        );
515        (server, client)
516    }
517
518    fn setup_server_and_client_with_custom_headers(
519        custom_headers: HashMap<String, String>,
520    ) -> (MockServer, AggregatorHTTPClient) {
521        let server = MockServer::start();
522        let client = setup_client(
523            &server.url(""),
524            APIVersionProvider::compute_all_versions_sorted(),
525            Some(custom_headers),
526        );
527        (server, client)
528    }
529
530    #[test]
531    fn always_append_trailing_slash_at_build() {
532        for (expected, url) in [
533            ("http://www.test.net/", "http://www.test.net/"),
534            ("http://www.test.net/", "http://www.test.net"),
535            (
536                "http://www.test.net/aggregator/",
537                "http://www.test.net/aggregator/",
538            ),
539            (
540                "http://www.test.net/aggregator/",
541                "http://www.test.net/aggregator",
542            ),
543        ] {
544            let url = Url::parse(url).unwrap();
545            let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
546                .expect("building aggregator http client should not fail");
547
548            assert_eq!(expected, client.aggregator_endpoint.as_str());
549        }
550    }
551
552    #[test]
553    fn deduce_routes_from_request() {
554        assert_eq!(
555            "certificate/abc".to_string(),
556            AggregatorRequest::GetCertificate {
557                hash: "abc".to_string()
558            }
559            .route()
560        );
561
562        assert_eq!(
563            "artifact/mithril-stake-distribution/abc".to_string(),
564            AggregatorRequest::GetMithrilStakeDistribution {
565                hash: "abc".to_string()
566            }
567            .route()
568        );
569
570        assert_eq!(
571            "artifact/mithril-stake-distribution/abc".to_string(),
572            AggregatorRequest::GetMithrilStakeDistribution {
573                hash: "abc".to_string()
574            }
575            .route()
576        );
577
578        assert_eq!(
579            "artifact/mithril-stake-distributions".to_string(),
580            AggregatorRequest::ListMithrilStakeDistributions.route()
581        );
582
583        assert_eq!(
584            "artifact/snapshot/abc".to_string(),
585            AggregatorRequest::GetSnapshot {
586                digest: "abc".to_string()
587            }
588            .route()
589        );
590
591        assert_eq!(
592            "artifact/snapshots".to_string(),
593            AggregatorRequest::ListSnapshots.route()
594        );
595
596        assert_eq!(
597            "statistics/snapshot".to_string(),
598            AggregatorRequest::IncrementSnapshotStatistic {
599                snapshot: "abc".to_string()
600            }
601            .route()
602        );
603
604        assert_eq!(
605            "artifact/cardano-database/abc".to_string(),
606            AggregatorRequest::GetCardanoDatabaseSnapshot {
607                hash: "abc".to_string()
608            }
609            .route()
610        );
611
612        assert_eq!(
613            "artifact/cardano-database".to_string(),
614            AggregatorRequest::ListCardanoDatabaseSnapshots.route()
615        );
616
617        assert_eq!(
618            "statistics/cardano-database/immutable-files-restored".to_string(),
619            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
620                number_of_immutables: 58
621            }
622            .route()
623        );
624
625        assert_eq!(
626            "statistics/cardano-database/ancillary-files-restored".to_string(),
627            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
628        );
629
630        assert_eq!(
631            "statistics/cardano-database/complete-restoration".to_string(),
632            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
633        );
634
635        assert_eq!(
636            "statistics/cardano-database/partial-restoration".to_string(),
637            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
638        );
639
640        assert_eq!(
641            "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
642            AggregatorRequest::GetTransactionsProofs {
643                transactions_hashes: vec![
644                    "abc".to_string(),
645                    "def".to_string(),
646                    "ghi".to_string(),
647                    "jkl".to_string()
648                ]
649            }
650            .route()
651        );
652
653        assert_eq!(
654            "artifact/cardano-transaction/abc".to_string(),
655            AggregatorRequest::GetCardanoTransactionSnapshot {
656                hash: "abc".to_string()
657            }
658            .route()
659        );
660
661        assert_eq!(
662            "artifact/cardano-transactions".to_string(),
663            AggregatorRequest::ListCardanoTransactionSnapshots.route()
664        );
665
666        assert_eq!(
667            "artifact/cardano-stake-distribution/abc".to_string(),
668            AggregatorRequest::GetCardanoStakeDistribution {
669                hash: "abc".to_string()
670            }
671            .route()
672        );
673
674        assert_eq!(
675            "artifact/cardano-stake-distribution/epoch/123".to_string(),
676            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
677        );
678
679        assert_eq!(
680            "artifact/cardano-stake-distributions".to_string(),
681            AggregatorRequest::ListCardanoStakeDistributions.route()
682        );
683    }
684
685    #[test]
686    fn deduce_body_from_request() {
687        fn that_should_not_have_body(req: &AggregatorRequest) -> bool {
688            !matches!(
689                req,
690                AggregatorRequest::IncrementSnapshotStatistic { .. }
691                    | AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic { .. }
692            )
693        }
694
695        assert_eq!(
696            Some(r#"{"key":"value"}"#.to_string()),
697            AggregatorRequest::IncrementSnapshotStatistic {
698                snapshot: r#"{"key":"value"}"#.to_string()
699            }
700            .get_body()
701        );
702
703        assert_eq!(
704            Some(
705                serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
706                    nb_immutable_files: 432,
707                })
708                .unwrap()
709            ),
710            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
711                number_of_immutables: 432,
712            }
713            .get_body()
714        );
715
716        for req_that_should_not_have_body in
717            AggregatorRequest::iter().filter(that_should_not_have_body)
718        {
719            assert_eq!(None, req_that_should_not_have_body.get_body());
720        }
721    }
722
723    #[tokio::test]
724    async fn test_client_handle_4xx_errors() {
725        let client_error = ClientError::new("label", "message");
726
727        let (aggregator, client) = setup_server_and_client();
728        aggregator.mock(|_when, then| {
729            then.status(StatusCode::IM_A_TEAPOT.as_u16())
730                .json_body_obj(&client_error);
731        });
732
733        let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
734
735        let get_content_error = client
736            .get_content(AggregatorRequest::ListCertificates)
737            .await
738            .unwrap_err();
739        assert_error_eq!(get_content_error, expected_error);
740
741        let post_content_error = client
742            .post_content(AggregatorRequest::ListCertificates)
743            .await
744            .unwrap_err();
745        assert_error_eq!(post_content_error, expected_error);
746    }
747
748    #[tokio::test]
749    async fn test_client_handle_404_not_found_error() {
750        let client_error = ClientError::new("label", "message");
751
752        let (aggregator, client) = setup_server_and_client();
753        aggregator.mock(|_when, then| {
754            then.status(StatusCode::NOT_FOUND.as_u16())
755                .json_body_obj(&client_error);
756        });
757
758        let expected_error = AggregatorHTTPClient::not_found_error(
759            Url::parse(&format!(
760                "{}/{}",
761                aggregator.base_url(),
762                AggregatorRequest::ListCertificates.route()
763            ))
764            .unwrap(),
765        );
766
767        let get_content_error = client
768            .get_content(AggregatorRequest::ListCertificates)
769            .await
770            .unwrap_err();
771        assert_error_eq!(get_content_error, expected_error);
772
773        let post_content_error = client
774            .post_content(AggregatorRequest::ListCertificates)
775            .await
776            .unwrap_err();
777        assert_error_eq!(post_content_error, expected_error);
778    }
779
780    #[tokio::test]
781    async fn test_client_handle_5xx_errors() {
782        let server_error = ServerError::new("message");
783
784        let (aggregator, client) = setup_server_and_client();
785        aggregator.mock(|_when, then| {
786            then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
787                .json_body_obj(&server_error);
788        });
789
790        let expected_error =
791            AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
792
793        let get_content_error = client
794            .get_content(AggregatorRequest::ListCertificates)
795            .await
796            .unwrap_err();
797        assert_error_eq!(get_content_error, expected_error);
798
799        let post_content_error = client
800            .post_content(AggregatorRequest::ListCertificates)
801            .await
802            .unwrap_err();
803        assert_error_eq!(post_content_error, expected_error);
804    }
805
806    #[tokio::test]
807    async fn test_client_with_custom_headers() {
808        let mut http_headers = HashMap::new();
809        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
810        http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
811        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
812        aggregator.mock(|when, then| {
813            when.header("Custom-Header", "CustomValue")
814                .header("Another-Header", "AnotherValue");
815            then.status(StatusCode::OK.as_u16()).body("ok");
816        });
817
818        client
819            .get_content(AggregatorRequest::ListCertificates)
820            .await
821            .expect("GET request should succeed");
822
823        client
824            .post_content(AggregatorRequest::ListCertificates)
825            .await
826            .expect("GET request should succeed");
827    }
828
829    #[tokio::test]
830    async fn test_client_sends_accept_encoding_header_with_correct_values() {
831        let (aggregator, client) = setup_server_and_client();
832        aggregator.mock(|when, then| {
833            when.matches(|req| {
834                let headers = req.headers.clone().expect("HTTP headers not found");
835                let accept_encoding_header = headers
836                    .iter()
837                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
838                    .expect("Accept-Encoding header not found");
839
840                let header_value = accept_encoding_header.clone().1;
841                ["gzip", "br", "deflate", "zstd"]
842                    .iter()
843                    .all(|&value| header_value.contains(value))
844            });
845
846            then.status(200).body("ok");
847        });
848
849        client
850            .get_content(AggregatorRequest::ListCertificates)
851            .await
852            .expect("GET request should succeed with Accept-Encoding header");
853    }
854
855    #[tokio::test]
856    async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
857        let mut http_headers = HashMap::new();
858        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
859        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
860        aggregator.mock(|when, then| {
861            when.matches(|req| {
862                let headers = req.headers.clone().expect("HTTP headers not found");
863                let accept_encoding_header = headers
864                    .iter()
865                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
866                    .expect("Accept-Encoding header not found");
867
868                let header_value = accept_encoding_header.clone().1;
869                ["gzip", "br", "deflate", "zstd"]
870                    .iter()
871                    .all(|&value| header_value.contains(value))
872            });
873
874            then.status(200).body("ok");
875        });
876
877        client
878            .get_content(AggregatorRequest::ListCertificates)
879            .await
880            .expect("GET request should succeed with Accept-Encoding header");
881    }
882
883    mod warn_if_api_version_mismatch {
884        use http::response::Builder as HttpResponseBuilder;
885
886        use mithril_common::test_utils::MemoryDrainForTestInspector;
887
888        use super::*;
889
890        fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
891            key: K,
892            value: V,
893        ) -> Response {
894            HttpResponseBuilder::new()
895                .header(key.into(), value.into())
896                .body("whatever")
897                .unwrap()
898                .into()
899        }
900
901        fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
902            log_inspector: &MemoryDrainForTestInspector,
903            aggregator_version: A,
904            client_version: S,
905        ) {
906            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
907            assert!(log_inspector
908                .contains_log(&format!("aggregator_version={}", aggregator_version.into())));
909            assert!(
910                log_inspector.contains_log(&format!("client_version={}", client_version.into()))
911            );
912        }
913
914        #[tokio::test]
915        async fn test_logs_warning_when_aggregator_api_version_is_newer() {
916            let aggregator_version = "2.0.0";
917            let client_version = "1.0.0";
918            let (logger, log_inspector) = TestLogger::memory();
919            let mut client = setup_client(
920                "http://whatever",
921                vec![Version::parse(client_version).unwrap()],
922                None,
923            );
924            client.logger = logger;
925            let response =
926                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
927
928            assert!(
929                Version::parse(aggregator_version).unwrap()
930                    > Version::parse(client_version).unwrap()
931            );
932
933            client.warn_if_api_version_mismatch(&response).await;
934
935            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
936        }
937
938        #[tokio::test]
939        async fn test_no_warning_logged_when_versions_match() {
940            let version = "1.0.0";
941            let (logger, log_inspector) = TestLogger::memory();
942            let mut client = setup_client(
943                "http://whatever",
944                vec![Version::parse(version).unwrap()],
945                None,
946            );
947            client.logger = logger;
948            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
949
950            client.warn_if_api_version_mismatch(&response).await;
951
952            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
953        }
954
955        #[tokio::test]
956        async fn test_no_warning_logged_when_aggregator_api_version_is_older() {
957            let aggregator_version = "1.0.0";
958            let client_version = "2.0.0";
959            let (logger, log_inspector) = TestLogger::memory();
960            let mut client = setup_client(
961                "http://whatever",
962                vec![Version::parse(client_version).unwrap()],
963                None,
964            );
965            client.logger = logger;
966            let response =
967                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
968
969            assert!(
970                Version::parse(aggregator_version).unwrap()
971                    < Version::parse(client_version).unwrap()
972            );
973
974            client.warn_if_api_version_mismatch(&response).await;
975
976            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
977        }
978
979        #[tokio::test]
980        async fn test_does_not_log_or_fail_when_header_is_missing() {
981            let (logger, log_inspector) = TestLogger::memory();
982            let mut client = setup_client(
983                "http://whatever",
984                APIVersionProvider::compute_all_versions_sorted(),
985                None,
986            );
987            client.logger = logger;
988            let response =
989                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
990
991            client.warn_if_api_version_mismatch(&response).await;
992
993            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
994        }
995
996        #[tokio::test]
997        async fn test_does_not_log_or_fail_when_header_is_not_a_version() {
998            let (logger, log_inspector) = TestLogger::memory();
999            let mut client = setup_client(
1000                "http://whatever",
1001                APIVersionProvider::compute_all_versions_sorted(),
1002                None,
1003            );
1004            client.logger = logger;
1005            let response =
1006                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1007
1008            client.warn_if_api_version_mismatch(&response).await;
1009
1010            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1011        }
1012
1013        #[tokio::test]
1014        async fn test_logs_error_when_client_version_cannot_be_computed() {
1015            let (logger, log_inspector) = TestLogger::memory();
1016            let mut client = setup_client("http://whatever", vec![], None);
1017            client.logger = logger;
1018            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1019
1020            client.warn_if_api_version_mismatch(&response).await;
1021
1022            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1023        }
1024
1025        #[tokio::test]
1026        async fn test_client_get_log_warning_if_api_version_mismatch() {
1027            let aggregator_version = "2.0.0";
1028            let client_version = "1.0.0";
1029            let (server, mut client) = setup_server_and_client();
1030            let (logger, log_inspector) = TestLogger::memory();
1031            client.api_versions =
1032                Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1033            client.logger = logger;
1034            server.mock(|_, then| {
1035                then.status(StatusCode::OK.as_u16())
1036                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1037            });
1038
1039            assert!(
1040                Version::parse(aggregator_version).unwrap()
1041                    > Version::parse(client_version).unwrap()
1042            );
1043
1044            client
1045                .get(Url::parse(&server.base_url()).unwrap())
1046                .await
1047                .unwrap();
1048
1049            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1050        }
1051
1052        #[tokio::test]
1053        async fn test_client_post_log_warning_if_api_version_mismatch() {
1054            let aggregator_version = "2.0.0";
1055            let client_version = "1.0.0";
1056            let (server, mut client) = setup_server_and_client();
1057            let (logger, log_inspector) = TestLogger::memory();
1058            client.api_versions =
1059                Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1060            client.logger = logger;
1061            server.mock(|_, then| {
1062                then.status(StatusCode::OK.as_u16())
1063                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1064            });
1065
1066            assert!(
1067                Version::parse(aggregator_version).unwrap()
1068                    > Version::parse(client_version).unwrap()
1069            );
1070
1071            client
1072                .post(Url::parse(&server.base_url()).unwrap(), "whatever")
1073                .await
1074                .unwrap();
1075
1076            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1077        }
1078    }
1079}