1use anyhow::{anyhow, Context};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{debug, error, warn, Logger};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::entities::{ClientError, ServerError};
23use mithril_common::logging::LoggerExtensions;
24use mithril_common::messages::CardanoDatabaseImmutableFilesRestoredMessage;
25use mithril_common::MITHRIL_API_VERSION_HEADER;
26
27use crate::common::Epoch;
28use crate::{MithrilError, MithrilResult};
29
30const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
31 "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
32
33#[derive(Error, Debug)]
35pub enum AggregatorClientError {
36 #[error("Internal error of the Aggregator")]
38 RemoteServerTechnical(#[source] MithrilError),
39
40 #[error("Invalid request to the Aggregator")]
42 RemoteServerLogical(#[source] MithrilError),
43
44 #[error("HTTP subsystem error")]
46 SubsystemError(#[source] MithrilError),
47}
48
49#[derive(Debug, Clone, Eq, PartialEq)]
51#[cfg_attr(test, derive(strum::EnumIter))]
52pub enum AggregatorRequest {
53 GetCertificate {
55 hash: String,
57 },
58
59 ListCertificates,
61
62 GetMithrilStakeDistribution {
64 hash: String,
66 },
67
68 ListMithrilStakeDistributions,
70
71 GetSnapshot {
73 digest: String,
75 },
76
77 ListSnapshots,
79
80 IncrementSnapshotStatistic {
82 snapshot: String,
84 },
85
86 GetCardanoDatabaseSnapshot {
88 hash: String,
90 },
91
92 ListCardanoDatabaseSnapshots,
94
95 IncrementCardanoDatabaseImmutablesRestoredStatistic {
97 number_of_immutables: u64,
99 },
100
101 IncrementCardanoDatabaseAncillaryStatistic,
103
104 IncrementCardanoDatabaseCompleteRestorationStatistic,
106
107 IncrementCardanoDatabasePartialRestorationStatistic,
109
110 GetTransactionsProofs {
112 transactions_hashes: Vec<String>,
114 },
115
116 GetCardanoTransactionSnapshot {
118 hash: String,
120 },
121
122 ListCardanoTransactionSnapshots,
124
125 GetCardanoStakeDistribution {
127 hash: String,
129 },
130
131 GetCardanoStakeDistributionByEpoch {
133 epoch: Epoch,
135 },
136
137 ListCardanoStakeDistributions,
139}
140
141impl AggregatorRequest {
142 pub fn route(&self) -> String {
144 match self {
145 AggregatorRequest::GetCertificate { hash } => {
146 format!("certificate/{hash}")
147 }
148 AggregatorRequest::ListCertificates => "certificates".to_string(),
149 AggregatorRequest::GetMithrilStakeDistribution { hash } => {
150 format!("artifact/mithril-stake-distribution/{hash}")
151 }
152 AggregatorRequest::ListMithrilStakeDistributions => {
153 "artifact/mithril-stake-distributions".to_string()
154 }
155 AggregatorRequest::GetSnapshot { digest } => {
156 format!("artifact/snapshot/{digest}")
157 }
158 AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
159 AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
160 "statistics/snapshot".to_string()
161 }
162 AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
163 format!("artifact/cardano-database/{hash}")
164 }
165 AggregatorRequest::ListCardanoDatabaseSnapshots => {
166 "artifact/cardano-database".to_string()
167 }
168 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
169 number_of_immutables: _,
170 } => "statistics/cardano-database/immutable-files-restored".to_string(),
171 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
172 "statistics/cardano-database/ancillary-files-restored".to_string()
173 }
174 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
175 "statistics/cardano-database/complete-restoration".to_string()
176 }
177 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
178 "statistics/cardano-database/partial-restoration".to_string()
179 }
180 AggregatorRequest::GetTransactionsProofs {
181 transactions_hashes,
182 } => format!(
183 "proof/cardano-transaction?transaction_hashes={}",
184 transactions_hashes.join(",")
185 ),
186 AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
187 format!("artifact/cardano-transaction/{hash}")
188 }
189 AggregatorRequest::ListCardanoTransactionSnapshots => {
190 "artifact/cardano-transactions".to_string()
191 }
192 AggregatorRequest::GetCardanoStakeDistribution { hash } => {
193 format!("artifact/cardano-stake-distribution/{hash}")
194 }
195 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
196 format!("artifact/cardano-stake-distribution/epoch/{epoch}")
197 }
198 AggregatorRequest::ListCardanoStakeDistributions => {
199 "artifact/cardano-stake-distributions".to_string()
200 }
201 }
202 }
203
204 pub fn get_body(&self) -> Option<String> {
206 match self {
207 AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
208 Some(snapshot.to_string())
209 }
210 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
211 number_of_immutables,
212 } => serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
213 nb_immutable_files: *number_of_immutables as u32,
214 })
215 .ok(),
216 _ => None,
217 }
218 }
219}
220
221#[cfg_attr(test, mockall::automock)]
223#[cfg_attr(target_family = "wasm", async_trait(?Send))]
224#[cfg_attr(not(target_family = "wasm"), async_trait)]
225pub trait AggregatorClient: Sync + Send {
226 async fn get_content(
228 &self,
229 request: AggregatorRequest,
230 ) -> Result<String, AggregatorClientError>;
231
232 async fn post_content(
234 &self,
235 request: AggregatorRequest,
236 ) -> Result<String, AggregatorClientError>;
237}
238
239pub struct AggregatorHTTPClient {
241 http_client: reqwest::Client,
242 aggregator_endpoint: Url,
243 api_versions: Arc<RwLock<Vec<Version>>>,
244 logger: Logger,
245 http_headers: HeaderMap,
246}
247
248impl AggregatorHTTPClient {
249 pub fn new(
251 aggregator_endpoint: Url,
252 api_versions: Vec<Version>,
253 logger: Logger,
254 custom_headers: Option<HashMap<String, String>>,
255 ) -> MithrilResult<Self> {
256 let http_client = reqwest::ClientBuilder::new()
257 .build()
258 .with_context(|| "Building http client for Aggregator client failed")?;
259
260 let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
264 aggregator_endpoint
265 } else {
266 let mut url = aggregator_endpoint.clone();
267 url.set_path(&format!("{}/", aggregator_endpoint.path()));
268 url
269 };
270
271 let mut http_headers = HeaderMap::new();
272 if let Some(headers) = custom_headers {
273 for (key, value) in headers.iter() {
274 http_headers.insert(
275 HeaderName::from_bytes(key.as_bytes())?,
276 HeaderValue::from_str(value)?,
277 );
278 }
279 }
280
281 Ok(Self {
282 http_client,
283 aggregator_endpoint,
284 api_versions: Arc::new(RwLock::new(api_versions)),
285 logger: logger.new_with_component_name::<Self>(),
286 http_headers,
287 })
288 }
289
290 async fn compute_current_api_version(&self) -> Option<Version> {
292 self.api_versions.read().await.first().cloned()
293 }
294
295 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
297 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
298 async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
299 debug!(self.logger, "GET url='{url}'.");
300 let request_builder = self.http_client.get(url.clone());
301 let current_api_version = self
302 .compute_current_api_version()
303 .await
304 .unwrap()
305 .to_string();
306 debug!(
307 self.logger,
308 "Prepare request with version: {current_api_version}"
309 );
310 let request_builder = request_builder
311 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
312 .headers(self.http_headers.clone());
313
314 let response = request_builder.send().await.map_err(|e| {
315 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
316 "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
317 )))
318 })?;
319
320 match response.status() {
321 StatusCode::OK => {
322 self.warn_if_api_version_mismatch(&response).await;
323
324 Ok(response)
325 }
326 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
327 status_code if status_code.is_client_error() => {
328 Err(Self::remote_logical_error(response).await)
329 }
330 _ => Err(Self::remote_technical_error(response).await),
331 }
332 }
333
334 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
335 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
336 async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
337 debug!(self.logger, "POST url='{url}'"; "json" => json);
338 let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
339 let current_api_version = self
340 .compute_current_api_version()
341 .await
342 .unwrap()
343 .to_string();
344 debug!(
345 self.logger,
346 "Prepare request with version: {current_api_version}"
347 );
348 let request_builder = request_builder
349 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
350 .headers(self.http_headers.clone());
351
352 let response = request_builder.send().await.map_err(|e| {
353 AggregatorClientError::SubsystemError(
354 anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
355 )
356 })?;
357
358 match response.status() {
359 StatusCode::OK | StatusCode::CREATED => {
360 self.warn_if_api_version_mismatch(&response).await;
361
362 Ok(response)
363 }
364 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
365 status_code if status_code.is_client_error() => {
366 Err(Self::remote_logical_error(response).await)
367 }
368 _ => Err(Self::remote_technical_error(response).await),
369 }
370 }
371
372 fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
373 self.aggregator_endpoint
374 .join(endpoint)
375 .with_context(|| {
376 format!(
377 "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
378 self.aggregator_endpoint
379 )
380 })
381 .map_err(AggregatorClientError::SubsystemError)
382 }
383
384 fn not_found_error(url: Url) -> AggregatorClientError {
385 AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
386 }
387
388 async fn remote_logical_error(response: Response) -> AggregatorClientError {
389 let status_code = response.status();
390 let client_error = response
391 .json::<ClientError>()
392 .await
393 .unwrap_or(ClientError::new(
394 format!("Unhandled error {status_code}"),
395 "",
396 ));
397
398 AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
399 }
400
401 async fn remote_technical_error(response: Response) -> AggregatorClientError {
402 let status_code = response.status();
403 let server_error = response
404 .json::<ServerError>()
405 .await
406 .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
407
408 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
409 }
410
411 async fn warn_if_api_version_mismatch(&self, response: &Response) {
413 let aggregator_version = response
414 .headers()
415 .get(MITHRIL_API_VERSION_HEADER)
416 .and_then(|v| v.to_str().ok())
417 .and_then(|s| Version::parse(s).ok());
418
419 let client_version = self.compute_current_api_version().await;
420
421 match (aggregator_version, client_version) {
422 (Some(aggregator), Some(client)) if client < aggregator => {
423 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
424 "aggregator_version" => %aggregator,
425 "client_version" => %client,
426 );
427 }
428 (Some(_), None) => {
429 error!(
430 self.logger,
431 "Failed to compute the current client API version"
432 );
433 }
434 _ => {}
435 }
436 }
437}
438
439#[cfg_attr(target_family = "wasm", async_trait(?Send))]
440#[cfg_attr(not(target_family = "wasm"), async_trait)]
441impl AggregatorClient for AggregatorHTTPClient {
442 async fn get_content(
443 &self,
444 request: AggregatorRequest,
445 ) -> Result<String, AggregatorClientError> {
446 let response = self.get(self.get_url_for_route(&request.route())?).await?;
447 let content = format!("{response:?}");
448
449 response.text().await.map_err(|e| {
450 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
451 "Could not find a JSON body in the response '{content}'."
452 )))
453 })
454 }
455
456 async fn post_content(
457 &self,
458 request: AggregatorRequest,
459 ) -> Result<String, AggregatorClientError> {
460 let response = self
461 .post(
462 self.get_url_for_route(&request.route())?,
463 &request.get_body().unwrap_or_default(),
464 )
465 .await?;
466
467 response.text().await.map_err(|e| {
468 AggregatorClientError::SubsystemError(
469 anyhow!(e).context("Could not find a text body in the response."),
470 )
471 })
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use httpmock::MockServer;
478 use std::collections::HashMap;
479 use strum::IntoEnumIterator;
480
481 use mithril_common::api_version::APIVersionProvider;
482 use mithril_common::entities::{ClientError, ServerError};
483
484 use crate::test_utils::TestLogger;
485
486 use super::*;
487
488 macro_rules! assert_error_eq {
489 ($left:expr, $right:expr) => {
490 assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
491 };
492 }
493
494 fn setup_client(
495 server_url: &str,
496 api_versions: Vec<Version>,
497 custom_headers: Option<HashMap<String, String>>,
498 ) -> AggregatorHTTPClient {
499 AggregatorHTTPClient::new(
500 Url::parse(server_url).unwrap(),
501 api_versions,
502 TestLogger::stdout(),
503 custom_headers,
504 )
505 .expect("building aggregator http client should not fail")
506 }
507
508 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
509 let server = MockServer::start();
510 let client = setup_client(
511 &server.url(""),
512 APIVersionProvider::compute_all_versions_sorted(),
513 None,
514 );
515 (server, client)
516 }
517
518 fn setup_server_and_client_with_custom_headers(
519 custom_headers: HashMap<String, String>,
520 ) -> (MockServer, AggregatorHTTPClient) {
521 let server = MockServer::start();
522 let client = setup_client(
523 &server.url(""),
524 APIVersionProvider::compute_all_versions_sorted(),
525 Some(custom_headers),
526 );
527 (server, client)
528 }
529
530 #[test]
531 fn always_append_trailing_slash_at_build() {
532 for (expected, url) in [
533 ("http://www.test.net/", "http://www.test.net/"),
534 ("http://www.test.net/", "http://www.test.net"),
535 (
536 "http://www.test.net/aggregator/",
537 "http://www.test.net/aggregator/",
538 ),
539 (
540 "http://www.test.net/aggregator/",
541 "http://www.test.net/aggregator",
542 ),
543 ] {
544 let url = Url::parse(url).unwrap();
545 let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
546 .expect("building aggregator http client should not fail");
547
548 assert_eq!(expected, client.aggregator_endpoint.as_str());
549 }
550 }
551
552 #[test]
553 fn deduce_routes_from_request() {
554 assert_eq!(
555 "certificate/abc".to_string(),
556 AggregatorRequest::GetCertificate {
557 hash: "abc".to_string()
558 }
559 .route()
560 );
561
562 assert_eq!(
563 "artifact/mithril-stake-distribution/abc".to_string(),
564 AggregatorRequest::GetMithrilStakeDistribution {
565 hash: "abc".to_string()
566 }
567 .route()
568 );
569
570 assert_eq!(
571 "artifact/mithril-stake-distribution/abc".to_string(),
572 AggregatorRequest::GetMithrilStakeDistribution {
573 hash: "abc".to_string()
574 }
575 .route()
576 );
577
578 assert_eq!(
579 "artifact/mithril-stake-distributions".to_string(),
580 AggregatorRequest::ListMithrilStakeDistributions.route()
581 );
582
583 assert_eq!(
584 "artifact/snapshot/abc".to_string(),
585 AggregatorRequest::GetSnapshot {
586 digest: "abc".to_string()
587 }
588 .route()
589 );
590
591 assert_eq!(
592 "artifact/snapshots".to_string(),
593 AggregatorRequest::ListSnapshots.route()
594 );
595
596 assert_eq!(
597 "statistics/snapshot".to_string(),
598 AggregatorRequest::IncrementSnapshotStatistic {
599 snapshot: "abc".to_string()
600 }
601 .route()
602 );
603
604 assert_eq!(
605 "artifact/cardano-database/abc".to_string(),
606 AggregatorRequest::GetCardanoDatabaseSnapshot {
607 hash: "abc".to_string()
608 }
609 .route()
610 );
611
612 assert_eq!(
613 "artifact/cardano-database".to_string(),
614 AggregatorRequest::ListCardanoDatabaseSnapshots.route()
615 );
616
617 assert_eq!(
618 "statistics/cardano-database/immutable-files-restored".to_string(),
619 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
620 number_of_immutables: 58
621 }
622 .route()
623 );
624
625 assert_eq!(
626 "statistics/cardano-database/ancillary-files-restored".to_string(),
627 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
628 );
629
630 assert_eq!(
631 "statistics/cardano-database/complete-restoration".to_string(),
632 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
633 );
634
635 assert_eq!(
636 "statistics/cardano-database/partial-restoration".to_string(),
637 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
638 );
639
640 assert_eq!(
641 "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
642 AggregatorRequest::GetTransactionsProofs {
643 transactions_hashes: vec![
644 "abc".to_string(),
645 "def".to_string(),
646 "ghi".to_string(),
647 "jkl".to_string()
648 ]
649 }
650 .route()
651 );
652
653 assert_eq!(
654 "artifact/cardano-transaction/abc".to_string(),
655 AggregatorRequest::GetCardanoTransactionSnapshot {
656 hash: "abc".to_string()
657 }
658 .route()
659 );
660
661 assert_eq!(
662 "artifact/cardano-transactions".to_string(),
663 AggregatorRequest::ListCardanoTransactionSnapshots.route()
664 );
665
666 assert_eq!(
667 "artifact/cardano-stake-distribution/abc".to_string(),
668 AggregatorRequest::GetCardanoStakeDistribution {
669 hash: "abc".to_string()
670 }
671 .route()
672 );
673
674 assert_eq!(
675 "artifact/cardano-stake-distribution/epoch/123".to_string(),
676 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
677 );
678
679 assert_eq!(
680 "artifact/cardano-stake-distributions".to_string(),
681 AggregatorRequest::ListCardanoStakeDistributions.route()
682 );
683 }
684
685 #[test]
686 fn deduce_body_from_request() {
687 fn that_should_not_have_body(req: &AggregatorRequest) -> bool {
688 !matches!(
689 req,
690 AggregatorRequest::IncrementSnapshotStatistic { .. }
691 | AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic { .. }
692 )
693 }
694
695 assert_eq!(
696 Some(r#"{"key":"value"}"#.to_string()),
697 AggregatorRequest::IncrementSnapshotStatistic {
698 snapshot: r#"{"key":"value"}"#.to_string()
699 }
700 .get_body()
701 );
702
703 assert_eq!(
704 Some(
705 serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
706 nb_immutable_files: 432,
707 })
708 .unwrap()
709 ),
710 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
711 number_of_immutables: 432,
712 }
713 .get_body()
714 );
715
716 for req_that_should_not_have_body in
717 AggregatorRequest::iter().filter(that_should_not_have_body)
718 {
719 assert_eq!(None, req_that_should_not_have_body.get_body());
720 }
721 }
722
723 #[tokio::test]
724 async fn test_client_handle_4xx_errors() {
725 let client_error = ClientError::new("label", "message");
726
727 let (aggregator, client) = setup_server_and_client();
728 aggregator.mock(|_when, then| {
729 then.status(StatusCode::IM_A_TEAPOT.as_u16())
730 .json_body_obj(&client_error);
731 });
732
733 let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
734
735 let get_content_error = client
736 .get_content(AggregatorRequest::ListCertificates)
737 .await
738 .unwrap_err();
739 assert_error_eq!(get_content_error, expected_error);
740
741 let post_content_error = client
742 .post_content(AggregatorRequest::ListCertificates)
743 .await
744 .unwrap_err();
745 assert_error_eq!(post_content_error, expected_error);
746 }
747
748 #[tokio::test]
749 async fn test_client_handle_404_not_found_error() {
750 let client_error = ClientError::new("label", "message");
751
752 let (aggregator, client) = setup_server_and_client();
753 aggregator.mock(|_when, then| {
754 then.status(StatusCode::NOT_FOUND.as_u16())
755 .json_body_obj(&client_error);
756 });
757
758 let expected_error = AggregatorHTTPClient::not_found_error(
759 Url::parse(&format!(
760 "{}/{}",
761 aggregator.base_url(),
762 AggregatorRequest::ListCertificates.route()
763 ))
764 .unwrap(),
765 );
766
767 let get_content_error = client
768 .get_content(AggregatorRequest::ListCertificates)
769 .await
770 .unwrap_err();
771 assert_error_eq!(get_content_error, expected_error);
772
773 let post_content_error = client
774 .post_content(AggregatorRequest::ListCertificates)
775 .await
776 .unwrap_err();
777 assert_error_eq!(post_content_error, expected_error);
778 }
779
780 #[tokio::test]
781 async fn test_client_handle_5xx_errors() {
782 let server_error = ServerError::new("message");
783
784 let (aggregator, client) = setup_server_and_client();
785 aggregator.mock(|_when, then| {
786 then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
787 .json_body_obj(&server_error);
788 });
789
790 let expected_error =
791 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
792
793 let get_content_error = client
794 .get_content(AggregatorRequest::ListCertificates)
795 .await
796 .unwrap_err();
797 assert_error_eq!(get_content_error, expected_error);
798
799 let post_content_error = client
800 .post_content(AggregatorRequest::ListCertificates)
801 .await
802 .unwrap_err();
803 assert_error_eq!(post_content_error, expected_error);
804 }
805
806 #[tokio::test]
807 async fn test_client_with_custom_headers() {
808 let mut http_headers = HashMap::new();
809 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
810 http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
811 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
812 aggregator.mock(|when, then| {
813 when.header("Custom-Header", "CustomValue")
814 .header("Another-Header", "AnotherValue");
815 then.status(StatusCode::OK.as_u16()).body("ok");
816 });
817
818 client
819 .get_content(AggregatorRequest::ListCertificates)
820 .await
821 .expect("GET request should succeed");
822
823 client
824 .post_content(AggregatorRequest::ListCertificates)
825 .await
826 .expect("GET request should succeed");
827 }
828
829 #[tokio::test]
830 async fn test_client_sends_accept_encoding_header_with_correct_values() {
831 let (aggregator, client) = setup_server_and_client();
832 aggregator.mock(|when, then| {
833 when.matches(|req| {
834 let headers = req.headers.clone().expect("HTTP headers not found");
835 let accept_encoding_header = headers
836 .iter()
837 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
838 .expect("Accept-Encoding header not found");
839
840 let header_value = accept_encoding_header.clone().1;
841 ["gzip", "br", "deflate", "zstd"]
842 .iter()
843 .all(|&value| header_value.contains(value))
844 });
845
846 then.status(200).body("ok");
847 });
848
849 client
850 .get_content(AggregatorRequest::ListCertificates)
851 .await
852 .expect("GET request should succeed with Accept-Encoding header");
853 }
854
855 #[tokio::test]
856 async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
857 let mut http_headers = HashMap::new();
858 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
859 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
860 aggregator.mock(|when, then| {
861 when.matches(|req| {
862 let headers = req.headers.clone().expect("HTTP headers not found");
863 let accept_encoding_header = headers
864 .iter()
865 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
866 .expect("Accept-Encoding header not found");
867
868 let header_value = accept_encoding_header.clone().1;
869 ["gzip", "br", "deflate", "zstd"]
870 .iter()
871 .all(|&value| header_value.contains(value))
872 });
873
874 then.status(200).body("ok");
875 });
876
877 client
878 .get_content(AggregatorRequest::ListCertificates)
879 .await
880 .expect("GET request should succeed with Accept-Encoding header");
881 }
882
883 mod warn_if_api_version_mismatch {
884 use http::response::Builder as HttpResponseBuilder;
885
886 use mithril_common::test_utils::MemoryDrainForTestInspector;
887
888 use super::*;
889
890 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
891 key: K,
892 value: V,
893 ) -> Response {
894 HttpResponseBuilder::new()
895 .header(key.into(), value.into())
896 .body("whatever")
897 .unwrap()
898 .into()
899 }
900
901 fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
902 log_inspector: &MemoryDrainForTestInspector,
903 aggregator_version: A,
904 client_version: S,
905 ) {
906 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
907 assert!(log_inspector
908 .contains_log(&format!("aggregator_version={}", aggregator_version.into())));
909 assert!(
910 log_inspector.contains_log(&format!("client_version={}", client_version.into()))
911 );
912 }
913
914 #[tokio::test]
915 async fn test_logs_warning_when_aggregator_api_version_is_newer() {
916 let aggregator_version = "2.0.0";
917 let client_version = "1.0.0";
918 let (logger, log_inspector) = TestLogger::memory();
919 let mut client = setup_client(
920 "http://whatever",
921 vec![Version::parse(client_version).unwrap()],
922 None,
923 );
924 client.logger = logger;
925 let response =
926 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
927
928 assert!(
929 Version::parse(aggregator_version).unwrap()
930 > Version::parse(client_version).unwrap()
931 );
932
933 client.warn_if_api_version_mismatch(&response).await;
934
935 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
936 }
937
938 #[tokio::test]
939 async fn test_no_warning_logged_when_versions_match() {
940 let version = "1.0.0";
941 let (logger, log_inspector) = TestLogger::memory();
942 let mut client = setup_client(
943 "http://whatever",
944 vec![Version::parse(version).unwrap()],
945 None,
946 );
947 client.logger = logger;
948 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
949
950 client.warn_if_api_version_mismatch(&response).await;
951
952 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
953 }
954
955 #[tokio::test]
956 async fn test_no_warning_logged_when_aggregator_api_version_is_older() {
957 let aggregator_version = "1.0.0";
958 let client_version = "2.0.0";
959 let (logger, log_inspector) = TestLogger::memory();
960 let mut client = setup_client(
961 "http://whatever",
962 vec![Version::parse(client_version).unwrap()],
963 None,
964 );
965 client.logger = logger;
966 let response =
967 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
968
969 assert!(
970 Version::parse(aggregator_version).unwrap()
971 < Version::parse(client_version).unwrap()
972 );
973
974 client.warn_if_api_version_mismatch(&response).await;
975
976 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
977 }
978
979 #[tokio::test]
980 async fn test_does_not_log_or_fail_when_header_is_missing() {
981 let (logger, log_inspector) = TestLogger::memory();
982 let mut client = setup_client(
983 "http://whatever",
984 APIVersionProvider::compute_all_versions_sorted(),
985 None,
986 );
987 client.logger = logger;
988 let response =
989 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
990
991 client.warn_if_api_version_mismatch(&response).await;
992
993 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
994 }
995
996 #[tokio::test]
997 async fn test_does_not_log_or_fail_when_header_is_not_a_version() {
998 let (logger, log_inspector) = TestLogger::memory();
999 let mut client = setup_client(
1000 "http://whatever",
1001 APIVersionProvider::compute_all_versions_sorted(),
1002 None,
1003 );
1004 client.logger = logger;
1005 let response =
1006 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1007
1008 client.warn_if_api_version_mismatch(&response).await;
1009
1010 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1011 }
1012
1013 #[tokio::test]
1014 async fn test_logs_error_when_client_version_cannot_be_computed() {
1015 let (logger, log_inspector) = TestLogger::memory();
1016 let mut client = setup_client("http://whatever", vec![], None);
1017 client.logger = logger;
1018 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1019
1020 client.warn_if_api_version_mismatch(&response).await;
1021
1022 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1023 }
1024
1025 #[tokio::test]
1026 async fn test_client_get_log_warning_if_api_version_mismatch() {
1027 let aggregator_version = "2.0.0";
1028 let client_version = "1.0.0";
1029 let (server, mut client) = setup_server_and_client();
1030 let (logger, log_inspector) = TestLogger::memory();
1031 client.api_versions =
1032 Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1033 client.logger = logger;
1034 server.mock(|_, then| {
1035 then.status(StatusCode::OK.as_u16())
1036 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1037 });
1038
1039 assert!(
1040 Version::parse(aggregator_version).unwrap()
1041 > Version::parse(client_version).unwrap()
1042 );
1043
1044 client
1045 .get(Url::parse(&server.base_url()).unwrap())
1046 .await
1047 .unwrap();
1048
1049 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1050 }
1051
1052 #[tokio::test]
1053 async fn test_client_post_log_warning_if_api_version_mismatch() {
1054 let aggregator_version = "2.0.0";
1055 let client_version = "1.0.0";
1056 let (server, mut client) = setup_server_and_client();
1057 let (logger, log_inspector) = TestLogger::memory();
1058 client.api_versions =
1059 Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1060 client.logger = logger;
1061 server.mock(|_, then| {
1062 then.status(StatusCode::OK.as_u16())
1063 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1064 });
1065
1066 assert!(
1067 Version::parse(aggregator_version).unwrap()
1068 > Version::parse(client_version).unwrap()
1069 );
1070
1071 client
1072 .post(Url::parse(&server.base_url()).unwrap(), "whatever")
1073 .await
1074 .unwrap();
1075
1076 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1077 }
1078 }
1079}