mithril_client/
aggregator_client.rs

1//! Mechanisms to exchange data with an Aggregator.
2//!
3//! The [AggregatorClient] trait abstracts how the communication with an Aggregator
4//! is done.
5//! The clients that need to communicate only need to define their request using the
6//! [AggregatorRequest] enum.
7//!
8//! An implementation using HTTP is available: [AggregatorHTTPClient].
9
10use anyhow::{anyhow, Context};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{debug, Logger};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::entities::{ClientError, ServerError};
23use mithril_common::logging::LoggerExtensions;
24use mithril_common::MITHRIL_API_VERSION_HEADER;
25
26use crate::common::Epoch;
27use crate::{MithrilError, MithrilResult};
28
29/// Error tied with the Aggregator client
30#[derive(Error, Debug)]
31pub enum AggregatorClientError {
32    /// Error raised when querying the aggregator returned a 5XX error.
33    #[error("Internal error of the Aggregator")]
34    RemoteServerTechnical(#[source] MithrilError),
35
36    /// Error raised when querying the aggregator returned a 4XX error.
37    #[error("Invalid request to the Aggregator")]
38    RemoteServerLogical(#[source] MithrilError),
39
40    /// Error raised when the server API version mismatch the client API version.
41    #[error("API version mismatch")]
42    ApiVersionMismatch(#[source] MithrilError),
43
44    /// HTTP subsystem error
45    #[error("HTTP subsystem error")]
46    SubsystemError(#[source] MithrilError),
47}
48
49/// What can be read from an [AggregatorClient].
50#[derive(Debug, Clone, Eq, PartialEq)]
51pub enum AggregatorRequest {
52    /// Get a specific [certificate][crate::MithrilCertificate] from the aggregator
53    GetCertificate {
54        /// Hash of the certificate to retrieve
55        hash: String,
56    },
57
58    /// Lists the aggregator [certificates][crate::MithrilCertificate]
59    ListCertificates,
60
61    /// Get a specific [Mithril stake distribution][crate::MithrilStakeDistribution] from the aggregator
62    GetMithrilStakeDistribution {
63        /// Hash of the Mithril stake distribution to retrieve
64        hash: String,
65    },
66
67    /// Lists the aggregator [Mithril stake distribution][crate::MithrilStakeDistribution]
68    ListMithrilStakeDistributions,
69
70    /// Get a specific [snapshot][crate::Snapshot] from the aggregator
71    GetSnapshot {
72        /// Digest of the snapshot to retrieve
73        digest: String,
74    },
75
76    /// Lists the aggregator [snapshots][crate::Snapshot]
77    ListSnapshots,
78
79    /// Increments the aggregator snapshot download statistics
80    IncrementSnapshotStatistic {
81        /// Snapshot as HTTP request body
82        snapshot: String,
83    },
84
85    /// Get a specific [Cardano database snapshot][crate::CardanoDatabaseSnapshot] from the aggregator
86    #[cfg(feature = "unstable")]
87    GetCardanoDatabaseSnapshot {
88        /// Hash of the snapshot to retrieve
89        hash: String,
90    },
91
92    /// Lists the aggregator [Cardano database snapshots][crate::CardanoDatabaseSnapshot]
93    #[cfg(feature = "unstable")]
94    ListCardanoDatabaseSnapshots,
95
96    /// Increments the aggregator Cardano database snapshot immutable files restored statistics
97    #[cfg(feature = "unstable")]
98    IncrementCardanoDatabaseImmutablesRestoredStatistic {
99        /// Number of immutable files restored
100        number_of_immutables: String,
101    },
102
103    /// Increments the aggregator Cardano database snapshot ancillary files restored statistics
104    #[cfg(feature = "unstable")]
105    IncrementCardanoDatabaseAncillaryStatistic,
106
107    /// Increments the aggregator Cardano database snapshot complete restoration statistics
108    #[cfg(feature = "unstable")]
109    IncrementCardanoDatabaseCompleteRestorationStatistic,
110
111    /// Increments the aggregator Cardano database snapshot partial restoration statistics
112    #[cfg(feature = "unstable")]
113    IncrementCardanoDatabasePartialRestorationStatistic,
114
115    /// Get proofs that the given set of Cardano transactions is included in the global Cardano transactions set
116    GetTransactionsProofs {
117        /// Hashes of the transactions to get proofs for.
118        transactions_hashes: Vec<String>,
119    },
120
121    /// Get a specific [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
122    GetCardanoTransactionSnapshot {
123        /// Hash of the Cardano transaction snapshot to retrieve
124        hash: String,
125    },
126
127    /// Lists the aggregator [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
128    ListCardanoTransactionSnapshots,
129
130    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by hash
131    GetCardanoStakeDistribution {
132        /// Hash of the Cardano stake distribution to retrieve
133        hash: String,
134    },
135
136    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by epoch
137    GetCardanoStakeDistributionByEpoch {
138        /// Epoch at the end of which the Cardano stake distribution is computed by the Cardano node
139        epoch: Epoch,
140    },
141
142    /// Lists the aggregator [Cardano stake distribution][crate::CardanoStakeDistribution]
143    ListCardanoStakeDistributions,
144}
145
146impl AggregatorRequest {
147    /// Get the request route relative to the aggregator root endpoint.
148    pub fn route(&self) -> String {
149        match self {
150            AggregatorRequest::GetCertificate { hash } => {
151                format!("certificate/{hash}")
152            }
153            AggregatorRequest::ListCertificates => "certificates".to_string(),
154            AggregatorRequest::GetMithrilStakeDistribution { hash } => {
155                format!("artifact/mithril-stake-distribution/{hash}")
156            }
157            AggregatorRequest::ListMithrilStakeDistributions => {
158                "artifact/mithril-stake-distributions".to_string()
159            }
160            AggregatorRequest::GetSnapshot { digest } => {
161                format!("artifact/snapshot/{}", digest)
162            }
163            AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
164            AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
165                "statistics/snapshot".to_string()
166            }
167            #[cfg(feature = "unstable")]
168            AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
169                format!("artifact/cardano-database/{}", hash)
170            }
171            #[cfg(feature = "unstable")]
172            AggregatorRequest::ListCardanoDatabaseSnapshots => {
173                "artifact/cardano-database".to_string()
174            }
175            #[cfg(feature = "unstable")]
176            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
177                number_of_immutables: _,
178            } => "statistics/cardano-database/immutable-files-restored".to_string(),
179            #[cfg(feature = "unstable")]
180            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
181                "statistics/cardano-database/ancillary-files-restored".to_string()
182            }
183            #[cfg(feature = "unstable")]
184            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
185                "statistics/cardano-database/complete-restoration".to_string()
186            }
187            #[cfg(feature = "unstable")]
188            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
189                "statistics/cardano-database/partial-restoration".to_string()
190            }
191            AggregatorRequest::GetTransactionsProofs {
192                transactions_hashes,
193            } => format!(
194                "proof/cardano-transaction?transaction_hashes={}",
195                transactions_hashes.join(",")
196            ),
197            AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
198                format!("artifact/cardano-transaction/{hash}")
199            }
200            AggregatorRequest::ListCardanoTransactionSnapshots => {
201                "artifact/cardano-transactions".to_string()
202            }
203            AggregatorRequest::GetCardanoStakeDistribution { hash } => {
204                format!("artifact/cardano-stake-distribution/{hash}")
205            }
206            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
207                format!("artifact/cardano-stake-distribution/epoch/{epoch}")
208            }
209            AggregatorRequest::ListCardanoStakeDistributions => {
210                "artifact/cardano-stake-distributions".to_string()
211            }
212        }
213    }
214
215    /// Get the request body to send to the aggregator
216    pub fn get_body(&self) -> Option<String> {
217        match self {
218            AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
219                Some(snapshot.to_string())
220            }
221            _ => None,
222        }
223    }
224}
225
226/// API that defines a client for the Aggregator
227#[cfg_attr(test, mockall::automock)]
228#[cfg_attr(target_family = "wasm", async_trait(?Send))]
229#[cfg_attr(not(target_family = "wasm"), async_trait)]
230pub trait AggregatorClient: Sync + Send {
231    /// Get the content back from the Aggregator
232    async fn get_content(
233        &self,
234        request: AggregatorRequest,
235    ) -> Result<String, AggregatorClientError>;
236
237    /// Post information to the Aggregator
238    async fn post_content(
239        &self,
240        request: AggregatorRequest,
241    ) -> Result<String, AggregatorClientError>;
242}
243
244/// Responsible for HTTP transport and API version check.
245pub struct AggregatorHTTPClient {
246    http_client: reqwest::Client,
247    aggregator_endpoint: Url,
248    api_versions: Arc<RwLock<Vec<Version>>>,
249    logger: Logger,
250    http_headers: HeaderMap,
251}
252
253impl AggregatorHTTPClient {
254    /// Constructs a new `AggregatorHTTPClient`
255    pub fn new(
256        aggregator_endpoint: Url,
257        api_versions: Vec<Version>,
258        logger: Logger,
259        custom_headers: Option<HashMap<String, String>>,
260    ) -> MithrilResult<Self> {
261        let http_client = reqwest::ClientBuilder::new()
262            .build()
263            .with_context(|| "Building http client for Aggregator client failed")?;
264
265        // Trailing slash is significant because url::join
266        // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove
267        // the 'path' part of the url if it doesn't end with a trailing slash.
268        let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
269            aggregator_endpoint
270        } else {
271            let mut url = aggregator_endpoint.clone();
272            url.set_path(&format!("{}/", aggregator_endpoint.path()));
273            url
274        };
275
276        let mut http_headers = HeaderMap::new();
277        if let Some(headers) = custom_headers {
278            for (key, value) in headers.iter() {
279                http_headers.insert(
280                    HeaderName::from_bytes(key.as_bytes())?,
281                    HeaderValue::from_str(value)?,
282                );
283            }
284        }
285
286        Ok(Self {
287            http_client,
288            aggregator_endpoint,
289            api_versions: Arc::new(RwLock::new(api_versions)),
290            logger: logger.new_with_component_name::<Self>(),
291            http_headers,
292        })
293    }
294
295    /// Computes the current api version
296    async fn compute_current_api_version(&self) -> Option<Version> {
297        self.api_versions.read().await.first().cloned()
298    }
299
300    /// Discards the current api version
301    /// It discards the current version if and only if there is at least 2 versions available
302    async fn discard_current_api_version(&self) -> Option<Version> {
303        if self.api_versions.read().await.len() < 2 {
304            return None;
305        }
306        if let Some(current_api_version) = self.compute_current_api_version().await {
307            let mut api_versions = self.api_versions.write().await;
308            if let Some(index) = api_versions
309                .iter()
310                .position(|value| *value == current_api_version)
311            {
312                api_versions.remove(index);
313                return Some(current_api_version);
314            }
315        }
316        None
317    }
318
319    /// Perform a HTTP GET request on the Aggregator and return the given JSON
320    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
321    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
322    async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
323        debug!(self.logger, "GET url='{url}'.");
324        let request_builder = self.http_client.get(url.clone());
325        let current_api_version = self
326            .compute_current_api_version()
327            .await
328            .unwrap()
329            .to_string();
330        debug!(
331            self.logger,
332            "Prepare request with version: {current_api_version}"
333        );
334        let request_builder = request_builder
335            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
336            .headers(self.http_headers.clone());
337
338        let response = request_builder.send().await.map_err(|e| {
339            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
340                "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
341            )))
342        })?;
343
344        match response.status() {
345            StatusCode::OK => Ok(response),
346            StatusCode::PRECONDITION_FAILED => {
347                if self.discard_current_api_version().await.is_some()
348                    && !self.api_versions.read().await.is_empty()
349                {
350                    return self.get(url).await;
351                }
352
353                Err(self.handle_api_error(response.headers()).await)
354            }
355            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
356            status_code if status_code.is_client_error() => {
357                Err(Self::remote_logical_error(response).await)
358            }
359            _ => Err(Self::remote_technical_error(response).await),
360        }
361    }
362
363    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
364    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
365    async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
366        debug!(self.logger, "POST url='{url}'"; "json" => json);
367        let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
368        let current_api_version = self
369            .compute_current_api_version()
370            .await
371            .unwrap()
372            .to_string();
373        debug!(
374            self.logger,
375            "Prepare request with version: {current_api_version}"
376        );
377        let request_builder = request_builder
378            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
379            .headers(self.http_headers.clone());
380
381        let response = request_builder.send().await.map_err(|e| {
382            AggregatorClientError::SubsystemError(
383                anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
384            )
385        })?;
386
387        match response.status() {
388            StatusCode::OK | StatusCode::CREATED => Ok(response),
389            StatusCode::PRECONDITION_FAILED => {
390                if self.discard_current_api_version().await.is_some()
391                    && !self.api_versions.read().await.is_empty()
392                {
393                    return self.post(url, json).await;
394                }
395
396                Err(self.handle_api_error(response.headers()).await)
397            }
398            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
399            status_code if status_code.is_client_error() => {
400                Err(Self::remote_logical_error(response).await)
401            }
402            _ => Err(Self::remote_technical_error(response).await),
403        }
404    }
405
406    fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
407        self.aggregator_endpoint
408            .join(endpoint)
409            .with_context(|| {
410                format!(
411                    "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
412                    self.aggregator_endpoint
413                )
414            })
415            .map_err(AggregatorClientError::SubsystemError)
416    }
417
418    /// API version error handling
419    async fn handle_api_error(&self, response_header: &HeaderMap) -> AggregatorClientError {
420        if let Some(version) = response_header.get(MITHRIL_API_VERSION_HEADER) {
421            AggregatorClientError::ApiVersionMismatch(anyhow!(
422                "server version: '{}', signer version: '{}'",
423                version.to_str().unwrap(),
424                self.compute_current_api_version().await.unwrap()
425            ))
426        } else {
427            AggregatorClientError::ApiVersionMismatch(anyhow!(
428                "version precondition failed, sent version '{}'.",
429                self.compute_current_api_version().await.unwrap()
430            ))
431        }
432    }
433
434    fn not_found_error(url: Url) -> AggregatorClientError {
435        AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
436    }
437
438    async fn remote_logical_error(response: Response) -> AggregatorClientError {
439        let status_code = response.status();
440        let client_error = response
441            .json::<ClientError>()
442            .await
443            .unwrap_or(ClientError::new(
444                format!("Unhandled error {status_code}"),
445                "",
446            ));
447
448        AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
449    }
450
451    async fn remote_technical_error(response: Response) -> AggregatorClientError {
452        let status_code = response.status();
453        let server_error = response
454            .json::<ServerError>()
455            .await
456            .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
457
458        AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
459    }
460}
461
462#[cfg_attr(target_family = "wasm", async_trait(?Send))]
463#[cfg_attr(not(target_family = "wasm"), async_trait)]
464impl AggregatorClient for AggregatorHTTPClient {
465    async fn get_content(
466        &self,
467        request: AggregatorRequest,
468    ) -> Result<String, AggregatorClientError> {
469        let response = self.get(self.get_url_for_route(&request.route())?).await?;
470        let content = format!("{response:?}");
471
472        response.text().await.map_err(|e| {
473            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
474                "Could not find a JSON body in the response '{content}'."
475            )))
476        })
477    }
478
479    async fn post_content(
480        &self,
481        request: AggregatorRequest,
482    ) -> Result<String, AggregatorClientError> {
483        let response = self
484            .post(
485                self.get_url_for_route(&request.route())?,
486                &request.get_body().unwrap_or_default(),
487            )
488            .await?;
489
490        response.text().await.map_err(|e| {
491            AggregatorClientError::SubsystemError(
492                anyhow!(e).context("Could not find a text body in the response."),
493            )
494        })
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use httpmock::MockServer;
501    use reqwest::header::{HeaderName, HeaderValue};
502    use std::collections::HashMap;
503
504    use mithril_common::api_version::APIVersionProvider;
505    use mithril_common::entities::{ClientError, ServerError};
506
507    use crate::test_utils::TestLogger;
508
509    use super::*;
510
511    macro_rules! assert_error_eq {
512        ($left:expr, $right:expr) => {
513            assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
514        };
515    }
516
517    fn setup_client(
518        server_url: &str,
519        api_versions: Vec<Version>,
520        custom_headers: Option<HashMap<String, String>>,
521    ) -> AggregatorHTTPClient {
522        AggregatorHTTPClient::new(
523            Url::parse(server_url).unwrap(),
524            api_versions,
525            TestLogger::stdout(),
526            custom_headers,
527        )
528        .expect("building aggregator http client should not fail")
529    }
530
531    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
532        let server = MockServer::start();
533        let client = setup_client(
534            &server.url(""),
535            APIVersionProvider::compute_all_versions_sorted().unwrap(),
536            None,
537        );
538        (server, client)
539    }
540
541    fn setup_server_and_client_with_custom_headers(
542        custom_headers: HashMap<String, String>,
543    ) -> (MockServer, AggregatorHTTPClient) {
544        let server = MockServer::start();
545        let client = setup_client(
546            &server.url(""),
547            APIVersionProvider::compute_all_versions_sorted().unwrap(),
548            Some(custom_headers),
549        );
550        (server, client)
551    }
552
553    fn mithril_api_version_headers(version: &str) -> HeaderMap {
554        HeaderMap::from_iter([(
555            HeaderName::from_static(MITHRIL_API_VERSION_HEADER),
556            HeaderValue::from_str(version).unwrap(),
557        )])
558    }
559
560    #[test]
561    fn always_append_trailing_slash_at_build() {
562        for (expected, url) in [
563            ("http://www.test.net/", "http://www.test.net/"),
564            ("http://www.test.net/", "http://www.test.net"),
565            (
566                "http://www.test.net/aggregator/",
567                "http://www.test.net/aggregator/",
568            ),
569            (
570                "http://www.test.net/aggregator/",
571                "http://www.test.net/aggregator",
572            ),
573        ] {
574            let url = Url::parse(url).unwrap();
575            let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
576                .expect("building aggregator http client should not fail");
577
578            assert_eq!(expected, client.aggregator_endpoint.as_str());
579        }
580    }
581
582    #[test]
583    fn deduce_routes_from_request() {
584        assert_eq!(
585            "certificate/abc".to_string(),
586            AggregatorRequest::GetCertificate {
587                hash: "abc".to_string()
588            }
589            .route()
590        );
591
592        assert_eq!(
593            "artifact/mithril-stake-distribution/abc".to_string(),
594            AggregatorRequest::GetMithrilStakeDistribution {
595                hash: "abc".to_string()
596            }
597            .route()
598        );
599
600        assert_eq!(
601            "artifact/mithril-stake-distribution/abc".to_string(),
602            AggregatorRequest::GetMithrilStakeDistribution {
603                hash: "abc".to_string()
604            }
605            .route()
606        );
607
608        assert_eq!(
609            "artifact/mithril-stake-distributions".to_string(),
610            AggregatorRequest::ListMithrilStakeDistributions.route()
611        );
612
613        assert_eq!(
614            "artifact/snapshot/abc".to_string(),
615            AggregatorRequest::GetSnapshot {
616                digest: "abc".to_string()
617            }
618            .route()
619        );
620
621        assert_eq!(
622            "artifact/snapshots".to_string(),
623            AggregatorRequest::ListSnapshots.route()
624        );
625
626        assert_eq!(
627            "statistics/snapshot".to_string(),
628            AggregatorRequest::IncrementSnapshotStatistic {
629                snapshot: "abc".to_string()
630            }
631            .route()
632        );
633
634        #[cfg(feature = "unstable")]
635        assert_eq!(
636            "artifact/cardano-database/abc".to_string(),
637            AggregatorRequest::GetCardanoDatabaseSnapshot {
638                hash: "abc".to_string()
639            }
640            .route()
641        );
642
643        #[cfg(feature = "unstable")]
644        assert_eq!(
645            "artifact/cardano-database".to_string(),
646            AggregatorRequest::ListCardanoDatabaseSnapshots.route()
647        );
648
649        #[cfg(feature = "unstable")]
650        assert_eq!(
651            "statistics/cardano-database/immutable-files-restored".to_string(),
652            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
653                number_of_immutables: "abc".to_string()
654            }
655            .route()
656        );
657
658        #[cfg(feature = "unstable")]
659        assert_eq!(
660            "statistics/cardano-database/ancillary-files-restored".to_string(),
661            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
662        );
663
664        #[cfg(feature = "unstable")]
665        assert_eq!(
666            "statistics/cardano-database/complete-restoration".to_string(),
667            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
668        );
669
670        #[cfg(feature = "unstable")]
671        assert_eq!(
672            "statistics/cardano-database/partial-restoration".to_string(),
673            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
674        );
675
676        assert_eq!(
677            "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
678            AggregatorRequest::GetTransactionsProofs {
679                transactions_hashes: vec![
680                    "abc".to_string(),
681                    "def".to_string(),
682                    "ghi".to_string(),
683                    "jkl".to_string()
684                ]
685            }
686            .route()
687        );
688
689        assert_eq!(
690            "artifact/cardano-transaction/abc".to_string(),
691            AggregatorRequest::GetCardanoTransactionSnapshot {
692                hash: "abc".to_string()
693            }
694            .route()
695        );
696
697        assert_eq!(
698            "artifact/cardano-transactions".to_string(),
699            AggregatorRequest::ListCardanoTransactionSnapshots.route()
700        );
701
702        assert_eq!(
703            "artifact/cardano-stake-distribution/abc".to_string(),
704            AggregatorRequest::GetCardanoStakeDistribution {
705                hash: "abc".to_string()
706            }
707            .route()
708        );
709
710        assert_eq!(
711            "artifact/cardano-stake-distribution/epoch/123".to_string(),
712            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
713        );
714
715        assert_eq!(
716            "artifact/cardano-stake-distributions".to_string(),
717            AggregatorRequest::ListCardanoStakeDistributions.route()
718        );
719    }
720
721    #[tokio::test]
722    async fn test_client_handle_4xx_errors() {
723        let client_error = ClientError::new("label", "message");
724
725        let (aggregator, client) = setup_server_and_client();
726        aggregator.mock(|_when, then| {
727            then.status(StatusCode::IM_A_TEAPOT.as_u16())
728                .json_body_obj(&client_error);
729        });
730
731        let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
732
733        let get_content_error = client
734            .get_content(AggregatorRequest::ListCertificates)
735            .await
736            .unwrap_err();
737        assert_error_eq!(get_content_error, expected_error);
738
739        let post_content_error = client
740            .post_content(AggregatorRequest::ListCertificates)
741            .await
742            .unwrap_err();
743        assert_error_eq!(post_content_error, expected_error);
744    }
745
746    #[tokio::test]
747    async fn test_client_handle_404_not_found_error() {
748        let client_error = ClientError::new("label", "message");
749
750        let (aggregator, client) = setup_server_and_client();
751        aggregator.mock(|_when, then| {
752            then.status(StatusCode::NOT_FOUND.as_u16())
753                .json_body_obj(&client_error);
754        });
755
756        let expected_error = AggregatorHTTPClient::not_found_error(
757            Url::parse(&format!(
758                "{}/{}",
759                aggregator.base_url(),
760                AggregatorRequest::ListCertificates.route()
761            ))
762            .unwrap(),
763        );
764
765        let get_content_error = client
766            .get_content(AggregatorRequest::ListCertificates)
767            .await
768            .unwrap_err();
769        assert_error_eq!(get_content_error, expected_error);
770
771        let post_content_error = client
772            .post_content(AggregatorRequest::ListCertificates)
773            .await
774            .unwrap_err();
775        assert_error_eq!(post_content_error, expected_error);
776    }
777
778    #[tokio::test]
779    async fn test_client_handle_5xx_errors() {
780        let server_error = ServerError::new("message");
781
782        let (aggregator, client) = setup_server_and_client();
783        aggregator.mock(|_when, then| {
784            then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
785                .json_body_obj(&server_error);
786        });
787
788        let expected_error =
789            AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
790
791        let get_content_error = client
792            .get_content(AggregatorRequest::ListCertificates)
793            .await
794            .unwrap_err();
795        assert_error_eq!(get_content_error, expected_error);
796
797        let post_content_error = client
798            .post_content(AggregatorRequest::ListCertificates)
799            .await
800            .unwrap_err();
801        assert_error_eq!(post_content_error, expected_error);
802    }
803
804    #[tokio::test]
805    async fn test_client_handle_412_api_version_mismatch_with_version_in_response_header() {
806        let version = "0.0.0";
807
808        let (aggregator, client) = setup_server_and_client();
809        aggregator.mock(|_when, then| {
810            then.status(StatusCode::PRECONDITION_FAILED.as_u16())
811                .header(MITHRIL_API_VERSION_HEADER, version);
812        });
813
814        let expected_error = client
815            .handle_api_error(&mithril_api_version_headers(version))
816            .await;
817
818        let get_content_error = client
819            .get_content(AggregatorRequest::ListCertificates)
820            .await
821            .unwrap_err();
822        assert_error_eq!(get_content_error, expected_error);
823
824        let post_content_error = client
825            .post_content(AggregatorRequest::ListCertificates)
826            .await
827            .unwrap_err();
828        assert_error_eq!(post_content_error, expected_error);
829    }
830
831    #[tokio::test]
832    async fn test_client_handle_412_api_version_mismatch_without_version_in_response_header() {
833        let (aggregator, client) = setup_server_and_client();
834        aggregator.mock(|_when, then| {
835            then.status(StatusCode::PRECONDITION_FAILED.as_u16());
836        });
837
838        let expected_error = client.handle_api_error(&HeaderMap::new()).await;
839
840        let get_content_error = client
841            .get_content(AggregatorRequest::ListCertificates)
842            .await
843            .unwrap_err();
844        assert_error_eq!(get_content_error, expected_error);
845
846        let post_content_error = client
847            .post_content(AggregatorRequest::ListCertificates)
848            .await
849            .unwrap_err();
850        assert_error_eq!(post_content_error, expected_error);
851    }
852
853    #[tokio::test]
854    async fn test_client_can_fallback_to_a_second_version_when_412_api_version_mismatch() {
855        let bad_version = "0.0.0";
856        let good_version = "1.0.0";
857
858        let aggregator = MockServer::start();
859        let client = setup_client(
860            &aggregator.url(""),
861            vec![
862                Version::parse(bad_version).unwrap(),
863                Version::parse(good_version).unwrap(),
864            ],
865            None,
866        );
867        aggregator.mock(|when, then| {
868            when.header(MITHRIL_API_VERSION_HEADER, bad_version);
869            then.status(StatusCode::PRECONDITION_FAILED.as_u16())
870                .header(MITHRIL_API_VERSION_HEADER, bad_version);
871        });
872        aggregator.mock(|when, then| {
873            when.header(MITHRIL_API_VERSION_HEADER, good_version);
874            then.status(StatusCode::OK.as_u16());
875        });
876
877        assert_eq!(
878            client.compute_current_api_version().await,
879            Some(Version::parse(bad_version).unwrap()),
880            "Bad version should be tried first"
881        );
882
883        client
884            .get_content(AggregatorRequest::ListCertificates)
885            .await
886            .expect("should have run with a fallback version");
887
888        client
889            .post_content(AggregatorRequest::ListCertificates)
890            .await
891            .expect("should have run with a fallback version");
892    }
893
894    #[tokio::test]
895    async fn test_client_with_custom_headers() {
896        let mut http_headers = HashMap::new();
897        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
898        http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
899        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
900        aggregator.mock(|when, then| {
901            when.header("Custom-Header", "CustomValue")
902                .header("Another-Header", "AnotherValue");
903            then.status(StatusCode::OK.as_u16()).body("ok");
904        });
905
906        client
907            .get_content(AggregatorRequest::ListCertificates)
908            .await
909            .expect("GET request should succeed");
910
911        client
912            .post_content(AggregatorRequest::ListCertificates)
913            .await
914            .expect("GET request should succeed");
915    }
916
917    #[tokio::test]
918    async fn test_client_sends_accept_encoding_header_with_correct_values() {
919        let (aggregator, client) = setup_server_and_client();
920        aggregator.mock(|when, then| {
921            when.matches(|req| {
922                let headers = req.headers.clone().expect("HTTP headers not found");
923                let accept_encoding_header = headers
924                    .iter()
925                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
926                    .expect("Accept-Encoding header not found");
927
928                let header_value = accept_encoding_header.clone().1;
929                ["gzip", "br", "deflate", "zstd"]
930                    .iter()
931                    .all(|&value| header_value.contains(value))
932            });
933
934            then.status(200).body("ok");
935        });
936
937        client
938            .get_content(AggregatorRequest::ListCertificates)
939            .await
940            .expect("GET request should succeed with Accept-Encoding header");
941    }
942
943    #[tokio::test]
944    async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
945        let mut http_headers = HashMap::new();
946        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
947        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
948        aggregator.mock(|when, then| {
949            when.matches(|req| {
950                let headers = req.headers.clone().expect("HTTP headers not found");
951                let accept_encoding_header = headers
952                    .iter()
953                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
954                    .expect("Accept-Encoding header not found");
955
956                let header_value = accept_encoding_header.clone().1;
957                ["gzip", "br", "deflate", "zstd"]
958                    .iter()
959                    .all(|&value| header_value.contains(value))
960            });
961
962            then.status(200).body("ok");
963        });
964
965        client
966            .get_content(AggregatorRequest::ListCertificates)
967            .await
968            .expect("GET request should succeed with Accept-Encoding header");
969    }
970}