mithril_client/
aggregator_client.rs

1//! Mechanisms to exchange data with an Aggregator.
2//!
3//! The [AggregatorClient] trait abstracts how the communication with an Aggregator
4//! is done.
5//! The clients that need to communicate only need to define their request using the
6//! [AggregatorRequest] enum.
7//!
8//! An implementation using HTTP is available: [AggregatorHTTPClient].
9
10use anyhow::{Context, anyhow};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{Logger, debug, error, warn};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::MITHRIL_API_VERSION_HEADER;
23use mithril_common::entities::{ClientError, ServerError};
24use mithril_common::logging::LoggerExtensions;
25use mithril_common::messages::CardanoDatabaseImmutableFilesRestoredMessage;
26
27use crate::common::Epoch;
28use crate::{MithrilError, MithrilResult};
29
30const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
31
32/// Error tied with the Aggregator client
33#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35    /// Error raised when querying the aggregator returned a 5XX error.
36    #[error("Internal error of the Aggregator")]
37    RemoteServerTechnical(#[source] MithrilError),
38
39    /// Error raised when querying the aggregator returned a 4XX error.
40    #[error("Invalid request to the Aggregator")]
41    RemoteServerLogical(#[source] MithrilError),
42
43    /// HTTP subsystem error
44    #[error("HTTP subsystem error")]
45    SubsystemError(#[source] MithrilError),
46}
47
48/// What can be read from an [AggregatorClient].
49#[derive(Debug, Clone, Eq, PartialEq)]
50#[cfg_attr(test, derive(strum::EnumIter))]
51pub enum AggregatorRequest {
52    /// Get a specific [certificate][crate::MithrilCertificate] from the aggregator
53    GetCertificate {
54        /// Hash of the certificate to retrieve
55        hash: String,
56    },
57
58    /// Lists the aggregator [certificates][crate::MithrilCertificate]
59    ListCertificates,
60
61    /// Get a specific [Mithril stake distribution][crate::MithrilStakeDistribution] from the aggregator
62    GetMithrilStakeDistribution {
63        /// Hash of the Mithril stake distribution to retrieve
64        hash: String,
65    },
66
67    /// Lists the aggregator [Mithril stake distribution][crate::MithrilStakeDistribution]
68    ListMithrilStakeDistributions,
69
70    /// Get a specific [snapshot][crate::Snapshot] from the aggregator
71    GetSnapshot {
72        /// Digest of the snapshot to retrieve
73        digest: String,
74    },
75
76    /// Lists the aggregator [snapshots][crate::Snapshot]
77    ListSnapshots,
78
79    /// Increments the aggregator snapshot download statistics
80    IncrementSnapshotStatistic {
81        /// Snapshot as HTTP request body
82        snapshot: String,
83    },
84
85    /// Get a specific [Cardano database snapshot][crate::CardanoDatabaseSnapshot] from the aggregator
86    GetCardanoDatabaseSnapshot {
87        /// Hash of the snapshot to retrieve
88        hash: String,
89    },
90
91    /// Lists the aggregator [Cardano database snapshots][crate::CardanoDatabaseSnapshot]
92    ListCardanoDatabaseSnapshots,
93
94    /// Lists the aggregator [Cardano database snapshots][crate::CardanoDatabaseSnapshot] for an epoch
95    ListCardanoDatabaseSnapshotByEpoch {
96        /// Epoch of the Cardano Database snapshots
97        epoch: Epoch,
98    },
99
100    /// Lists the aggregator [Cardano database snapshots][crate::CardanoDatabaseSnapshot] for the latest epoch
101    ListCardanoDatabaseSnapshotForLatestEpoch {
102        /// Optional offset to subtract to the latest epoch
103        offset: Option<u64>,
104    },
105
106    /// Increments the aggregator Cardano database snapshot immutable files restored statistics
107    IncrementCardanoDatabaseImmutablesRestoredStatistic {
108        /// Number of immutable files restored
109        number_of_immutables: u64,
110    },
111
112    /// Increments the aggregator Cardano database snapshot ancillary files restored statistics
113    IncrementCardanoDatabaseAncillaryStatistic,
114
115    /// Increments the aggregator Cardano database snapshot complete restoration statistics
116    IncrementCardanoDatabaseCompleteRestorationStatistic,
117
118    /// Increments the aggregator Cardano database snapshot partial restoration statistics
119    IncrementCardanoDatabasePartialRestorationStatistic,
120
121    /// Get proofs that the given set of Cardano transactions is included in the global Cardano transactions set
122    GetTransactionsProofs {
123        /// Hashes of the transactions to get proofs for.
124        transactions_hashes: Vec<String>,
125    },
126
127    /// Get a specific [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
128    GetCardanoTransactionSnapshot {
129        /// Hash of the Cardano transaction snapshot to retrieve
130        hash: String,
131    },
132
133    /// Lists the aggregator [Cardano transaction snapshot][crate::CardanoTransactionSnapshot]
134    ListCardanoTransactionSnapshots,
135
136    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by hash
137    GetCardanoStakeDistribution {
138        /// Hash of the Cardano stake distribution to retrieve
139        hash: String,
140    },
141
142    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator by epoch
143    GetCardanoStakeDistributionByEpoch {
144        /// Epoch at the end of which the Cardano stake distribution is computed by the Cardano node
145        epoch: Epoch,
146    },
147
148    /// Get a specific [Cardano stake distribution][crate::CardanoStakeDistribution] from the aggregator for the latest epoch
149    GetCardanoStakeDistributionForLatestEpoch {
150        /// Optional offset to subtract to the latest epoch
151        offset: Option<u64>,
152    },
153
154    /// Lists the aggregator [Cardano stake distribution][crate::CardanoStakeDistribution]
155    ListCardanoStakeDistributions,
156
157    /// Get information about the aggregator status
158    Status,
159}
160
161impl AggregatorRequest {
162    /// Get the request route relative to the aggregator root endpoint.
163    pub fn route(&self) -> String {
164        match self {
165            AggregatorRequest::GetCertificate { hash } => {
166                format!("certificate/{hash}")
167            }
168            AggregatorRequest::ListCertificates => "certificates".to_string(),
169            AggregatorRequest::GetMithrilStakeDistribution { hash } => {
170                format!("artifact/mithril-stake-distribution/{hash}")
171            }
172            AggregatorRequest::ListMithrilStakeDistributions => {
173                "artifact/mithril-stake-distributions".to_string()
174            }
175            AggregatorRequest::GetSnapshot { digest } => {
176                format!("artifact/snapshot/{digest}")
177            }
178            AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
179            AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
180                "statistics/snapshot".to_string()
181            }
182            AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
183                format!("artifact/cardano-database/{hash}")
184            }
185            AggregatorRequest::ListCardanoDatabaseSnapshots => {
186                "artifact/cardano-database".to_string()
187            }
188            AggregatorRequest::ListCardanoDatabaseSnapshotByEpoch { epoch } => {
189                format!("artifact/cardano-database/epoch/{epoch}")
190            }
191            AggregatorRequest::ListCardanoDatabaseSnapshotForLatestEpoch { offset } => {
192                format!(
193                    "artifact/cardano-database/epoch/latest{}",
194                    offset.map(|o| format!("-{o}")).unwrap_or_default()
195                )
196            }
197            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
198                number_of_immutables: _,
199            } => "statistics/cardano-database/immutable-files-restored".to_string(),
200            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
201                "statistics/cardano-database/ancillary-files-restored".to_string()
202            }
203            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
204                "statistics/cardano-database/complete-restoration".to_string()
205            }
206            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
207                "statistics/cardano-database/partial-restoration".to_string()
208            }
209            AggregatorRequest::GetTransactionsProofs {
210                transactions_hashes,
211            } => format!(
212                "proof/cardano-transaction?transaction_hashes={}",
213                transactions_hashes.join(",")
214            ),
215            AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
216                format!("artifact/cardano-transaction/{hash}")
217            }
218            AggregatorRequest::ListCardanoTransactionSnapshots => {
219                "artifact/cardano-transactions".to_string()
220            }
221            AggregatorRequest::GetCardanoStakeDistribution { hash } => {
222                format!("artifact/cardano-stake-distribution/{hash}")
223            }
224            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
225                format!("artifact/cardano-stake-distribution/epoch/{epoch}")
226            }
227            AggregatorRequest::GetCardanoStakeDistributionForLatestEpoch { offset } => {
228                format!(
229                    "artifact/cardano-stake-distribution/epoch/latest{}",
230                    offset.map(|o| format!("-{o}")).unwrap_or_default()
231                )
232            }
233            AggregatorRequest::ListCardanoStakeDistributions => {
234                "artifact/cardano-stake-distributions".to_string()
235            }
236            AggregatorRequest::Status => "status".to_string(),
237        }
238    }
239
240    /// Get the request body to send to the aggregator
241    pub fn get_body(&self) -> Option<String> {
242        match self {
243            AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
244                Some(snapshot.to_string())
245            }
246            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
247                number_of_immutables,
248            } => serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
249                nb_immutable_files: *number_of_immutables as u32,
250            })
251            .ok(),
252            _ => None,
253        }
254    }
255}
256
257/// API that defines a client for the Aggregator
258#[cfg_attr(test, mockall::automock)]
259#[cfg_attr(target_family = "wasm", async_trait(?Send))]
260#[cfg_attr(not(target_family = "wasm"), async_trait)]
261pub trait AggregatorClient: Sync + Send {
262    /// Get the content back from the Aggregator
263    async fn get_content(
264        &self,
265        request: AggregatorRequest,
266    ) -> Result<String, AggregatorClientError>;
267
268    /// Post information to the Aggregator
269    async fn post_content(
270        &self,
271        request: AggregatorRequest,
272    ) -> Result<String, AggregatorClientError>;
273}
274
275/// Responsible for HTTP transport and API version check.
276pub struct AggregatorHTTPClient {
277    http_client: reqwest::Client,
278    aggregator_endpoint: Url,
279    api_versions: Arc<RwLock<Vec<Version>>>,
280    logger: Logger,
281    http_headers: HeaderMap,
282}
283
284impl AggregatorHTTPClient {
285    /// Constructs a new `AggregatorHTTPClient`
286    pub fn new(
287        aggregator_endpoint: Url,
288        api_versions: Vec<Version>,
289        logger: Logger,
290        custom_headers: Option<HashMap<String, String>>,
291    ) -> MithrilResult<Self> {
292        let http_client = reqwest::ClientBuilder::new()
293            .build()
294            .with_context(|| "Building http client for Aggregator client failed")?;
295
296        // Trailing slash is significant because url::join
297        // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove
298        // the 'path' part of the url if it doesn't end with a trailing slash.
299        let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
300            aggregator_endpoint
301        } else {
302            let mut url = aggregator_endpoint.clone();
303            url.set_path(&format!("{}/", aggregator_endpoint.path()));
304            url
305        };
306
307        let mut http_headers = HeaderMap::new();
308        if let Some(headers) = custom_headers {
309            for (key, value) in headers.iter() {
310                http_headers.insert(
311                    HeaderName::from_bytes(key.as_bytes())?,
312                    HeaderValue::from_str(value)?,
313                );
314            }
315        }
316
317        Ok(Self {
318            http_client,
319            aggregator_endpoint,
320            api_versions: Arc::new(RwLock::new(api_versions)),
321            logger: logger.new_with_component_name::<Self>(),
322            http_headers,
323        })
324    }
325
326    /// Computes the current api version
327    async fn compute_current_api_version(&self) -> Option<Version> {
328        self.api_versions.read().await.first().cloned()
329    }
330
331    /// Perform a HTTP GET request on the Aggregator and return the given JSON
332    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
333    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
334    async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
335        debug!(self.logger, "GET url='{url}'.");
336        let request_builder = self.http_client.get(url.clone());
337        let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
338        debug!(
339            self.logger,
340            "Prepare request with version: {current_api_version}"
341        );
342        let request_builder = request_builder
343            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
344            .headers(self.http_headers.clone());
345
346        let response = request_builder.send().await.map_err(|e| {
347            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
348                "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
349            )))
350        })?;
351
352        match response.status() {
353            StatusCode::OK => {
354                self.warn_if_api_version_mismatch(&response).await;
355
356                Ok(response)
357            }
358            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
359            status_code if status_code.is_client_error() => {
360                Err(Self::remote_logical_error(response).await)
361            }
362            _ => Err(Self::remote_technical_error(response).await),
363        }
364    }
365
366    #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
367    #[cfg_attr(not(target_family = "wasm"), async_recursion)]
368    async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
369        debug!(self.logger, "POST url='{url}'"; "json" => json);
370        let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
371        let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
372        debug!(
373            self.logger,
374            "Prepare request with version: {current_api_version}"
375        );
376        let request_builder = request_builder
377            .header(MITHRIL_API_VERSION_HEADER, current_api_version)
378            .headers(self.http_headers.clone());
379
380        let response = request_builder.send().await.map_err(|e| {
381            AggregatorClientError::SubsystemError(
382                anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
383            )
384        })?;
385
386        match response.status() {
387            StatusCode::OK | StatusCode::CREATED => {
388                self.warn_if_api_version_mismatch(&response).await;
389
390                Ok(response)
391            }
392            StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
393            status_code if status_code.is_client_error() => {
394                Err(Self::remote_logical_error(response).await)
395            }
396            _ => Err(Self::remote_technical_error(response).await),
397        }
398    }
399
400    fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
401        self.aggregator_endpoint
402            .join(endpoint)
403            .with_context(|| {
404                format!(
405                    "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
406                    self.aggregator_endpoint
407                )
408            })
409            .map_err(AggregatorClientError::SubsystemError)
410    }
411
412    fn not_found_error(url: Url) -> AggregatorClientError {
413        AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
414    }
415
416    async fn remote_logical_error(response: Response) -> AggregatorClientError {
417        let status_code = response.status();
418        let client_error = response.json::<ClientError>().await.unwrap_or(ClientError::new(
419            format!("Unhandled error {status_code}"),
420            "",
421        ));
422
423        AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
424    }
425
426    async fn remote_technical_error(response: Response) -> AggregatorClientError {
427        let status_code = response.status();
428        let server_error = response
429            .json::<ServerError>()
430            .await
431            .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
432
433        AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
434    }
435
436    /// Check API version mismatch and log a warning if the aggregator's version is more recent.
437    async fn warn_if_api_version_mismatch(&self, response: &Response) {
438        let aggregator_version = response
439            .headers()
440            .get(MITHRIL_API_VERSION_HEADER)
441            .and_then(|v| v.to_str().ok())
442            .and_then(|s| Version::parse(s).ok());
443
444        let client_version = self.compute_current_api_version().await;
445
446        match (aggregator_version, client_version) {
447            (Some(aggregator), Some(client)) if client < aggregator => {
448                warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
449                    "aggregator_version" => %aggregator,
450                    "client_version" => %client,
451                );
452            }
453            (Some(_), None) => {
454                error!(
455                    self.logger,
456                    "Failed to compute the current client API version"
457                );
458            }
459            _ => {}
460        }
461    }
462}
463
464#[cfg_attr(target_family = "wasm", async_trait(?Send))]
465#[cfg_attr(not(target_family = "wasm"), async_trait)]
466impl AggregatorClient for AggregatorHTTPClient {
467    async fn get_content(
468        &self,
469        request: AggregatorRequest,
470    ) -> Result<String, AggregatorClientError> {
471        let response = self.get(self.get_url_for_route(&request.route())?).await?;
472        let content = format!("{response:?}");
473
474        response.text().await.map_err(|e| {
475            AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
476                "Could not find a JSON body in the response '{content}'."
477            )))
478        })
479    }
480
481    async fn post_content(
482        &self,
483        request: AggregatorRequest,
484    ) -> Result<String, AggregatorClientError> {
485        let response = self
486            .post(
487                self.get_url_for_route(&request.route())?,
488                &request.get_body().unwrap_or_default(),
489            )
490            .await?;
491
492        response.text().await.map_err(|e| {
493            AggregatorClientError::SubsystemError(
494                anyhow!(e).context("Could not find a text body in the response."),
495            )
496        })
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use httpmock::MockServer;
503    use std::collections::HashMap;
504    use strum::IntoEnumIterator;
505
506    use mithril_common::api_version::APIVersionProvider;
507    use mithril_common::entities::{ClientError, ServerError};
508
509    use crate::test_utils::TestLogger;
510
511    use super::*;
512
513    macro_rules! assert_error_eq {
514        ($left:expr, $right:expr) => {
515            assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
516        };
517    }
518
519    fn setup_client(
520        server_url: &str,
521        api_versions: Vec<Version>,
522        custom_headers: Option<HashMap<String, String>>,
523    ) -> AggregatorHTTPClient {
524        AggregatorHTTPClient::new(
525            Url::parse(server_url).unwrap(),
526            api_versions,
527            TestLogger::stdout(),
528            custom_headers,
529        )
530        .expect("building aggregator http client should not fail")
531    }
532
533    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
534        let server = MockServer::start();
535        let client = setup_client(
536            &server.url(""),
537            APIVersionProvider::compute_all_versions_sorted(),
538            None,
539        );
540        (server, client)
541    }
542
543    fn setup_server_and_client_with_custom_headers(
544        custom_headers: HashMap<String, String>,
545    ) -> (MockServer, AggregatorHTTPClient) {
546        let server = MockServer::start();
547        let client = setup_client(
548            &server.url(""),
549            APIVersionProvider::compute_all_versions_sorted(),
550            Some(custom_headers),
551        );
552        (server, client)
553    }
554
555    #[test]
556    fn always_append_trailing_slash_at_build() {
557        for (expected, url) in [
558            ("http://www.test.net/", "http://www.test.net/"),
559            ("http://www.test.net/", "http://www.test.net"),
560            (
561                "http://www.test.net/aggregator/",
562                "http://www.test.net/aggregator/",
563            ),
564            (
565                "http://www.test.net/aggregator/",
566                "http://www.test.net/aggregator",
567            ),
568        ] {
569            let url = Url::parse(url).unwrap();
570            let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
571                .expect("building aggregator http client should not fail");
572
573            assert_eq!(expected, client.aggregator_endpoint.as_str());
574        }
575    }
576
577    #[test]
578    fn deduce_routes_from_request() {
579        assert_eq!(
580            "certificate/abc".to_string(),
581            AggregatorRequest::GetCertificate {
582                hash: "abc".to_string()
583            }
584            .route()
585        );
586
587        assert_eq!(
588            "artifact/mithril-stake-distribution/abc".to_string(),
589            AggregatorRequest::GetMithrilStakeDistribution {
590                hash: "abc".to_string()
591            }
592            .route()
593        );
594
595        assert_eq!(
596            "artifact/mithril-stake-distribution/abc".to_string(),
597            AggregatorRequest::GetMithrilStakeDistribution {
598                hash: "abc".to_string()
599            }
600            .route()
601        );
602
603        assert_eq!(
604            "artifact/mithril-stake-distributions".to_string(),
605            AggregatorRequest::ListMithrilStakeDistributions.route()
606        );
607
608        assert_eq!(
609            "artifact/snapshot/abc".to_string(),
610            AggregatorRequest::GetSnapshot {
611                digest: "abc".to_string()
612            }
613            .route()
614        );
615
616        assert_eq!(
617            "artifact/snapshots".to_string(),
618            AggregatorRequest::ListSnapshots.route()
619        );
620
621        assert_eq!(
622            "statistics/snapshot".to_string(),
623            AggregatorRequest::IncrementSnapshotStatistic {
624                snapshot: "abc".to_string()
625            }
626            .route()
627        );
628
629        assert_eq!(
630            "artifact/cardano-database/abc".to_string(),
631            AggregatorRequest::GetCardanoDatabaseSnapshot {
632                hash: "abc".to_string()
633            }
634            .route()
635        );
636
637        assert_eq!(
638            "artifact/cardano-database".to_string(),
639            AggregatorRequest::ListCardanoDatabaseSnapshots.route()
640        );
641
642        assert_eq!(
643            "artifact/cardano-database/epoch/5".to_string(),
644            AggregatorRequest::ListCardanoDatabaseSnapshotByEpoch { epoch: Epoch(5) }.route()
645        );
646
647        assert_eq!(
648            "artifact/cardano-database/epoch/latest".to_string(),
649            AggregatorRequest::ListCardanoDatabaseSnapshotForLatestEpoch { offset: None }.route()
650        );
651
652        assert_eq!(
653            "artifact/cardano-database/epoch/latest-6".to_string(),
654            AggregatorRequest::ListCardanoDatabaseSnapshotForLatestEpoch { offset: Some(6) }
655                .route()
656        );
657
658        assert_eq!(
659            "statistics/cardano-database/immutable-files-restored".to_string(),
660            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
661                number_of_immutables: 58
662            }
663            .route()
664        );
665
666        assert_eq!(
667            "statistics/cardano-database/ancillary-files-restored".to_string(),
668            AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
669        );
670
671        assert_eq!(
672            "statistics/cardano-database/complete-restoration".to_string(),
673            AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
674        );
675
676        assert_eq!(
677            "statistics/cardano-database/partial-restoration".to_string(),
678            AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
679        );
680
681        assert_eq!(
682            "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
683            AggregatorRequest::GetTransactionsProofs {
684                transactions_hashes: vec![
685                    "abc".to_string(),
686                    "def".to_string(),
687                    "ghi".to_string(),
688                    "jkl".to_string()
689                ]
690            }
691            .route()
692        );
693
694        assert_eq!(
695            "artifact/cardano-transaction/abc".to_string(),
696            AggregatorRequest::GetCardanoTransactionSnapshot {
697                hash: "abc".to_string()
698            }
699            .route()
700        );
701
702        assert_eq!(
703            "artifact/cardano-transactions".to_string(),
704            AggregatorRequest::ListCardanoTransactionSnapshots.route()
705        );
706
707        assert_eq!(
708            "artifact/cardano-stake-distribution/abc".to_string(),
709            AggregatorRequest::GetCardanoStakeDistribution {
710                hash: "abc".to_string()
711            }
712            .route()
713        );
714
715        assert_eq!(
716            "artifact/cardano-stake-distribution/epoch/123".to_string(),
717            AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
718        );
719
720        assert_eq!(
721            "artifact/cardano-stake-distribution/epoch/latest".to_string(),
722            AggregatorRequest::GetCardanoStakeDistributionForLatestEpoch { offset: None }.route()
723        );
724
725        assert_eq!(
726            "artifact/cardano-stake-distribution/epoch/latest-8".to_string(),
727            AggregatorRequest::GetCardanoStakeDistributionForLatestEpoch { offset: Some(8) }
728                .route()
729        );
730
731        assert_eq!(
732            "artifact/cardano-stake-distributions".to_string(),
733            AggregatorRequest::ListCardanoStakeDistributions.route()
734        );
735
736        assert_eq!("status".to_string(), AggregatorRequest::Status.route());
737    }
738
739    #[test]
740    fn deduce_body_from_request() {
741        fn that_should_not_have_body(req: &AggregatorRequest) -> bool {
742            !matches!(
743                req,
744                AggregatorRequest::IncrementSnapshotStatistic { .. }
745                    | AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic { .. }
746            )
747        }
748
749        assert_eq!(
750            Some(r#"{"key":"value"}"#.to_string()),
751            AggregatorRequest::IncrementSnapshotStatistic {
752                snapshot: r#"{"key":"value"}"#.to_string()
753            }
754            .get_body()
755        );
756
757        assert_eq!(
758            Some(
759                serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
760                    nb_immutable_files: 432,
761                })
762                .unwrap()
763            ),
764            AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
765                number_of_immutables: 432,
766            }
767            .get_body()
768        );
769
770        for req_that_should_not_have_body in
771            AggregatorRequest::iter().filter(that_should_not_have_body)
772        {
773            assert_eq!(None, req_that_should_not_have_body.get_body());
774        }
775    }
776
777    #[tokio::test]
778    async fn test_client_handle_4xx_errors() {
779        let client_error = ClientError::new("label", "message");
780
781        let (aggregator, client) = setup_server_and_client();
782        aggregator.mock(|_when, then| {
783            then.status(StatusCode::IM_A_TEAPOT.as_u16())
784                .json_body_obj(&client_error);
785        });
786
787        let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
788
789        let get_content_error = client
790            .get_content(AggregatorRequest::ListCertificates)
791            .await
792            .unwrap_err();
793        assert_error_eq!(get_content_error, expected_error);
794
795        let post_content_error = client
796            .post_content(AggregatorRequest::ListCertificates)
797            .await
798            .unwrap_err();
799        assert_error_eq!(post_content_error, expected_error);
800    }
801
802    #[tokio::test]
803    async fn test_client_handle_404_not_found_error() {
804        let client_error = ClientError::new("label", "message");
805
806        let (aggregator, client) = setup_server_and_client();
807        aggregator.mock(|_when, then| {
808            then.status(StatusCode::NOT_FOUND.as_u16())
809                .json_body_obj(&client_error);
810        });
811
812        let expected_error = AggregatorHTTPClient::not_found_error(
813            Url::parse(&format!(
814                "{}/{}",
815                aggregator.base_url(),
816                AggregatorRequest::ListCertificates.route()
817            ))
818            .unwrap(),
819        );
820
821        let get_content_error = client
822            .get_content(AggregatorRequest::ListCertificates)
823            .await
824            .unwrap_err();
825        assert_error_eq!(get_content_error, expected_error);
826
827        let post_content_error = client
828            .post_content(AggregatorRequest::ListCertificates)
829            .await
830            .unwrap_err();
831        assert_error_eq!(post_content_error, expected_error);
832    }
833
834    #[tokio::test]
835    async fn test_client_handle_5xx_errors() {
836        let server_error = ServerError::new("message");
837
838        let (aggregator, client) = setup_server_and_client();
839        aggregator.mock(|_when, then| {
840            then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
841                .json_body_obj(&server_error);
842        });
843
844        let expected_error =
845            AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
846
847        let get_content_error = client
848            .get_content(AggregatorRequest::ListCertificates)
849            .await
850            .unwrap_err();
851        assert_error_eq!(get_content_error, expected_error);
852
853        let post_content_error = client
854            .post_content(AggregatorRequest::ListCertificates)
855            .await
856            .unwrap_err();
857        assert_error_eq!(post_content_error, expected_error);
858    }
859
860    #[tokio::test]
861    async fn test_client_with_custom_headers() {
862        let mut http_headers = HashMap::new();
863        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
864        http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
865        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
866        aggregator.mock(|when, then| {
867            when.header("Custom-Header", "CustomValue")
868                .header("Another-Header", "AnotherValue");
869            then.status(StatusCode::OK.as_u16()).body("ok");
870        });
871
872        client
873            .get_content(AggregatorRequest::ListCertificates)
874            .await
875            .expect("GET request should succeed");
876
877        client
878            .post_content(AggregatorRequest::ListCertificates)
879            .await
880            .expect("GET request should succeed");
881    }
882
883    #[tokio::test]
884    async fn test_client_sends_accept_encoding_header_with_correct_values() {
885        let (aggregator, client) = setup_server_and_client();
886        aggregator.mock(|when, then| {
887            when.matches(|req| {
888                let headers = req.headers();
889                let accept_encoding_header = headers
890                    .get("accept-encoding")
891                    .expect("Accept-Encoding header not found");
892
893                ["gzip", "br", "deflate", "zstd"].iter().all(|&encoding| {
894                    accept_encoding_header.to_str().is_ok_and(|h| h.contains(encoding))
895                })
896            });
897
898            then.status(200).body("ok");
899        });
900
901        client
902            .get_content(AggregatorRequest::ListCertificates)
903            .await
904            .expect("GET request should succeed with Accept-Encoding header");
905    }
906
907    #[tokio::test]
908    async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
909        let mut http_headers = HashMap::new();
910        http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
911        let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
912        aggregator.mock(|when, then| {
913            when.matches(|req| {
914                let headers = req.headers();
915                let accept_encoding_header = headers
916                    .get("accept-encoding")
917                    .expect("Accept-Encoding header not found");
918
919                ["gzip", "br", "deflate", "zstd"].iter().all(|&encoding| {
920                    accept_encoding_header.to_str().is_ok_and(|h| h.contains(encoding))
921                })
922            });
923
924            then.status(200).body("ok");
925        });
926
927        client
928            .get_content(AggregatorRequest::ListCertificates)
929            .await
930            .expect("GET request should succeed with Accept-Encoding header");
931    }
932
933    mod warn_if_api_version_mismatch {
934        use http::response::Builder as HttpResponseBuilder;
935
936        use mithril_common::test::logging::MemoryDrainForTestInspector;
937
938        use super::*;
939
940        fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
941            key: K,
942            value: V,
943        ) -> Response {
944            HttpResponseBuilder::new()
945                .header(key.into(), value.into())
946                .body("whatever")
947                .unwrap()
948                .into()
949        }
950
951        fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
952            log_inspector: &MemoryDrainForTestInspector,
953            aggregator_version: A,
954            client_version: S,
955        ) {
956            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
957            assert!(
958                log_inspector
959                    .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
960            );
961            assert!(
962                log_inspector.contains_log(&format!("client_version={}", client_version.into()))
963            );
964        }
965
966        #[tokio::test]
967        async fn test_logs_warning_when_aggregator_api_version_is_newer() {
968            let aggregator_version = "2.0.0";
969            let client_version = "1.0.0";
970            let (logger, log_inspector) = TestLogger::memory();
971            let mut client = setup_client(
972                "http://whatever",
973                vec![Version::parse(client_version).unwrap()],
974                None,
975            );
976            client.logger = logger;
977            let response =
978                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
979
980            assert!(
981                Version::parse(aggregator_version).unwrap()
982                    > Version::parse(client_version).unwrap()
983            );
984
985            client.warn_if_api_version_mismatch(&response).await;
986
987            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
988        }
989
990        #[tokio::test]
991        async fn test_no_warning_logged_when_versions_match() {
992            let version = "1.0.0";
993            let (logger, log_inspector) = TestLogger::memory();
994            let mut client = setup_client(
995                "http://whatever",
996                vec![Version::parse(version).unwrap()],
997                None,
998            );
999            client.logger = logger;
1000            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
1001
1002            client.warn_if_api_version_mismatch(&response).await;
1003
1004            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1005        }
1006
1007        #[tokio::test]
1008        async fn test_no_warning_logged_when_aggregator_api_version_is_older() {
1009            let aggregator_version = "1.0.0";
1010            let client_version = "2.0.0";
1011            let (logger, log_inspector) = TestLogger::memory();
1012            let mut client = setup_client(
1013                "http://whatever",
1014                vec![Version::parse(client_version).unwrap()],
1015                None,
1016            );
1017            client.logger = logger;
1018            let response =
1019                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1020
1021            assert!(
1022                Version::parse(aggregator_version).unwrap()
1023                    < Version::parse(client_version).unwrap()
1024            );
1025
1026            client.warn_if_api_version_mismatch(&response).await;
1027
1028            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1029        }
1030
1031        #[tokio::test]
1032        async fn test_does_not_log_or_fail_when_header_is_missing() {
1033            let (logger, log_inspector) = TestLogger::memory();
1034            let mut client = setup_client(
1035                "http://whatever",
1036                APIVersionProvider::compute_all_versions_sorted(),
1037                None,
1038            );
1039            client.logger = logger;
1040            let response =
1041                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
1042
1043            client.warn_if_api_version_mismatch(&response).await;
1044
1045            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1046        }
1047
1048        #[tokio::test]
1049        async fn test_does_not_log_or_fail_when_header_is_not_a_version() {
1050            let (logger, log_inspector) = TestLogger::memory();
1051            let mut client = setup_client(
1052                "http://whatever",
1053                APIVersionProvider::compute_all_versions_sorted(),
1054                None,
1055            );
1056            client.logger = logger;
1057            let response =
1058                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1059
1060            client.warn_if_api_version_mismatch(&response).await;
1061
1062            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1063        }
1064
1065        #[tokio::test]
1066        async fn test_logs_error_when_client_version_cannot_be_computed() {
1067            let (logger, log_inspector) = TestLogger::memory();
1068            let mut client = setup_client("http://whatever", vec![], None);
1069            client.logger = logger;
1070            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1071
1072            client.warn_if_api_version_mismatch(&response).await;
1073
1074            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1075        }
1076
1077        #[tokio::test]
1078        async fn test_client_get_log_warning_if_api_version_mismatch() {
1079            let aggregator_version = "2.0.0";
1080            let client_version = "1.0.0";
1081            let (server, mut client) = setup_server_and_client();
1082            let (logger, log_inspector) = TestLogger::memory();
1083            client.api_versions =
1084                Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1085            client.logger = logger;
1086            server.mock(|_, then| {
1087                then.status(StatusCode::OK.as_u16())
1088                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1089            });
1090
1091            assert!(
1092                Version::parse(aggregator_version).unwrap()
1093                    > Version::parse(client_version).unwrap()
1094            );
1095
1096            client.get(Url::parse(&server.base_url()).unwrap()).await.unwrap();
1097
1098            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1099        }
1100
1101        #[tokio::test]
1102        async fn test_client_post_log_warning_if_api_version_mismatch() {
1103            let aggregator_version = "2.0.0";
1104            let client_version = "1.0.0";
1105            let (server, mut client) = setup_server_and_client();
1106            let (logger, log_inspector) = TestLogger::memory();
1107            client.api_versions =
1108                Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1109            client.logger = logger;
1110            server.mock(|_, then| {
1111                then.status(StatusCode::OK.as_u16())
1112                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1113            });
1114
1115            assert!(
1116                Version::parse(aggregator_version).unwrap()
1117                    > Version::parse(client_version).unwrap()
1118            );
1119
1120            client
1121                .post(Url::parse(&server.base_url()).unwrap(), "whatever")
1122                .await
1123                .unwrap();
1124
1125            assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1126        }
1127    }
1128}