mithril_client/
aggregator_client.rs

1//! Mechanisms to exchange data with an Aggregator.
2//!
3//! The [AggregatorClient] trait abstracts how the communication with an Aggregator
4//! is done.
5//! The clients that need to communicate only need to define their request using the
6//! [AggregatorRequest] enum.
7//!
8//! An implementation using HTTP is available: [AggregatorHTTPClient].
9
10use anyhow::{Context, anyhow};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{Logger, debug, error, warn};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::MITHRIL_API_VERSION_HEADER;
23use mithril_common::entities::{ClientError, ServerError};
24use mithril_common::logging::LoggerExtensions;
25use mithril_common::messages::CardanoDatabaseImmutableFilesRestoredMessage;
26
27use crate::common::Epoch;
28use crate::{MithrilError, MithrilResult};
29
30const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
31
32/// Error tied with the Aggregator client
33#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35    /// Error raised when querying the aggregator returned a 5XX error.
36    #[error("Internal error of the Aggregator")]
37    RemoteServerTechnical(#[source] MithrilError),
38
39    /// Error raised when querying the aggregator returned a 4XX error.
40    #[error("Invalid request to the Aggregator")]
41    RemoteServerLogical(#[source] MithrilError),
42
43    /// HTTP subsystem error
44    #[error("HTTP subsystem error")]
45    SubsystemError(#[source] MithrilError),
46}
47
48/// What can be read from an [AggregatorClient].
49#[derive(Debug, Clone, Eq, PartialEq)]
50#[cfg_attr(test, derive(strum::EnumIter))]
51pub enum AggregatorRequest {
52    /// Get a specific [certificate][crate::MithrilCertificate] from the aggregator
53    GetCertificate {
54        /// Hash of the certificate to retrieve
55        hash: String,
56    },
57
58    /// Lists the aggregator [certificates][crate::MithrilCertificate]
59    ListCertificates,
60
61    /// Get a specific [Mithril stake distribution][crate::MithrilStakeDistribution] from the aggregator
62    GetMithrilStakeDistribution {
63        /// Hash of the Mithril stake distribution to retrieve
64        hash: String,
65    },
66
67    /// Lists the aggregator [Mithril stake distribution][crate::MithrilStakeDistribution]
68    ListMithrilStakeDistributions,
69
70    /// Get a specific [snapshot][crate::Snapshot] from the aggregator
71    GetSnapshot {
72        /// Digest of the snapshot to retrieve
73        digest: String,
74    },
75
76    /// Lists the aggregator [snapshots][crate::Snapshot]
77    ListSnapshots,
78
79    /// Increments the aggregator snapshot download statistics
80    IncrementSnapshotStatistic {
81        /// Snapshot as HTTP request body
82        snapshot: String,
83    },
84
85    /// Get a specific [Cardano database snapshot][crate::CardanoDatabaseSnapshot] from the aggregator
86    GetCardanoDatabaseSnapshot {
87        /// Hash of the snapshot to retrieve
88        hash: String,
89    },
90
91    /// Lists the aggregator [Cardano database snapshots][crate::CardanoDatabaseSnapshot]
92    ListCardanoDatabaseSnapshots,
93
94    /// Increments the aggregator Cardano database snapshot immutable files restored statistics
95    IncrementCardanoDatabaseImmutablesRestoredStatistic {
96        /// Number of immutable files restored
97        number_of_immutables: u64,
98    },
99
100    /// Increments the aggregator Cardano database snapshot ancillary files restored statistics
101    IncrementCardanoDatabaseAncillaryStatistic,
102
103    /// Increments the aggregator Cardano database snapshot complete restoration statistics
104    IncrementCardanoDatabaseCompleteRestorationStatistic,
105
106    /// Increments the aggregator Cardano database snapshot partial restoration statistics
107    IncrementCardanoDatabasePartialRestorationStatistic,
108
109    /// Get proofs that the given set of Cardano transactions is included in the global Cardano transactions set
110    GetTransactionsProofs {
111        /// Hashes of the transactions to get proofs for.
112        transactions_hashes: Vec<String>,
113    },
114
115    /// Get a specific [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
116    GetCardanoTransactionSnapshot {
117        /// Hash of the Cardano transaction snapshot to retrieve
118        hash: String,
119    },
120
121    /// Lists the aggregator [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
122    ListCardanoTransactionSnapshots,
123
124    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by hash
125    GetCardanoStakeDistribution {
126        /// Hash of the Cardano stake distribution to retrieve
127        hash: String,
128    },
129
130    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by epoch
131    GetCardanoStakeDistributionByEpoch {
132        /// Epoch at the end of which the Cardano stake distribution is computed by the Cardano node
133        epoch: Epoch,
134    },
135
136    /// Lists the aggregator [Cardano stake distribution][crate::CardanoStakeDistribution]
137    ListCardanoStakeDistributions,
138
139    /// Get information about the aggregator status
140    Status,
141}
142
143impl AggregatorRequest {
144    /// Get the request route relative to the aggregator root endpoint.
145    pub fn route(&self) -> String {
146        match self {
147            AggregatorRequest::GetCertificate { hash } => {
148                format!("certificate/{hash}")
149            }
150            AggregatorRequest::ListCertificates => "certificates".to_string(),
151            AggregatorRequest::GetMithrilStakeDistribution { hash } => {
152                format!("artifact/mithril-stake-distribution/{hash}")
153            }
154            AggregatorRequest::ListMithrilStakeDistributions => {
155                "artifact/mithril-stake-distributions".to_string()
156            }
157            AggregatorRequest::GetSnapshot { digest } => {
158                format!("artifact/snapshot/{digest}")
159            }
160            AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
161            AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
162                "statistics/snapshot".to_string()
163            }
164            AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
165                format!("artifact/cardano-database/{hash}")
166            }
167            AggregatorRequest::ListCardanoDatabaseSnapshots => {
168                "artifact/cardano-database".to_string()
169            }
170            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
171                number_of_immutables: _,
172            } => "statistics/cardano-database/immutable-files-restored".to_string(),
173            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
174                "statistics/cardano-database/ancillary-files-restored".to_string()
175            }
176            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
177                "statistics/cardano-database/complete-restoration".to_string()
178            }
179            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
180                "statistics/cardano-database/partial-restoration".to_string()
181            }
182            AggregatorRequest::GetTransactionsProofs {
183                transactions_hashes,
184            } => format!(
185                "proof/cardano-transaction?transaction_hashes={}",
186                transactions_hashes.join(",")
187            ),
188            AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
189                format!("artifact/cardano-transaction/{hash}")
190            }
191            AggregatorRequest::ListCardanoTransactionSnapshots => {
192                "artifact/cardano-transactions".to_string()
193            }
194            AggregatorRequest::GetCardanoStakeDistribution { hash } => {
195                format!("artifact/cardano-stake-distribution/{hash}")
196            }
197            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
198                format!("artifact/cardano-stake-distribution/epoch/{epoch}")
199            }
200            AggregatorRequest::ListCardanoStakeDistributions => {
201                "artifact/cardano-stake-distributions".to_string()
202            }
203            AggregatorRequest::Status => "status".to_string(),
204        }
205    }
206
207    /// Get the request body to send to the aggregator
208    pub fn get_body(&self) -> Option<String> {
209        match self {
210            AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
211                Some(snapshot.to_string())
212            }
213            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
214                number_of_immutables,
215            } => serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
216                nb_immutable_files: *number_of_immutables as u32,
217            })
218            .ok(),
219            _ => None,
220        }
221    }
222}
223
224/// API that defines a client for the Aggregator
225#[cfg_attr(test, mockall::automock)]
226#[cfg_attr(target_family = "wasm", async_trait(?Send))]
227#[cfg_attr(not(target_family = "wasm"), async_trait)]
228pub trait AggregatorClient: Sync + Send {
229    /// Get the content back from the Aggregator
230    async fn get_content(
231        &self,
232        request: AggregatorRequest,
233    ) -> Result<String, AggregatorClientError>;
234
235    /// Post information to the Aggregator
236    async fn post_content(
237        &self,
238        request: AggregatorRequest,
239    ) -> Result<String, AggregatorClientError>;
240}
241
242/// Responsible for HTTP transport and API version check.
243pub struct AggregatorHTTPClient {
244    http_client: reqwest::Client,
245    aggregator_endpoint: Url,
246    api_versions: Arc<RwLock<Vec<Version>>>,
247    logger: Logger,
248    http_headers: HeaderMap,
249}
250
251impl AggregatorHTTPClient {
252    /// Constructs a new `AggregatorHTTPClient`
253    pub fn new(
254        aggregator_endpoint: Url,
255        api_versions: Vec<Version>,
256        logger: Logger,
257        custom_headers: Option<HashMap<String, String>>,
258    ) -> MithrilResult<Self> {
259        let http_client = reqwest::ClientBuilder::new()
260            .build()
261            .with_context(|| "Building http client for Aggregator client failed")?;
262
263        // Trailing slash is significant because url::join
264        // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove
265        // the 'path' part of the url if it doesn't end with a trailing slash.
266        let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
267            aggregator_endpoint
268        } else {
269            let mut url = aggregator_endpoint.clone();
270            url.set_path(&format!("{}/", aggregator_endpoint.path()));
271            url
272        };
273
274        let mut http_headers = HeaderMap::new();
275        if let Some(headers) = custom_headers {
276            for (key, value) in headers.iter() {
277                http_headers.insert(
278                    HeaderName::from_bytes(key.as_bytes())?,
279                    HeaderValue::from_str(value)?,
280                );
281            }
282        }
283
284        Ok(Self {
285            http_client,
286            aggregator_endpoint,
287            api_versions: Arc::new(RwLock::new(api_versions)),
288            logger: logger.new_with_component_name::<Self>(),
289            http_headers,
290        })
291    }
292
293    /// Computes the current api version
294    async fn compute_current_api_version(&self) -> Option<Version> {
295        self.api_versions.read().await.first().cloned()
296    }
297
298    /// Perform a HTTP GET request on the Aggregator and return the given JSON
299    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
300    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
301    async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
302        debug!(self.logger, "GET url='{url}'.");
303        let request_builder = self.http_client.get(url.clone());
304        let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
305        debug!(
306            self.logger,
307            "Prepare request with version: {current_api_version}"
308        );
309        let request_builder = request_builder
310            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
311            .headers(self.http_headers.clone());
312
313        let response = request_builder.send().await.map_err(|e| {
314            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
315                "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
316            )))
317        })?;
318
319        match response.status() {
320            StatusCode::OK => {
321                self.warn_if_api_version_mismatch(&response).await;
322
323                Ok(response)
324            }
325            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
326            status_code if status_code.is_client_error() => {
327                Err(Self::remote_logical_error(response).await)
328            }
329            _ => Err(Self::remote_technical_error(response).await),
330        }
331    }
332
333    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
334    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
335    async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
336        debug!(self.logger, "POST url='{url}'"; "json" => json);
337        let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
338        let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
339        debug!(
340            self.logger,
341            "Prepare request with version: {current_api_version}"
342        );
343        let request_builder = request_builder
344            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
345            .headers(self.http_headers.clone());
346
347        let response = request_builder.send().await.map_err(|e| {
348            AggregatorClientError::SubsystemError(
349                anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
350            )
351        })?;
352
353        match response.status() {
354            StatusCode::OK | StatusCode::CREATED => {
355                self.warn_if_api_version_mismatch(&response).await;
356
357                Ok(response)
358            }
359            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
360            status_code if status_code.is_client_error() => {
361                Err(Self::remote_logical_error(response).await)
362            }
363            _ => Err(Self::remote_technical_error(response).await),
364        }
365    }
366
367    fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
368        self.aggregator_endpoint
369            .join(endpoint)
370            .with_context(|| {
371                format!(
372                    "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
373                    self.aggregator_endpoint
374                )
375            })
376            .map_err(AggregatorClientError::SubsystemError)
377    }
378
379    fn not_found_error(url: Url) -> AggregatorClientError {
380        AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
381    }
382
383    async fn remote_logical_error(response: Response) -> AggregatorClientError {
384        let status_code = response.status();
385        let client_error = response.json::<ClientError>().await.unwrap_or(ClientError::new(
386            format!("Unhandled error {status_code}"),
387            "",
388        ));
389
390        AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
391    }
392
393    async fn remote_technical_error(response: Response) -> AggregatorClientError {
394        let status_code = response.status();
395        let server_error = response
396            .json::<ServerError>()
397            .await
398            .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
399
400        AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
401    }
402
403    /// Check API version mismatch and log a warning if the aggregator's version is more recent.
404    async fn warn_if_api_version_mismatch(&self, response: &Response) {
405        let aggregator_version = response
406            .headers()
407            .get(MITHRIL_API_VERSION_HEADER)
408            .and_then(|v| v.to_str().ok())
409            .and_then(|s| Version::parse(s).ok());
410
411        let client_version = self.compute_current_api_version().await;
412
413        match (aggregator_version, client_version) {
414            (Some(aggregator), Some(client)) if client < aggregator => {
415                warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
416                    "aggregator_version" => %aggregator,
417                    "client_version" => %client,
418                );
419            }
420            (Some(_), None) => {
421                error!(
422                    self.logger,
423                    "Failed to compute the current client API version"
424                );
425            }
426            _ => {}
427        }
428    }
429}
430
431#[cfg_attr(target_family = "wasm", async_trait(?Send))]
432#[cfg_attr(not(target_family = "wasm"), async_trait)]
433impl AggregatorClient for AggregatorHTTPClient {
434    async fn get_content(
435        &self,
436        request: AggregatorRequest,
437    ) -> Result<String, AggregatorClientError> {
438        let response = self.get(self.get_url_for_route(&request.route())?).await?;
439        let content = format!("{response:?}");
440
441        response.text().await.map_err(|e| {
442            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
443                "Could not find a JSON body in the response '{content}'."
444            )))
445        })
446    }
447
448    async fn post_content(
449        &self,
450        request: AggregatorRequest,
451    ) -> Result<String, AggregatorClientError> {
452        let response = self
453            .post(
454                self.get_url_for_route(&request.route())?,
455                &request.get_body().unwrap_or_default(),
456            )
457            .await?;
458
459        response.text().await.map_err(|e| {
460            AggregatorClientError::SubsystemError(
461                anyhow!(e).context("Could not find a text body in the response."),
462            )
463        })
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use httpmock::MockServer;
470    use std::collections::HashMap;
471    use strum::IntoEnumIterator;
472
473    use mithril_common::api_version::APIVersionProvider;
474    use mithril_common::entities::{ClientError, ServerError};
475
476    use crate::test_utils::TestLogger;
477
478    use super::*;
479
480    macro_rules! assert_error_eq {
481        ($left:expr, $right:expr) => {
482            assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
483        };
484    }
485
486    fn setup_client(
487        server_url: &str,
488        api_versions: Vec<Version>,
489        custom_headers: Option<HashMap<String, String>>,
490    ) -> AggregatorHTTPClient {
491        AggregatorHTTPClient::new(
492            Url::parse(server_url).unwrap(),
493            api_versions,
494            TestLogger::stdout(),
495            custom_headers,
496        )
497        .expect("building aggregator http client should not fail")
498    }
499
500    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
501        let server = MockServer::start();
502        let client = setup_client(
503            &server.url(""),
504            APIVersionProvider::compute_all_versions_sorted(),
505            None,
506        );
507        (server, client)
508    }
509
510    fn setup_server_and_client_with_custom_headers(
511        custom_headers: HashMap<String, String>,
512    ) -> (MockServer, AggregatorHTTPClient) {
513        let server = MockServer::start();
514        let client = setup_client(
515            &server.url(""),
516            APIVersionProvider::compute_all_versions_sorted(),
517            Some(custom_headers),
518        );
519        (server, client)
520    }
521
522    #[test]
523    fn always_append_trailing_slash_at_build() {
524        for (expected, url) in [
525            ("http://www.test.net/", "http://www.test.net/"),
526            ("http://www.test.net/", "http://www.test.net"),
527            (
528                "http://www.test.net/aggregator/",
529                "http://www.test.net/aggregator/",
530            ),
531            (
532                "http://www.test.net/aggregator/",
533                "http://www.test.net/aggregator",
534            ),
535        ] {
536            let url = Url::parse(url).unwrap();
537            let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
538                .expect("building aggregator http client should not fail");
539
540            assert_eq!(expected, client.aggregator_endpoint.as_str());
541        }
542    }
543
544    #[test]
545    fn deduce_routes_from_request() {
546        assert_eq!(
547            "certificate/abc".to_string(),
548            AggregatorRequest::GetCertificate {
549                hash: "abc".to_string()
550            }
551            .route()
552        );
553
554        assert_eq!(
555            "artifact/mithril-stake-distribution/abc".to_string(),
556            AggregatorRequest::GetMithrilStakeDistribution {
557                hash: "abc".to_string()
558            }
559            .route()
560        );
561
562        assert_eq!(
563            "artifact/mithril-stake-distribution/abc".to_string(),
564            AggregatorRequest::GetMithrilStakeDistribution {
565                hash: "abc".to_string()
566            }
567            .route()
568        );
569
570        assert_eq!(
571            "artifact/mithril-stake-distributions".to_string(),
572            AggregatorRequest::ListMithrilStakeDistributions.route()
573        );
574
575        assert_eq!(
576            "artifact/snapshot/abc".to_string(),
577            AggregatorRequest::GetSnapshot {
578                digest: "abc".to_string()
579            }
580            .route()
581        );
582
583        assert_eq!(
584            "artifact/snapshots".to_string(),
585            AggregatorRequest::ListSnapshots.route()
586        );
587
588        assert_eq!(
589            "statistics/snapshot".to_string(),
590            AggregatorRequest::IncrementSnapshotStatistic {
591                snapshot: "abc".to_string()
592            }
593            .route()
594        );
595
596        assert_eq!(
597            "artifact/cardano-database/abc".to_string(),
598            AggregatorRequest::GetCardanoDatabaseSnapshot {
599                hash: "abc".to_string()
600            }
601            .route()
602        );
603
604        assert_eq!(
605            "artifact/cardano-database".to_string(),
606            AggregatorRequest::ListCardanoDatabaseSnapshots.route()
607        );
608
609        assert_eq!(
610            "statistics/cardano-database/immutable-files-restored".to_string(),
611            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
612                number_of_immutables: 58
613            }
614            .route()
615        );
616
617        assert_eq!(
618            "statistics/cardano-database/ancillary-files-restored".to_string(),
619            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
620        );
621
622        assert_eq!(
623            "statistics/cardano-database/complete-restoration".to_string(),
624            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
625        );
626
627        assert_eq!(
628            "statistics/cardano-database/partial-restoration".to_string(),
629            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
630        );
631
632        assert_eq!(
633            "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
634            AggregatorRequest::GetTransactionsProofs {
635                transactions_hashes: vec![
636                    "abc".to_string(),
637                    "def".to_string(),
638                    "ghi".to_string(),
639                    "jkl".to_string()
640                ]
641            }
642            .route()
643        );
644
645        assert_eq!(
646            "artifact/cardano-transaction/abc".to_string(),
647            AggregatorRequest::GetCardanoTransactionSnapshot {
648                hash: "abc".to_string()
649            }
650            .route()
651        );
652
653        assert_eq!(
654            "artifact/cardano-transactions".to_string(),
655            AggregatorRequest::ListCardanoTransactionSnapshots.route()
656        );
657
658        assert_eq!(
659            "artifact/cardano-stake-distribution/abc".to_string(),
660            AggregatorRequest::GetCardanoStakeDistribution {
661                hash: "abc".to_string()
662            }
663            .route()
664        );
665
666        assert_eq!(
667            "artifact/cardano-stake-distribution/epoch/123".to_string(),
668            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
669        );
670
671        assert_eq!(
672            "artifact/cardano-stake-distributions".to_string(),
673            AggregatorRequest::ListCardanoStakeDistributions.route()
674        );
675
676        assert_eq!("status".to_string(), AggregatorRequest::Status.route());
677    }
678
679    #[test]
680    fn deduce_body_from_request() {
681        fn that_should_not_have_body(req: &AggregatorRequest) -> bool {
682            !matches!(
683                req,
684                AggregatorRequest::IncrementSnapshotStatistic { .. }
685                    | AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic { .. }
686            )
687        }
688
689        assert_eq!(
690            Some(r#"{"key":"value"}"#.to_string()),
691            AggregatorRequest::IncrementSnapshotStatistic {
692                snapshot: r#"{"key":"value"}"#.to_string()
693            }
694            .get_body()
695        );
696
697        assert_eq!(
698            Some(
699                serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
700                    nb_immutable_files: 432,
701                })
702                .unwrap()
703            ),
704            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
705                number_of_immutables: 432,
706            }
707            .get_body()
708        );
709
710        for req_that_should_not_have_body in
711            AggregatorRequest::iter().filter(that_should_not_have_body)
712        {
713            assert_eq!(None, req_that_should_not_have_body.get_body());
714        }
715    }
716
717    #[tokio::test]
718    async fn test_client_handle_4xx_errors() {
719        let client_error = ClientError::new("label", "message");
720
721        let (aggregator, client) = setup_server_and_client();
722        aggregator.mock(|_when, then| {
723            then.status(StatusCode::IM_A_TEAPOT.as_u16())
724                .json_body_obj(&client_error);
725        });
726
727        let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
728
729        let get_content_error = client
730            .get_content(AggregatorRequest::ListCertificates)
731            .await
732            .unwrap_err();
733        assert_error_eq!(get_content_error, expected_error);
734
735        let post_content_error = client
736            .post_content(AggregatorRequest::ListCertificates)
737            .await
738            .unwrap_err();
739        assert_error_eq!(post_content_error, expected_error);
740    }
741
742    #[tokio::test]
743    async fn test_client_handle_404_not_found_error() {
744        let client_error = ClientError::new("label", "message");
745
746        let (aggregator, client) = setup_server_and_client();
747        aggregator.mock(|_when, then| {
748            then.status(StatusCode::NOT_FOUND.as_u16())
749                .json_body_obj(&client_error);
750        });
751
752        let expected_error = AggregatorHTTPClient::not_found_error(
753            Url::parse(&format!(
754                "{}/{}",
755                aggregator.base_url(),
756                AggregatorRequest::ListCertificates.route()
757            ))
758            .unwrap(),
759        );
760
761        let get_content_error = client
762            .get_content(AggregatorRequest::ListCertificates)
763            .await
764            .unwrap_err();
765        assert_error_eq!(get_content_error, expected_error);
766
767        let post_content_error = client
768            .post_content(AggregatorRequest::ListCertificates)
769            .await
770            .unwrap_err();
771        assert_error_eq!(post_content_error, expected_error);
772    }
773
774    #[tokio::test]
775    async fn test_client_handle_5xx_errors() {
776        let server_error = ServerError::new("message");
777
778        let (aggregator, client) = setup_server_and_client();
779        aggregator.mock(|_when, then| {
780            then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
781                .json_body_obj(&server_error);
782        });
783
784        let expected_error =
785            AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
786
787        let get_content_error = client
788            .get_content(AggregatorRequest::ListCertificates)
789            .await
790            .unwrap_err();
791        assert_error_eq!(get_content_error, expected_error);
792
793        let post_content_error = client
794            .post_content(AggregatorRequest::ListCertificates)
795            .await
796            .unwrap_err();
797        assert_error_eq!(post_content_error, expected_error);
798    }
799
800    #[tokio::test]
801    async fn test_client_with_custom_headers() {
802        let mut http_headers = HashMap::new();
803        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
804        http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
805        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
806        aggregator.mock(|when, then| {
807            when.header("Custom-Header", "CustomValue")
808                .header("Another-Header", "AnotherValue");
809            then.status(StatusCode::OK.as_u16()).body("ok");
810        });
811
812        client
813            .get_content(AggregatorRequest::ListCertificates)
814            .await
815            .expect("GET request should succeed");
816
817        client
818            .post_content(AggregatorRequest::ListCertificates)
819            .await
820            .expect("GET request should succeed");
821    }
822
823    #[tokio::test]
824    async fn test_client_sends_accept_encoding_header_with_correct_values() {
825        let (aggregator, client) = setup_server_and_client();
826        aggregator.mock(|when, then| {
827            when.matches(|req| {
828                let headers = req.headers.clone().expect("HTTP headers not found");
829                let accept_encoding_header = headers
830                    .iter()
831                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
832                    .expect("Accept-Encoding header not found");
833
834                let header_value = accept_encoding_header.clone().1;
835                ["gzip", "br", "deflate", "zstd"]
836                    .iter()
837                    .all(|&value| header_value.contains(value))
838            });
839
840            then.status(200).body("ok");
841        });
842
843        client
844            .get_content(AggregatorRequest::ListCertificates)
845            .await
846            .expect("GET request should succeed with Accept-Encoding header");
847    }
848
849    #[tokio::test]
850    async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
851        let mut http_headers = HashMap::new();
852        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
853        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
854        aggregator.mock(|when, then| {
855            when.matches(|req| {
856                let headers = req.headers.clone().expect("HTTP headers not found");
857                let accept_encoding_header = headers
858                    .iter()
859                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
860                    .expect("Accept-Encoding header not found");
861
862                let header_value = accept_encoding_header.clone().1;
863                ["gzip", "br", "deflate", "zstd"]
864                    .iter()
865                    .all(|&value| header_value.contains(value))
866            });
867
868            then.status(200).body("ok");
869        });
870
871        client
872            .get_content(AggregatorRequest::ListCertificates)
873            .await
874            .expect("GET request should succeed with Accept-Encoding header");
875    }
876
877    mod warn_if_api_version_mismatch {
878        use http::response::Builder as HttpResponseBuilder;
879
880        use mithril_common::test_utils::MemoryDrainForTestInspector;
881
882        use super::*;
883
884        fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
885            key: K,
886            value: V,
887        ) -> Response {
888            HttpResponseBuilder::new()
889                .header(key.into(), value.into())
890                .body("whatever")
891                .unwrap()
892                .into()
893        }
894
895        fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
896            log_inspector: &MemoryDrainForTestInspector,
897            aggregator_version: A,
898            client_version: S,
899        ) {
900            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
901            assert!(
902                log_inspector
903                    .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
904            );
905            assert!(
906                log_inspector.contains_log(&format!("client_version={}", client_version.into()))
907            );
908        }
909
910        #[tokio::test]
911        async fn test_logs_warning_when_aggregator_api_version_is_newer() {
912            let aggregator_version = "2.0.0";
913            let client_version = "1.0.0";
914            let (logger, log_inspector) = TestLogger::memory();
915            let mut client = setup_client(
916                "http://whatever",
917                vec![Version::parse(client_version).unwrap()],
918                None,
919            );
920            client.logger = logger;
921            let response =
922                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
923
924            assert!(
925                Version::parse(aggregator_version).unwrap()
926                    > Version::parse(client_version).unwrap()
927            );
928
929            client.warn_if_api_version_mismatch(&response).await;
930
931            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
932        }
933
934        #[tokio::test]
935        async fn test_no_warning_logged_when_versions_match() {
936            let version = "1.0.0";
937            let (logger, log_inspector) = TestLogger::memory();
938            let mut client = setup_client(
939                "http://whatever",
940                vec![Version::parse(version).unwrap()],
941                None,
942            );
943            client.logger = logger;
944            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
945
946            client.warn_if_api_version_mismatch(&response).await;
947
948            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
949        }
950
951        #[tokio::test]
952        async fn test_no_warning_logged_when_aggregator_api_version_is_older() {
953            let aggregator_version = "1.0.0";
954            let client_version = "2.0.0";
955            let (logger, log_inspector) = TestLogger::memory();
956            let mut client = setup_client(
957                "http://whatever",
958                vec![Version::parse(client_version).unwrap()],
959                None,
960            );
961            client.logger = logger;
962            let response =
963                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
964
965            assert!(
966                Version::parse(aggregator_version).unwrap()
967                    < Version::parse(client_version).unwrap()
968            );
969
970            client.warn_if_api_version_mismatch(&response).await;
971
972            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
973        }
974
975        #[tokio::test]
976        async fn test_does_not_log_or_fail_when_header_is_missing() {
977            let (logger, log_inspector) = TestLogger::memory();
978            let mut client = setup_client(
979                "http://whatever",
980                APIVersionProvider::compute_all_versions_sorted(),
981                None,
982            );
983            client.logger = logger;
984            let response =
985                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
986
987            client.warn_if_api_version_mismatch(&response).await;
988
989            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
990        }
991
992        #[tokio::test]
993        async fn test_does_not_log_or_fail_when_header_is_not_a_version() {
994            let (logger, log_inspector) = TestLogger::memory();
995            let mut client = setup_client(
996                "http://whatever",
997                APIVersionProvider::compute_all_versions_sorted(),
998                None,
999            );
1000            client.logger = logger;
1001            let response =
1002                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1003
1004            client.warn_if_api_version_mismatch(&response).await;
1005
1006            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1007        }
1008
1009        #[tokio::test]
1010        async fn test_logs_error_when_client_version_cannot_be_computed() {
1011            let (logger, log_inspector) = TestLogger::memory();
1012            let mut client = setup_client("http://whatever", vec![], None);
1013            client.logger = logger;
1014            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1015
1016            client.warn_if_api_version_mismatch(&response).await;
1017
1018            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1019        }
1020
1021        #[tokio::test]
1022        async fn test_client_get_log_warning_if_api_version_mismatch() {
1023            let aggregator_version = "2.0.0";
1024            let client_version = "1.0.0";
1025            let (server, mut client) = setup_server_and_client();
1026            let (logger, log_inspector) = TestLogger::memory();
1027            client.api_versions =
1028                Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1029            client.logger = logger;
1030            server.mock(|_, then| {
1031                then.status(StatusCode::OK.as_u16())
1032                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1033            });
1034
1035            assert!(
1036                Version::parse(aggregator_version).unwrap()
1037                    > Version::parse(client_version).unwrap()
1038            );
1039
1040            client.get(Url::parse(&server.base_url()).unwrap()).await.unwrap();
1041
1042            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1043        }
1044
1045        #[tokio::test]
1046        async fn test_client_post_log_warning_if_api_version_mismatch() {
1047            let aggregator_version = "2.0.0";
1048            let client_version = "1.0.0";
1049            let (server, mut client) = setup_server_and_client();
1050            let (logger, log_inspector) = TestLogger::memory();
1051            client.api_versions =
1052                Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1053            client.logger = logger;
1054            server.mock(|_, then| {
1055                then.status(StatusCode::OK.as_u16())
1056                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1057            });
1058
1059            assert!(
1060                Version::parse(aggregator_version).unwrap()
1061                    > Version::parse(client_version).unwrap()
1062            );
1063
1064            client
1065                .post(Url::parse(&server.base_url()).unwrap(), "whatever")
1066                .await
1067                .unwrap();
1068
1069            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1070        }
1071    }
1072}