1use anyhow::{anyhow, Context};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{debug, Logger};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::entities::{ClientError, ServerError};
23use mithril_common::logging::LoggerExtensions;
24use mithril_common::MITHRIL_API_VERSION_HEADER;
25
26use crate::common::Epoch;
27use crate::{MithrilError, MithrilResult};
28
29#[derive(Error, Debug)]
31pub enum AggregatorClientError {
32 #[error("Internal error of the Aggregator")]
34 RemoteServerTechnical(#[source] MithrilError),
35
36 #[error("Invalid request to the Aggregator")]
38 RemoteServerLogical(#[source] MithrilError),
39
40 #[error("API version mismatch")]
42 ApiVersionMismatch(#[source] MithrilError),
43
44 #[error("HTTP subsystem error")]
46 SubsystemError(#[source] MithrilError),
47}
48
49#[derive(Debug, Clone, Eq, PartialEq)]
51pub enum AggregatorRequest {
52 GetCertificate {
54 hash: String,
56 },
57
58 ListCertificates,
60
61 GetMithrilStakeDistribution {
63 hash: String,
65 },
66
67 ListMithrilStakeDistributions,
69
70 GetSnapshot {
72 digest: String,
74 },
75
76 ListSnapshots,
78
79 IncrementSnapshotStatistic {
81 snapshot: String,
83 },
84
85 #[cfg(feature = "unstable")]
87 GetCardanoDatabaseSnapshot {
88 hash: String,
90 },
91
92 #[cfg(feature = "unstable")]
94 ListCardanoDatabaseSnapshots,
95
96 #[cfg(feature = "unstable")]
98 IncrementCardanoDatabaseImmutablesRestoredStatistic {
99 number_of_immutables: String,
101 },
102
103 #[cfg(feature = "unstable")]
105 IncrementCardanoDatabaseAncillaryStatistic,
106
107 #[cfg(feature = "unstable")]
109 IncrementCardanoDatabaseCompleteRestorationStatistic,
110
111 #[cfg(feature = "unstable")]
113 IncrementCardanoDatabasePartialRestorationStatistic,
114
115 GetTransactionsProofs {
117 transactions_hashes: Vec<String>,
119 },
120
121 GetCardanoTransactionSnapshot {
123 hash: String,
125 },
126
127 ListCardanoTransactionSnapshots,
129
130 GetCardanoStakeDistribution {
132 hash: String,
134 },
135
136 GetCardanoStakeDistributionByEpoch {
138 epoch: Epoch,
140 },
141
142 ListCardanoStakeDistributions,
144}
145
146impl AggregatorRequest {
147 pub fn route(&self) -> String {
149 match self {
150 AggregatorRequest::GetCertificate { hash } => {
151 format!("certificate/{hash}")
152 }
153 AggregatorRequest::ListCertificates => "certificates".to_string(),
154 AggregatorRequest::GetMithrilStakeDistribution { hash } => {
155 format!("artifact/mithril-stake-distribution/{hash}")
156 }
157 AggregatorRequest::ListMithrilStakeDistributions => {
158 "artifact/mithril-stake-distributions".to_string()
159 }
160 AggregatorRequest::GetSnapshot { digest } => {
161 format!("artifact/snapshot/{}", digest)
162 }
163 AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
164 AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
165 "statistics/snapshot".to_string()
166 }
167 #[cfg(feature = "unstable")]
168 AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
169 format!("artifact/cardano-database/{}", hash)
170 }
171 #[cfg(feature = "unstable")]
172 AggregatorRequest::ListCardanoDatabaseSnapshots => {
173 "artifact/cardano-database".to_string()
174 }
175 #[cfg(feature = "unstable")]
176 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
177 number_of_immutables: _,
178 } => "statistics/cardano-database/immutable-files-restored".to_string(),
179 #[cfg(feature = "unstable")]
180 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
181 "statistics/cardano-database/ancillary-files-restored".to_string()
182 }
183 #[cfg(feature = "unstable")]
184 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
185 "statistics/cardano-database/complete-restoration".to_string()
186 }
187 #[cfg(feature = "unstable")]
188 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
189 "statistics/cardano-database/partial-restoration".to_string()
190 }
191 AggregatorRequest::GetTransactionsProofs {
192 transactions_hashes,
193 } => format!(
194 "proof/cardano-transaction?transaction_hashes={}",
195 transactions_hashes.join(",")
196 ),
197 AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
198 format!("artifact/cardano-transaction/{hash}")
199 }
200 AggregatorRequest::ListCardanoTransactionSnapshots => {
201 "artifact/cardano-transactions".to_string()
202 }
203 AggregatorRequest::GetCardanoStakeDistribution { hash } => {
204 format!("artifact/cardano-stake-distribution/{hash}")
205 }
206 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
207 format!("artifact/cardano-stake-distribution/epoch/{epoch}")
208 }
209 AggregatorRequest::ListCardanoStakeDistributions => {
210 "artifact/cardano-stake-distributions".to_string()
211 }
212 }
213 }
214
215 pub fn get_body(&self) -> Option<String> {
217 match self {
218 AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
219 Some(snapshot.to_string())
220 }
221 _ => None,
222 }
223 }
224}
225
226#[cfg_attr(test, mockall::automock)]
228#[cfg_attr(target_family = "wasm", async_trait(?Send))]
229#[cfg_attr(not(target_family = "wasm"), async_trait)]
230pub trait AggregatorClient: Sync + Send {
231 async fn get_content(
233 &self,
234 request: AggregatorRequest,
235 ) -> Result<String, AggregatorClientError>;
236
237 async fn post_content(
239 &self,
240 request: AggregatorRequest,
241 ) -> Result<String, AggregatorClientError>;
242}
243
244pub struct AggregatorHTTPClient {
246 http_client: reqwest::Client,
247 aggregator_endpoint: Url,
248 api_versions: Arc<RwLock<Vec<Version>>>,
249 logger: Logger,
250 http_headers: HeaderMap,
251}
252
253impl AggregatorHTTPClient {
254 pub fn new(
256 aggregator_endpoint: Url,
257 api_versions: Vec<Version>,
258 logger: Logger,
259 custom_headers: Option<HashMap<String, String>>,
260 ) -> MithrilResult<Self> {
261 let http_client = reqwest::ClientBuilder::new()
262 .build()
263 .with_context(|| "Building http client for Aggregator client failed")?;
264
265 let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
269 aggregator_endpoint
270 } else {
271 let mut url = aggregator_endpoint.clone();
272 url.set_path(&format!("{}/", aggregator_endpoint.path()));
273 url
274 };
275
276 let mut http_headers = HeaderMap::new();
277 if let Some(headers) = custom_headers {
278 for (key, value) in headers.iter() {
279 http_headers.insert(
280 HeaderName::from_bytes(key.as_bytes())?,
281 HeaderValue::from_str(value)?,
282 );
283 }
284 }
285
286 Ok(Self {
287 http_client,
288 aggregator_endpoint,
289 api_versions: Arc::new(RwLock::new(api_versions)),
290 logger: logger.new_with_component_name::<Self>(),
291 http_headers,
292 })
293 }
294
295 async fn compute_current_api_version(&self) -> Option<Version> {
297 self.api_versions.read().await.first().cloned()
298 }
299
300 async fn discard_current_api_version(&self) -> Option<Version> {
303 if self.api_versions.read().await.len() < 2 {
304 return None;
305 }
306 if let Some(current_api_version) = self.compute_current_api_version().await {
307 let mut api_versions = self.api_versions.write().await;
308 if let Some(index) = api_versions
309 .iter()
310 .position(|value| *value == current_api_version)
311 {
312 api_versions.remove(index);
313 return Some(current_api_version);
314 }
315 }
316 None
317 }
318
319 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
321 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
322 async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
323 debug!(self.logger, "GET url='{url}'.");
324 let request_builder = self.http_client.get(url.clone());
325 let current_api_version = self
326 .compute_current_api_version()
327 .await
328 .unwrap()
329 .to_string();
330 debug!(
331 self.logger,
332 "Prepare request with version: {current_api_version}"
333 );
334 let request_builder = request_builder
335 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
336 .headers(self.http_headers.clone());
337
338 let response = request_builder.send().await.map_err(|e| {
339 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
340 "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
341 )))
342 })?;
343
344 match response.status() {
345 StatusCode::OK => Ok(response),
346 StatusCode::PRECONDITION_FAILED => {
347 if self.discard_current_api_version().await.is_some()
348 && !self.api_versions.read().await.is_empty()
349 {
350 return self.get(url).await;
351 }
352
353 Err(self.handle_api_error(response.headers()).await)
354 }
355 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
356 status_code if status_code.is_client_error() => {
357 Err(Self::remote_logical_error(response).await)
358 }
359 _ => Err(Self::remote_technical_error(response).await),
360 }
361 }
362
363 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
364 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
365 async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
366 debug!(self.logger, "POST url='{url}'"; "json" => json);
367 let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
368 let current_api_version = self
369 .compute_current_api_version()
370 .await
371 .unwrap()
372 .to_string();
373 debug!(
374 self.logger,
375 "Prepare request with version: {current_api_version}"
376 );
377 let request_builder = request_builder
378 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
379 .headers(self.http_headers.clone());
380
381 let response = request_builder.send().await.map_err(|e| {
382 AggregatorClientError::SubsystemError(
383 anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
384 )
385 })?;
386
387 match response.status() {
388 StatusCode::OK | StatusCode::CREATED => Ok(response),
389 StatusCode::PRECONDITION_FAILED => {
390 if self.discard_current_api_version().await.is_some()
391 && !self.api_versions.read().await.is_empty()
392 {
393 return self.post(url, json).await;
394 }
395
396 Err(self.handle_api_error(response.headers()).await)
397 }
398 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
399 status_code if status_code.is_client_error() => {
400 Err(Self::remote_logical_error(response).await)
401 }
402 _ => Err(Self::remote_technical_error(response).await),
403 }
404 }
405
406 fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
407 self.aggregator_endpoint
408 .join(endpoint)
409 .with_context(|| {
410 format!(
411 "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
412 self.aggregator_endpoint
413 )
414 })
415 .map_err(AggregatorClientError::SubsystemError)
416 }
417
418 async fn handle_api_error(&self, response_header: &HeaderMap) -> AggregatorClientError {
420 if let Some(version) = response_header.get(MITHRIL_API_VERSION_HEADER) {
421 AggregatorClientError::ApiVersionMismatch(anyhow!(
422 "server version: '{}', signer version: '{}'",
423 version.to_str().unwrap(),
424 self.compute_current_api_version().await.unwrap()
425 ))
426 } else {
427 AggregatorClientError::ApiVersionMismatch(anyhow!(
428 "version precondition failed, sent version '{}'.",
429 self.compute_current_api_version().await.unwrap()
430 ))
431 }
432 }
433
434 fn not_found_error(url: Url) -> AggregatorClientError {
435 AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
436 }
437
438 async fn remote_logical_error(response: Response) -> AggregatorClientError {
439 let status_code = response.status();
440 let client_error = response
441 .json::<ClientError>()
442 .await
443 .unwrap_or(ClientError::new(
444 format!("Unhandled error {status_code}"),
445 "",
446 ));
447
448 AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
449 }
450
451 async fn remote_technical_error(response: Response) -> AggregatorClientError {
452 let status_code = response.status();
453 let server_error = response
454 .json::<ServerError>()
455 .await
456 .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
457
458 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
459 }
460}
461
462#[cfg_attr(target_family = "wasm", async_trait(?Send))]
463#[cfg_attr(not(target_family = "wasm"), async_trait)]
464impl AggregatorClient for AggregatorHTTPClient {
465 async fn get_content(
466 &self,
467 request: AggregatorRequest,
468 ) -> Result<String, AggregatorClientError> {
469 let response = self.get(self.get_url_for_route(&request.route())?).await?;
470 let content = format!("{response:?}");
471
472 response.text().await.map_err(|e| {
473 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
474 "Could not find a JSON body in the response '{content}'."
475 )))
476 })
477 }
478
479 async fn post_content(
480 &self,
481 request: AggregatorRequest,
482 ) -> Result<String, AggregatorClientError> {
483 let response = self
484 .post(
485 self.get_url_for_route(&request.route())?,
486 &request.get_body().unwrap_or_default(),
487 )
488 .await?;
489
490 response.text().await.map_err(|e| {
491 AggregatorClientError::SubsystemError(
492 anyhow!(e).context("Could not find a text body in the response."),
493 )
494 })
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use httpmock::MockServer;
501 use reqwest::header::{HeaderName, HeaderValue};
502 use std::collections::HashMap;
503
504 use mithril_common::api_version::APIVersionProvider;
505 use mithril_common::entities::{ClientError, ServerError};
506
507 use crate::test_utils::TestLogger;
508
509 use super::*;
510
511 macro_rules! assert_error_eq {
512 ($left:expr, $right:expr) => {
513 assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
514 };
515 }
516
517 fn setup_client(
518 server_url: &str,
519 api_versions: Vec<Version>,
520 custom_headers: Option<HashMap<String, String>>,
521 ) -> AggregatorHTTPClient {
522 AggregatorHTTPClient::new(
523 Url::parse(server_url).unwrap(),
524 api_versions,
525 TestLogger::stdout(),
526 custom_headers,
527 )
528 .expect("building aggregator http client should not fail")
529 }
530
531 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
532 let server = MockServer::start();
533 let client = setup_client(
534 &server.url(""),
535 APIVersionProvider::compute_all_versions_sorted().unwrap(),
536 None,
537 );
538 (server, client)
539 }
540
541 fn setup_server_and_client_with_custom_headers(
542 custom_headers: HashMap<String, String>,
543 ) -> (MockServer, AggregatorHTTPClient) {
544 let server = MockServer::start();
545 let client = setup_client(
546 &server.url(""),
547 APIVersionProvider::compute_all_versions_sorted().unwrap(),
548 Some(custom_headers),
549 );
550 (server, client)
551 }
552
553 fn mithril_api_version_headers(version: &str) -> HeaderMap {
554 HeaderMap::from_iter([(
555 HeaderName::from_static(MITHRIL_API_VERSION_HEADER),
556 HeaderValue::from_str(version).unwrap(),
557 )])
558 }
559
560 #[test]
561 fn always_append_trailing_slash_at_build() {
562 for (expected, url) in [
563 ("http://www.test.net/", "http://www.test.net/"),
564 ("http://www.test.net/", "http://www.test.net"),
565 (
566 "http://www.test.net/aggregator/",
567 "http://www.test.net/aggregator/",
568 ),
569 (
570 "http://www.test.net/aggregator/",
571 "http://www.test.net/aggregator",
572 ),
573 ] {
574 let url = Url::parse(url).unwrap();
575 let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
576 .expect("building aggregator http client should not fail");
577
578 assert_eq!(expected, client.aggregator_endpoint.as_str());
579 }
580 }
581
582 #[test]
583 fn deduce_routes_from_request() {
584 assert_eq!(
585 "certificate/abc".to_string(),
586 AggregatorRequest::GetCertificate {
587 hash: "abc".to_string()
588 }
589 .route()
590 );
591
592 assert_eq!(
593 "artifact/mithril-stake-distribution/abc".to_string(),
594 AggregatorRequest::GetMithrilStakeDistribution {
595 hash: "abc".to_string()
596 }
597 .route()
598 );
599
600 assert_eq!(
601 "artifact/mithril-stake-distribution/abc".to_string(),
602 AggregatorRequest::GetMithrilStakeDistribution {
603 hash: "abc".to_string()
604 }
605 .route()
606 );
607
608 assert_eq!(
609 "artifact/mithril-stake-distributions".to_string(),
610 AggregatorRequest::ListMithrilStakeDistributions.route()
611 );
612
613 assert_eq!(
614 "artifact/snapshot/abc".to_string(),
615 AggregatorRequest::GetSnapshot {
616 digest: "abc".to_string()
617 }
618 .route()
619 );
620
621 assert_eq!(
622 "artifact/snapshots".to_string(),
623 AggregatorRequest::ListSnapshots.route()
624 );
625
626 assert_eq!(
627 "statistics/snapshot".to_string(),
628 AggregatorRequest::IncrementSnapshotStatistic {
629 snapshot: "abc".to_string()
630 }
631 .route()
632 );
633
634 #[cfg(feature = "unstable")]
635 assert_eq!(
636 "artifact/cardano-database/abc".to_string(),
637 AggregatorRequest::GetCardanoDatabaseSnapshot {
638 hash: "abc".to_string()
639 }
640 .route()
641 );
642
643 #[cfg(feature = "unstable")]
644 assert_eq!(
645 "artifact/cardano-database".to_string(),
646 AggregatorRequest::ListCardanoDatabaseSnapshots.route()
647 );
648
649 #[cfg(feature = "unstable")]
650 assert_eq!(
651 "statistics/cardano-database/immutable-files-restored".to_string(),
652 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
653 number_of_immutables: "abc".to_string()
654 }
655 .route()
656 );
657
658 #[cfg(feature = "unstable")]
659 assert_eq!(
660 "statistics/cardano-database/ancillary-files-restored".to_string(),
661 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
662 );
663
664 #[cfg(feature = "unstable")]
665 assert_eq!(
666 "statistics/cardano-database/complete-restoration".to_string(),
667 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
668 );
669
670 #[cfg(feature = "unstable")]
671 assert_eq!(
672 "statistics/cardano-database/partial-restoration".to_string(),
673 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
674 );
675
676 assert_eq!(
677 "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
678 AggregatorRequest::GetTransactionsProofs {
679 transactions_hashes: vec![
680 "abc".to_string(),
681 "def".to_string(),
682 "ghi".to_string(),
683 "jkl".to_string()
684 ]
685 }
686 .route()
687 );
688
689 assert_eq!(
690 "artifact/cardano-transaction/abc".to_string(),
691 AggregatorRequest::GetCardanoTransactionSnapshot {
692 hash: "abc".to_string()
693 }
694 .route()
695 );
696
697 assert_eq!(
698 "artifact/cardano-transactions".to_string(),
699 AggregatorRequest::ListCardanoTransactionSnapshots.route()
700 );
701
702 assert_eq!(
703 "artifact/cardano-stake-distribution/abc".to_string(),
704 AggregatorRequest::GetCardanoStakeDistribution {
705 hash: "abc".to_string()
706 }
707 .route()
708 );
709
710 assert_eq!(
711 "artifact/cardano-stake-distribution/epoch/123".to_string(),
712 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
713 );
714
715 assert_eq!(
716 "artifact/cardano-stake-distributions".to_string(),
717 AggregatorRequest::ListCardanoStakeDistributions.route()
718 );
719 }
720
721 #[tokio::test]
722 async fn test_client_handle_4xx_errors() {
723 let client_error = ClientError::new("label", "message");
724
725 let (aggregator, client) = setup_server_and_client();
726 aggregator.mock(|_when, then| {
727 then.status(StatusCode::IM_A_TEAPOT.as_u16())
728 .json_body_obj(&client_error);
729 });
730
731 let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
732
733 let get_content_error = client
734 .get_content(AggregatorRequest::ListCertificates)
735 .await
736 .unwrap_err();
737 assert_error_eq!(get_content_error, expected_error);
738
739 let post_content_error = client
740 .post_content(AggregatorRequest::ListCertificates)
741 .await
742 .unwrap_err();
743 assert_error_eq!(post_content_error, expected_error);
744 }
745
746 #[tokio::test]
747 async fn test_client_handle_404_not_found_error() {
748 let client_error = ClientError::new("label", "message");
749
750 let (aggregator, client) = setup_server_and_client();
751 aggregator.mock(|_when, then| {
752 then.status(StatusCode::NOT_FOUND.as_u16())
753 .json_body_obj(&client_error);
754 });
755
756 let expected_error = AggregatorHTTPClient::not_found_error(
757 Url::parse(&format!(
758 "{}/{}",
759 aggregator.base_url(),
760 AggregatorRequest::ListCertificates.route()
761 ))
762 .unwrap(),
763 );
764
765 let get_content_error = client
766 .get_content(AggregatorRequest::ListCertificates)
767 .await
768 .unwrap_err();
769 assert_error_eq!(get_content_error, expected_error);
770
771 let post_content_error = client
772 .post_content(AggregatorRequest::ListCertificates)
773 .await
774 .unwrap_err();
775 assert_error_eq!(post_content_error, expected_error);
776 }
777
778 #[tokio::test]
779 async fn test_client_handle_5xx_errors() {
780 let server_error = ServerError::new("message");
781
782 let (aggregator, client) = setup_server_and_client();
783 aggregator.mock(|_when, then| {
784 then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
785 .json_body_obj(&server_error);
786 });
787
788 let expected_error =
789 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
790
791 let get_content_error = client
792 .get_content(AggregatorRequest::ListCertificates)
793 .await
794 .unwrap_err();
795 assert_error_eq!(get_content_error, expected_error);
796
797 let post_content_error = client
798 .post_content(AggregatorRequest::ListCertificates)
799 .await
800 .unwrap_err();
801 assert_error_eq!(post_content_error, expected_error);
802 }
803
804 #[tokio::test]
805 async fn test_client_handle_412_api_version_mismatch_with_version_in_response_header() {
806 let version = "0.0.0";
807
808 let (aggregator, client) = setup_server_and_client();
809 aggregator.mock(|_when, then| {
810 then.status(StatusCode::PRECONDITION_FAILED.as_u16())
811 .header(MITHRIL_API_VERSION_HEADER, version);
812 });
813
814 let expected_error = client
815 .handle_api_error(&mithril_api_version_headers(version))
816 .await;
817
818 let get_content_error = client
819 .get_content(AggregatorRequest::ListCertificates)
820 .await
821 .unwrap_err();
822 assert_error_eq!(get_content_error, expected_error);
823
824 let post_content_error = client
825 .post_content(AggregatorRequest::ListCertificates)
826 .await
827 .unwrap_err();
828 assert_error_eq!(post_content_error, expected_error);
829 }
830
831 #[tokio::test]
832 async fn test_client_handle_412_api_version_mismatch_without_version_in_response_header() {
833 let (aggregator, client) = setup_server_and_client();
834 aggregator.mock(|_when, then| {
835 then.status(StatusCode::PRECONDITION_FAILED.as_u16());
836 });
837
838 let expected_error = client.handle_api_error(&HeaderMap::new()).await;
839
840 let get_content_error = client
841 .get_content(AggregatorRequest::ListCertificates)
842 .await
843 .unwrap_err();
844 assert_error_eq!(get_content_error, expected_error);
845
846 let post_content_error = client
847 .post_content(AggregatorRequest::ListCertificates)
848 .await
849 .unwrap_err();
850 assert_error_eq!(post_content_error, expected_error);
851 }
852
853 #[tokio::test]
854 async fn test_client_can_fallback_to_a_second_version_when_412_api_version_mismatch() {
855 let bad_version = "0.0.0";
856 let good_version = "1.0.0";
857
858 let aggregator = MockServer::start();
859 let client = setup_client(
860 &aggregator.url(""),
861 vec![
862 Version::parse(bad_version).unwrap(),
863 Version::parse(good_version).unwrap(),
864 ],
865 None,
866 );
867 aggregator.mock(|when, then| {
868 when.header(MITHRIL_API_VERSION_HEADER, bad_version);
869 then.status(StatusCode::PRECONDITION_FAILED.as_u16())
870 .header(MITHRIL_API_VERSION_HEADER, bad_version);
871 });
872 aggregator.mock(|when, then| {
873 when.header(MITHRIL_API_VERSION_HEADER, good_version);
874 then.status(StatusCode::OK.as_u16());
875 });
876
877 assert_eq!(
878 client.compute_current_api_version().await,
879 Some(Version::parse(bad_version).unwrap()),
880 "Bad version should be tried first"
881 );
882
883 client
884 .get_content(AggregatorRequest::ListCertificates)
885 .await
886 .expect("should have run with a fallback version");
887
888 client
889 .post_content(AggregatorRequest::ListCertificates)
890 .await
891 .expect("should have run with a fallback version");
892 }
893
894 #[tokio::test]
895 async fn test_client_with_custom_headers() {
896 let mut http_headers = HashMap::new();
897 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
898 http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
899 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
900 aggregator.mock(|when, then| {
901 when.header("Custom-Header", "CustomValue")
902 .header("Another-Header", "AnotherValue");
903 then.status(StatusCode::OK.as_u16()).body("ok");
904 });
905
906 client
907 .get_content(AggregatorRequest::ListCertificates)
908 .await
909 .expect("GET request should succeed");
910
911 client
912 .post_content(AggregatorRequest::ListCertificates)
913 .await
914 .expect("GET request should succeed");
915 }
916
917 #[tokio::test]
918 async fn test_client_sends_accept_encoding_header_with_correct_values() {
919 let (aggregator, client) = setup_server_and_client();
920 aggregator.mock(|when, then| {
921 when.matches(|req| {
922 let headers = req.headers.clone().expect("HTTP headers not found");
923 let accept_encoding_header = headers
924 .iter()
925 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
926 .expect("Accept-Encoding header not found");
927
928 let header_value = accept_encoding_header.clone().1;
929 ["gzip", "br", "deflate", "zstd"]
930 .iter()
931 .all(|&value| header_value.contains(value))
932 });
933
934 then.status(200).body("ok");
935 });
936
937 client
938 .get_content(AggregatorRequest::ListCertificates)
939 .await
940 .expect("GET request should succeed with Accept-Encoding header");
941 }
942
943 #[tokio::test]
944 async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
945 let mut http_headers = HashMap::new();
946 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
947 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
948 aggregator.mock(|when, then| {
949 when.matches(|req| {
950 let headers = req.headers.clone().expect("HTTP headers not found");
951 let accept_encoding_header = headers
952 .iter()
953 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
954 .expect("Accept-Encoding header not found");
955
956 let header_value = accept_encoding_header.clone().1;
957 ["gzip", "br", "deflate", "zstd"]
958 .iter()
959 .all(|&value| header_value.contains(value))
960 });
961
962 then.status(200).body("ok");
963 });
964
965 client
966 .get_content(AggregatorRequest::ListCertificates)
967 .await
968 .expect("GET request should succeed with Accept-Encoding header");
969 }
970}