1use anyhow::{Context, anyhow};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{Logger, debug, error, warn};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::MITHRIL_API_VERSION_HEADER;
23use mithril_common::entities::{ClientError, ServerError};
24use mithril_common::logging::LoggerExtensions;
25use mithril_common::messages::CardanoDatabaseImmutableFilesRestoredMessage;
26
27use crate::common::Epoch;
28use crate::{MithrilError, MithrilResult};
29
30const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
31
32#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35 #[error("Internal error of the Aggregator")]
37 RemoteServerTechnical(#[source] MithrilError),
38
39 #[error("Invalid request to the Aggregator")]
41 RemoteServerLogical(#[source] MithrilError),
42
43 #[error("HTTP subsystem error")]
45 SubsystemError(#[source] MithrilError),
46}
47
48#[derive(Debug, Clone, Eq, PartialEq)]
50#[cfg_attr(test, derive(strum::EnumIter))]
51pub enum AggregatorRequest {
52 GetCertificate {
54 hash: String,
56 },
57
58 ListCertificates,
60
61 GetMithrilStakeDistribution {
63 hash: String,
65 },
66
67 ListMithrilStakeDistributions,
69
70 GetSnapshot {
72 digest: String,
74 },
75
76 ListSnapshots,
78
79 IncrementSnapshotStatistic {
81 snapshot: String,
83 },
84
85 GetCardanoDatabaseSnapshot {
87 hash: String,
89 },
90
91 ListCardanoDatabaseSnapshots,
93
94 IncrementCardanoDatabaseImmutablesRestoredStatistic {
96 number_of_immutables: u64,
98 },
99
100 IncrementCardanoDatabaseAncillaryStatistic,
102
103 IncrementCardanoDatabaseCompleteRestorationStatistic,
105
106 IncrementCardanoDatabasePartialRestorationStatistic,
108
109 GetTransactionsProofs {
111 transactions_hashes: Vec<String>,
113 },
114
115 GetCardanoTransactionSnapshot {
117 hash: String,
119 },
120
121 ListCardanoTransactionSnapshots,
123
124 GetCardanoStakeDistribution {
126 hash: String,
128 },
129
130 GetCardanoStakeDistributionByEpoch {
132 epoch: Epoch,
134 },
135
136 ListCardanoStakeDistributions,
138
139 Status,
141}
142
143impl AggregatorRequest {
144 pub fn route(&self) -> String {
146 match self {
147 AggregatorRequest::GetCertificate { hash } => {
148 format!("certificate/{hash}")
149 }
150 AggregatorRequest::ListCertificates => "certificates".to_string(),
151 AggregatorRequest::GetMithrilStakeDistribution { hash } => {
152 format!("artifact/mithril-stake-distribution/{hash}")
153 }
154 AggregatorRequest::ListMithrilStakeDistributions => {
155 "artifact/mithril-stake-distributions".to_string()
156 }
157 AggregatorRequest::GetSnapshot { digest } => {
158 format!("artifact/snapshot/{digest}")
159 }
160 AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
161 AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
162 "statistics/snapshot".to_string()
163 }
164 AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
165 format!("artifact/cardano-database/{hash}")
166 }
167 AggregatorRequest::ListCardanoDatabaseSnapshots => {
168 "artifact/cardano-database".to_string()
169 }
170 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
171 number_of_immutables: _,
172 } => "statistics/cardano-database/immutable-files-restored".to_string(),
173 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
174 "statistics/cardano-database/ancillary-files-restored".to_string()
175 }
176 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
177 "statistics/cardano-database/complete-restoration".to_string()
178 }
179 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
180 "statistics/cardano-database/partial-restoration".to_string()
181 }
182 AggregatorRequest::GetTransactionsProofs {
183 transactions_hashes,
184 } => format!(
185 "proof/cardano-transaction?transaction_hashes={}",
186 transactions_hashes.join(",")
187 ),
188 AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
189 format!("artifact/cardano-transaction/{hash}")
190 }
191 AggregatorRequest::ListCardanoTransactionSnapshots => {
192 "artifact/cardano-transactions".to_string()
193 }
194 AggregatorRequest::GetCardanoStakeDistribution { hash } => {
195 format!("artifact/cardano-stake-distribution/{hash}")
196 }
197 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
198 format!("artifact/cardano-stake-distribution/epoch/{epoch}")
199 }
200 AggregatorRequest::ListCardanoStakeDistributions => {
201 "artifact/cardano-stake-distributions".to_string()
202 }
203 AggregatorRequest::Status => "status".to_string(),
204 }
205 }
206
207 pub fn get_body(&self) -> Option<String> {
209 match self {
210 AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
211 Some(snapshot.to_string())
212 }
213 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
214 number_of_immutables,
215 } => serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
216 nb_immutable_files: *number_of_immutables as u32,
217 })
218 .ok(),
219 _ => None,
220 }
221 }
222}
223
224#[cfg_attr(test, mockall::automock)]
226#[cfg_attr(target_family = "wasm", async_trait(?Send))]
227#[cfg_attr(not(target_family = "wasm"), async_trait)]
228pub trait AggregatorClient: Sync + Send {
229 async fn get_content(
231 &self,
232 request: AggregatorRequest,
233 ) -> Result<String, AggregatorClientError>;
234
235 async fn post_content(
237 &self,
238 request: AggregatorRequest,
239 ) -> Result<String, AggregatorClientError>;
240}
241
242pub struct AggregatorHTTPClient {
244 http_client: reqwest::Client,
245 aggregator_endpoint: Url,
246 api_versions: Arc<RwLock<Vec<Version>>>,
247 logger: Logger,
248 http_headers: HeaderMap,
249}
250
251impl AggregatorHTTPClient {
252 pub fn new(
254 aggregator_endpoint: Url,
255 api_versions: Vec<Version>,
256 logger: Logger,
257 custom_headers: Option<HashMap<String, String>>,
258 ) -> MithrilResult<Self> {
259 let http_client = reqwest::ClientBuilder::new()
260 .build()
261 .with_context(|| "Building http client for Aggregator client failed")?;
262
263 let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
267 aggregator_endpoint
268 } else {
269 let mut url = aggregator_endpoint.clone();
270 url.set_path(&format!("{}/", aggregator_endpoint.path()));
271 url
272 };
273
274 let mut http_headers = HeaderMap::new();
275 if let Some(headers) = custom_headers {
276 for (key, value) in headers.iter() {
277 http_headers.insert(
278 HeaderName::from_bytes(key.as_bytes())?,
279 HeaderValue::from_str(value)?,
280 );
281 }
282 }
283
284 Ok(Self {
285 http_client,
286 aggregator_endpoint,
287 api_versions: Arc::new(RwLock::new(api_versions)),
288 logger: logger.new_with_component_name::<Self>(),
289 http_headers,
290 })
291 }
292
293 async fn compute_current_api_version(&self) -> Option<Version> {
295 self.api_versions.read().await.first().cloned()
296 }
297
298 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
300 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
301 async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
302 debug!(self.logger, "GET url='{url}'.");
303 let request_builder = self.http_client.get(url.clone());
304 let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
305 debug!(
306 self.logger,
307 "Prepare request with version: {current_api_version}"
308 );
309 let request_builder = request_builder
310 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
311 .headers(self.http_headers.clone());
312
313 let response = request_builder.send().await.map_err(|e| {
314 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
315 "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
316 )))
317 })?;
318
319 match response.status() {
320 StatusCode::OK => {
321 self.warn_if_api_version_mismatch(&response).await;
322
323 Ok(response)
324 }
325 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
326 status_code if status_code.is_client_error() => {
327 Err(Self::remote_logical_error(response).await)
328 }
329 _ => Err(Self::remote_technical_error(response).await),
330 }
331 }
332
333 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
334 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
335 async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
336 debug!(self.logger, "POST url='{url}'"; "json" => json);
337 let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
338 let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
339 debug!(
340 self.logger,
341 "Prepare request with version: {current_api_version}"
342 );
343 let request_builder = request_builder
344 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
345 .headers(self.http_headers.clone());
346
347 let response = request_builder.send().await.map_err(|e| {
348 AggregatorClientError::SubsystemError(
349 anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
350 )
351 })?;
352
353 match response.status() {
354 StatusCode::OK | StatusCode::CREATED => {
355 self.warn_if_api_version_mismatch(&response).await;
356
357 Ok(response)
358 }
359 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
360 status_code if status_code.is_client_error() => {
361 Err(Self::remote_logical_error(response).await)
362 }
363 _ => Err(Self::remote_technical_error(response).await),
364 }
365 }
366
367 fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
368 self.aggregator_endpoint
369 .join(endpoint)
370 .with_context(|| {
371 format!(
372 "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
373 self.aggregator_endpoint
374 )
375 })
376 .map_err(AggregatorClientError::SubsystemError)
377 }
378
379 fn not_found_error(url: Url) -> AggregatorClientError {
380 AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
381 }
382
383 async fn remote_logical_error(response: Response) -> AggregatorClientError {
384 let status_code = response.status();
385 let client_error = response.json::<ClientError>().await.unwrap_or(ClientError::new(
386 format!("Unhandled error {status_code}"),
387 "",
388 ));
389
390 AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
391 }
392
393 async fn remote_technical_error(response: Response) -> AggregatorClientError {
394 let status_code = response.status();
395 let server_error = response
396 .json::<ServerError>()
397 .await
398 .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
399
400 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
401 }
402
403 async fn warn_if_api_version_mismatch(&self, response: &Response) {
405 let aggregator_version = response
406 .headers()
407 .get(MITHRIL_API_VERSION_HEADER)
408 .and_then(|v| v.to_str().ok())
409 .and_then(|s| Version::parse(s).ok());
410
411 let client_version = self.compute_current_api_version().await;
412
413 match (aggregator_version, client_version) {
414 (Some(aggregator), Some(client)) if client < aggregator => {
415 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
416 "aggregator_version" => %aggregator,
417 "client_version" => %client,
418 );
419 }
420 (Some(_), None) => {
421 error!(
422 self.logger,
423 "Failed to compute the current client API version"
424 );
425 }
426 _ => {}
427 }
428 }
429}
430
431#[cfg_attr(target_family = "wasm", async_trait(?Send))]
432#[cfg_attr(not(target_family = "wasm"), async_trait)]
433impl AggregatorClient for AggregatorHTTPClient {
434 async fn get_content(
435 &self,
436 request: AggregatorRequest,
437 ) -> Result<String, AggregatorClientError> {
438 let response = self.get(self.get_url_for_route(&request.route())?).await?;
439 let content = format!("{response:?}");
440
441 response.text().await.map_err(|e| {
442 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
443 "Could not find a JSON body in the response '{content}'."
444 )))
445 })
446 }
447
448 async fn post_content(
449 &self,
450 request: AggregatorRequest,
451 ) -> Result<String, AggregatorClientError> {
452 let response = self
453 .post(
454 self.get_url_for_route(&request.route())?,
455 &request.get_body().unwrap_or_default(),
456 )
457 .await?;
458
459 response.text().await.map_err(|e| {
460 AggregatorClientError::SubsystemError(
461 anyhow!(e).context("Could not find a text body in the response."),
462 )
463 })
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use httpmock::MockServer;
470 use std::collections::HashMap;
471 use strum::IntoEnumIterator;
472
473 use mithril_common::api_version::APIVersionProvider;
474 use mithril_common::entities::{ClientError, ServerError};
475
476 use crate::test_utils::TestLogger;
477
478 use super::*;
479
480 macro_rules! assert_error_eq {
481 ($left:expr, $right:expr) => {
482 assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
483 };
484 }
485
486 fn setup_client(
487 server_url: &str,
488 api_versions: Vec<Version>,
489 custom_headers: Option<HashMap<String, String>>,
490 ) -> AggregatorHTTPClient {
491 AggregatorHTTPClient::new(
492 Url::parse(server_url).unwrap(),
493 api_versions,
494 TestLogger::stdout(),
495 custom_headers,
496 )
497 .expect("building aggregator http client should not fail")
498 }
499
500 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
501 let server = MockServer::start();
502 let client = setup_client(
503 &server.url(""),
504 APIVersionProvider::compute_all_versions_sorted(),
505 None,
506 );
507 (server, client)
508 }
509
510 fn setup_server_and_client_with_custom_headers(
511 custom_headers: HashMap<String, String>,
512 ) -> (MockServer, AggregatorHTTPClient) {
513 let server = MockServer::start();
514 let client = setup_client(
515 &server.url(""),
516 APIVersionProvider::compute_all_versions_sorted(),
517 Some(custom_headers),
518 );
519 (server, client)
520 }
521
522 #[test]
523 fn always_append_trailing_slash_at_build() {
524 for (expected, url) in [
525 ("http://www.test.net/", "http://www.test.net/"),
526 ("http://www.test.net/", "http://www.test.net"),
527 (
528 "http://www.test.net/aggregator/",
529 "http://www.test.net/aggregator/",
530 ),
531 (
532 "http://www.test.net/aggregator/",
533 "http://www.test.net/aggregator",
534 ),
535 ] {
536 let url = Url::parse(url).unwrap();
537 let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
538 .expect("building aggregator http client should not fail");
539
540 assert_eq!(expected, client.aggregator_endpoint.as_str());
541 }
542 }
543
544 #[test]
545 fn deduce_routes_from_request() {
546 assert_eq!(
547 "certificate/abc".to_string(),
548 AggregatorRequest::GetCertificate {
549 hash: "abc".to_string()
550 }
551 .route()
552 );
553
554 assert_eq!(
555 "artifact/mithril-stake-distribution/abc".to_string(),
556 AggregatorRequest::GetMithrilStakeDistribution {
557 hash: "abc".to_string()
558 }
559 .route()
560 );
561
562 assert_eq!(
563 "artifact/mithril-stake-distribution/abc".to_string(),
564 AggregatorRequest::GetMithrilStakeDistribution {
565 hash: "abc".to_string()
566 }
567 .route()
568 );
569
570 assert_eq!(
571 "artifact/mithril-stake-distributions".to_string(),
572 AggregatorRequest::ListMithrilStakeDistributions.route()
573 );
574
575 assert_eq!(
576 "artifact/snapshot/abc".to_string(),
577 AggregatorRequest::GetSnapshot {
578 digest: "abc".to_string()
579 }
580 .route()
581 );
582
583 assert_eq!(
584 "artifact/snapshots".to_string(),
585 AggregatorRequest::ListSnapshots.route()
586 );
587
588 assert_eq!(
589 "statistics/snapshot".to_string(),
590 AggregatorRequest::IncrementSnapshotStatistic {
591 snapshot: "abc".to_string()
592 }
593 .route()
594 );
595
596 assert_eq!(
597 "artifact/cardano-database/abc".to_string(),
598 AggregatorRequest::GetCardanoDatabaseSnapshot {
599 hash: "abc".to_string()
600 }
601 .route()
602 );
603
604 assert_eq!(
605 "artifact/cardano-database".to_string(),
606 AggregatorRequest::ListCardanoDatabaseSnapshots.route()
607 );
608
609 assert_eq!(
610 "statistics/cardano-database/immutable-files-restored".to_string(),
611 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
612 number_of_immutables: 58
613 }
614 .route()
615 );
616
617 assert_eq!(
618 "statistics/cardano-database/ancillary-files-restored".to_string(),
619 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
620 );
621
622 assert_eq!(
623 "statistics/cardano-database/complete-restoration".to_string(),
624 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
625 );
626
627 assert_eq!(
628 "statistics/cardano-database/partial-restoration".to_string(),
629 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
630 );
631
632 assert_eq!(
633 "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
634 AggregatorRequest::GetTransactionsProofs {
635 transactions_hashes: vec![
636 "abc".to_string(),
637 "def".to_string(),
638 "ghi".to_string(),
639 "jkl".to_string()
640 ]
641 }
642 .route()
643 );
644
645 assert_eq!(
646 "artifact/cardano-transaction/abc".to_string(),
647 AggregatorRequest::GetCardanoTransactionSnapshot {
648 hash: "abc".to_string()
649 }
650 .route()
651 );
652
653 assert_eq!(
654 "artifact/cardano-transactions".to_string(),
655 AggregatorRequest::ListCardanoTransactionSnapshots.route()
656 );
657
658 assert_eq!(
659 "artifact/cardano-stake-distribution/abc".to_string(),
660 AggregatorRequest::GetCardanoStakeDistribution {
661 hash: "abc".to_string()
662 }
663 .route()
664 );
665
666 assert_eq!(
667 "artifact/cardano-stake-distribution/epoch/123".to_string(),
668 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
669 );
670
671 assert_eq!(
672 "artifact/cardano-stake-distributions".to_string(),
673 AggregatorRequest::ListCardanoStakeDistributions.route()
674 );
675
676 assert_eq!("status".to_string(), AggregatorRequest::Status.route());
677 }
678
679 #[test]
680 fn deduce_body_from_request() {
681 fn that_should_not_have_body(req: &AggregatorRequest) -> bool {
682 !matches!(
683 req,
684 AggregatorRequest::IncrementSnapshotStatistic { .. }
685 | AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic { .. }
686 )
687 }
688
689 assert_eq!(
690 Some(r#"{"key":"value"}"#.to_string()),
691 AggregatorRequest::IncrementSnapshotStatistic {
692 snapshot: r#"{"key":"value"}"#.to_string()
693 }
694 .get_body()
695 );
696
697 assert_eq!(
698 Some(
699 serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
700 nb_immutable_files: 432,
701 })
702 .unwrap()
703 ),
704 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
705 number_of_immutables: 432,
706 }
707 .get_body()
708 );
709
710 for req_that_should_not_have_body in
711 AggregatorRequest::iter().filter(that_should_not_have_body)
712 {
713 assert_eq!(None, req_that_should_not_have_body.get_body());
714 }
715 }
716
717 #[tokio::test]
718 async fn test_client_handle_4xx_errors() {
719 let client_error = ClientError::new("label", "message");
720
721 let (aggregator, client) = setup_server_and_client();
722 aggregator.mock(|_when, then| {
723 then.status(StatusCode::IM_A_TEAPOT.as_u16())
724 .json_body_obj(&client_error);
725 });
726
727 let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
728
729 let get_content_error = client
730 .get_content(AggregatorRequest::ListCertificates)
731 .await
732 .unwrap_err();
733 assert_error_eq!(get_content_error, expected_error);
734
735 let post_content_error = client
736 .post_content(AggregatorRequest::ListCertificates)
737 .await
738 .unwrap_err();
739 assert_error_eq!(post_content_error, expected_error);
740 }
741
742 #[tokio::test]
743 async fn test_client_handle_404_not_found_error() {
744 let client_error = ClientError::new("label", "message");
745
746 let (aggregator, client) = setup_server_and_client();
747 aggregator.mock(|_when, then| {
748 then.status(StatusCode::NOT_FOUND.as_u16())
749 .json_body_obj(&client_error);
750 });
751
752 let expected_error = AggregatorHTTPClient::not_found_error(
753 Url::parse(&format!(
754 "{}/{}",
755 aggregator.base_url(),
756 AggregatorRequest::ListCertificates.route()
757 ))
758 .unwrap(),
759 );
760
761 let get_content_error = client
762 .get_content(AggregatorRequest::ListCertificates)
763 .await
764 .unwrap_err();
765 assert_error_eq!(get_content_error, expected_error);
766
767 let post_content_error = client
768 .post_content(AggregatorRequest::ListCertificates)
769 .await
770 .unwrap_err();
771 assert_error_eq!(post_content_error, expected_error);
772 }
773
774 #[tokio::test]
775 async fn test_client_handle_5xx_errors() {
776 let server_error = ServerError::new("message");
777
778 let (aggregator, client) = setup_server_and_client();
779 aggregator.mock(|_when, then| {
780 then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
781 .json_body_obj(&server_error);
782 });
783
784 let expected_error =
785 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
786
787 let get_content_error = client
788 .get_content(AggregatorRequest::ListCertificates)
789 .await
790 .unwrap_err();
791 assert_error_eq!(get_content_error, expected_error);
792
793 let post_content_error = client
794 .post_content(AggregatorRequest::ListCertificates)
795 .await
796 .unwrap_err();
797 assert_error_eq!(post_content_error, expected_error);
798 }
799
800 #[tokio::test]
801 async fn test_client_with_custom_headers() {
802 let mut http_headers = HashMap::new();
803 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
804 http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
805 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
806 aggregator.mock(|when, then| {
807 when.header("Custom-Header", "CustomValue")
808 .header("Another-Header", "AnotherValue");
809 then.status(StatusCode::OK.as_u16()).body("ok");
810 });
811
812 client
813 .get_content(AggregatorRequest::ListCertificates)
814 .await
815 .expect("GET request should succeed");
816
817 client
818 .post_content(AggregatorRequest::ListCertificates)
819 .await
820 .expect("GET request should succeed");
821 }
822
823 #[tokio::test]
824 async fn test_client_sends_accept_encoding_header_with_correct_values() {
825 let (aggregator, client) = setup_server_and_client();
826 aggregator.mock(|when, then| {
827 when.matches(|req| {
828 let headers = req.headers.clone().expect("HTTP headers not found");
829 let accept_encoding_header = headers
830 .iter()
831 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
832 .expect("Accept-Encoding header not found");
833
834 let header_value = accept_encoding_header.clone().1;
835 ["gzip", "br", "deflate", "zstd"]
836 .iter()
837 .all(|&value| header_value.contains(value))
838 });
839
840 then.status(200).body("ok");
841 });
842
843 client
844 .get_content(AggregatorRequest::ListCertificates)
845 .await
846 .expect("GET request should succeed with Accept-Encoding header");
847 }
848
849 #[tokio::test]
850 async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
851 let mut http_headers = HashMap::new();
852 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
853 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
854 aggregator.mock(|when, then| {
855 when.matches(|req| {
856 let headers = req.headers.clone().expect("HTTP headers not found");
857 let accept_encoding_header = headers
858 .iter()
859 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
860 .expect("Accept-Encoding header not found");
861
862 let header_value = accept_encoding_header.clone().1;
863 ["gzip", "br", "deflate", "zstd"]
864 .iter()
865 .all(|&value| header_value.contains(value))
866 });
867
868 then.status(200).body("ok");
869 });
870
871 client
872 .get_content(AggregatorRequest::ListCertificates)
873 .await
874 .expect("GET request should succeed with Accept-Encoding header");
875 }
876
877 mod warn_if_api_version_mismatch {
878 use http::response::Builder as HttpResponseBuilder;
879
880 use mithril_common::test_utils::MemoryDrainForTestInspector;
881
882 use super::*;
883
884 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
885 key: K,
886 value: V,
887 ) -> Response {
888 HttpResponseBuilder::new()
889 .header(key.into(), value.into())
890 .body("whatever")
891 .unwrap()
892 .into()
893 }
894
895 fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
896 log_inspector: &MemoryDrainForTestInspector,
897 aggregator_version: A,
898 client_version: S,
899 ) {
900 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
901 assert!(
902 log_inspector
903 .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
904 );
905 assert!(
906 log_inspector.contains_log(&format!("client_version={}", client_version.into()))
907 );
908 }
909
910 #[tokio::test]
911 async fn test_logs_warning_when_aggregator_api_version_is_newer() {
912 let aggregator_version = "2.0.0";
913 let client_version = "1.0.0";
914 let (logger, log_inspector) = TestLogger::memory();
915 let mut client = setup_client(
916 "http://whatever",
917 vec![Version::parse(client_version).unwrap()],
918 None,
919 );
920 client.logger = logger;
921 let response =
922 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
923
924 assert!(
925 Version::parse(aggregator_version).unwrap()
926 > Version::parse(client_version).unwrap()
927 );
928
929 client.warn_if_api_version_mismatch(&response).await;
930
931 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
932 }
933
934 #[tokio::test]
935 async fn test_no_warning_logged_when_versions_match() {
936 let version = "1.0.0";
937 let (logger, log_inspector) = TestLogger::memory();
938 let mut client = setup_client(
939 "http://whatever",
940 vec![Version::parse(version).unwrap()],
941 None,
942 );
943 client.logger = logger;
944 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
945
946 client.warn_if_api_version_mismatch(&response).await;
947
948 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
949 }
950
951 #[tokio::test]
952 async fn test_no_warning_logged_when_aggregator_api_version_is_older() {
953 let aggregator_version = "1.0.0";
954 let client_version = "2.0.0";
955 let (logger, log_inspector) = TestLogger::memory();
956 let mut client = setup_client(
957 "http://whatever",
958 vec![Version::parse(client_version).unwrap()],
959 None,
960 );
961 client.logger = logger;
962 let response =
963 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
964
965 assert!(
966 Version::parse(aggregator_version).unwrap()
967 < Version::parse(client_version).unwrap()
968 );
969
970 client.warn_if_api_version_mismatch(&response).await;
971
972 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
973 }
974
975 #[tokio::test]
976 async fn test_does_not_log_or_fail_when_header_is_missing() {
977 let (logger, log_inspector) = TestLogger::memory();
978 let mut client = setup_client(
979 "http://whatever",
980 APIVersionProvider::compute_all_versions_sorted(),
981 None,
982 );
983 client.logger = logger;
984 let response =
985 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
986
987 client.warn_if_api_version_mismatch(&response).await;
988
989 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
990 }
991
992 #[tokio::test]
993 async fn test_does_not_log_or_fail_when_header_is_not_a_version() {
994 let (logger, log_inspector) = TestLogger::memory();
995 let mut client = setup_client(
996 "http://whatever",
997 APIVersionProvider::compute_all_versions_sorted(),
998 None,
999 );
1000 client.logger = logger;
1001 let response =
1002 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1003
1004 client.warn_if_api_version_mismatch(&response).await;
1005
1006 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1007 }
1008
1009 #[tokio::test]
1010 async fn test_logs_error_when_client_version_cannot_be_computed() {
1011 let (logger, log_inspector) = TestLogger::memory();
1012 let mut client = setup_client("http://whatever", vec![], None);
1013 client.logger = logger;
1014 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1015
1016 client.warn_if_api_version_mismatch(&response).await;
1017
1018 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1019 }
1020
1021 #[tokio::test]
1022 async fn test_client_get_log_warning_if_api_version_mismatch() {
1023 let aggregator_version = "2.0.0";
1024 let client_version = "1.0.0";
1025 let (server, mut client) = setup_server_and_client();
1026 let (logger, log_inspector) = TestLogger::memory();
1027 client.api_versions =
1028 Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1029 client.logger = logger;
1030 server.mock(|_, then| {
1031 then.status(StatusCode::OK.as_u16())
1032 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1033 });
1034
1035 assert!(
1036 Version::parse(aggregator_version).unwrap()
1037 > Version::parse(client_version).unwrap()
1038 );
1039
1040 client.get(Url::parse(&server.base_url()).unwrap()).await.unwrap();
1041
1042 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1043 }
1044
1045 #[tokio::test]
1046 async fn test_client_post_log_warning_if_api_version_mismatch() {
1047 let aggregator_version = "2.0.0";
1048 let client_version = "1.0.0";
1049 let (server, mut client) = setup_server_and_client();
1050 let (logger, log_inspector) = TestLogger::memory();
1051 client.api_versions =
1052 Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1053 client.logger = logger;
1054 server.mock(|_, then| {
1055 then.status(StatusCode::OK.as_u16())
1056 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1057 });
1058
1059 assert!(
1060 Version::parse(aggregator_version).unwrap()
1061 > Version::parse(client_version).unwrap()
1062 );
1063
1064 client
1065 .post(Url::parse(&server.base_url()).unwrap(), "whatever")
1066 .await
1067 .unwrap();
1068
1069 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1070 }
1071 }
1072}