1use anyhow::{Context, anyhow};
11use async_recursion::async_recursion;
12use async_trait::async_trait;
13use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
14use reqwest::{Response, StatusCode, Url};
15use semver::Version;
16use slog::{Logger, debug, error, warn};
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::RwLock;
21
22use mithril_common::MITHRIL_API_VERSION_HEADER;
23use mithril_common::entities::{ClientError, ServerError};
24use mithril_common::logging::LoggerExtensions;
25use mithril_common::messages::CardanoDatabaseImmutableFilesRestoredMessage;
26
27use crate::common::Epoch;
28use crate::{MithrilError, MithrilResult};
29
30const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incompatible, please update Mithril client library to the latest version.";
31
32#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35 #[error("Internal error of the Aggregator")]
37 RemoteServerTechnical(#[source] MithrilError),
38
39 #[error("Invalid request to the Aggregator")]
41 RemoteServerLogical(#[source] MithrilError),
42
43 #[error("HTTP subsystem error")]
45 SubsystemError(#[source] MithrilError),
46}
47
48#[derive(Debug, Clone, Eq, PartialEq)]
50#[cfg_attr(test, derive(strum::EnumIter))]
51pub enum AggregatorRequest {
52 GetCertificate {
54 hash: String,
56 },
57
58 ListCertificates,
60
61 GetMithrilStakeDistribution {
63 hash: String,
65 },
66
67 ListMithrilStakeDistributions,
69
70 GetSnapshot {
72 digest: String,
74 },
75
76 ListSnapshots,
78
79 IncrementSnapshotStatistic {
81 snapshot: String,
83 },
84
85 GetCardanoDatabaseSnapshot {
87 hash: String,
89 },
90
91 ListCardanoDatabaseSnapshots,
93
94 ListCardanoDatabaseSnapshotByEpoch {
96 epoch: Epoch,
98 },
99
100 ListCardanoDatabaseSnapshotForLatestEpoch {
102 offset: Option<u64>,
104 },
105
106 IncrementCardanoDatabaseImmutablesRestoredStatistic {
108 number_of_immutables: u64,
110 },
111
112 IncrementCardanoDatabaseAncillaryStatistic,
114
115 IncrementCardanoDatabaseCompleteRestorationStatistic,
117
118 IncrementCardanoDatabasePartialRestorationStatistic,
120
121 GetTransactionsProofs {
123 transactions_hashes: Vec<String>,
125 },
126
127 GetCardanoTransactionSnapshot {
129 hash: String,
131 },
132
133 ListCardanoTransactionSnapshots,
135
136 GetCardanoStakeDistribution {
138 hash: String,
140 },
141
142 GetCardanoStakeDistributionByEpoch {
144 epoch: Epoch,
146 },
147
148 GetCardanoStakeDistributionForLatestEpoch {
150 offset: Option<u64>,
152 },
153
154 ListCardanoStakeDistributions,
156
157 Status,
159}
160
161impl AggregatorRequest {
162 pub fn route(&self) -> String {
164 match self {
165 AggregatorRequest::GetCertificate { hash } => {
166 format!("certificate/{hash}")
167 }
168 AggregatorRequest::ListCertificates => "certificates".to_string(),
169 AggregatorRequest::GetMithrilStakeDistribution { hash } => {
170 format!("artifact/mithril-stake-distribution/{hash}")
171 }
172 AggregatorRequest::ListMithrilStakeDistributions => {
173 "artifact/mithril-stake-distributions".to_string()
174 }
175 AggregatorRequest::GetSnapshot { digest } => {
176 format!("artifact/snapshot/{digest}")
177 }
178 AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
179 AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
180 "statistics/snapshot".to_string()
181 }
182 AggregatorRequest::GetCardanoDatabaseSnapshot { hash } => {
183 format!("artifact/cardano-database/{hash}")
184 }
185 AggregatorRequest::ListCardanoDatabaseSnapshots => {
186 "artifact/cardano-database".to_string()
187 }
188 AggregatorRequest::ListCardanoDatabaseSnapshotByEpoch { epoch } => {
189 format!("artifact/cardano-database/epoch/{epoch}")
190 }
191 AggregatorRequest::ListCardanoDatabaseSnapshotForLatestEpoch { offset } => {
192 format!(
193 "artifact/cardano-database/epoch/latest{}",
194 offset.map(|o| format!("-{o}")).unwrap_or_default()
195 )
196 }
197 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
198 number_of_immutables: _,
199 } => "statistics/cardano-database/immutable-files-restored".to_string(),
200 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic => {
201 "statistics/cardano-database/ancillary-files-restored".to_string()
202 }
203 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic => {
204 "statistics/cardano-database/complete-restoration".to_string()
205 }
206 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic => {
207 "statistics/cardano-database/partial-restoration".to_string()
208 }
209 AggregatorRequest::GetTransactionsProofs {
210 transactions_hashes,
211 } => format!(
212 "proof/cardano-transaction?transaction_hashes={}",
213 transactions_hashes.join(",")
214 ),
215 AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
216 format!("artifact/cardano-transaction/{hash}")
217 }
218 AggregatorRequest::ListCardanoTransactionSnapshots => {
219 "artifact/cardano-transactions".to_string()
220 }
221 AggregatorRequest::GetCardanoStakeDistribution { hash } => {
222 format!("artifact/cardano-stake-distribution/{hash}")
223 }
224 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch } => {
225 format!("artifact/cardano-stake-distribution/epoch/{epoch}")
226 }
227 AggregatorRequest::GetCardanoStakeDistributionForLatestEpoch { offset } => {
228 format!(
229 "artifact/cardano-stake-distribution/epoch/latest{}",
230 offset.map(|o| format!("-{o}")).unwrap_or_default()
231 )
232 }
233 AggregatorRequest::ListCardanoStakeDistributions => {
234 "artifact/cardano-stake-distributions".to_string()
235 }
236 AggregatorRequest::Status => "status".to_string(),
237 }
238 }
239
240 pub fn get_body(&self) -> Option<String> {
242 match self {
243 AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
244 Some(snapshot.to_string())
245 }
246 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
247 number_of_immutables,
248 } => serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
249 nb_immutable_files: *number_of_immutables as u32,
250 })
251 .ok(),
252 _ => None,
253 }
254 }
255}
256
257#[cfg_attr(test, mockall::automock)]
259#[cfg_attr(target_family = "wasm", async_trait(?Send))]
260#[cfg_attr(not(target_family = "wasm"), async_trait)]
261pub trait AggregatorClient: Sync + Send {
262 async fn get_content(
264 &self,
265 request: AggregatorRequest,
266 ) -> Result<String, AggregatorClientError>;
267
268 async fn post_content(
270 &self,
271 request: AggregatorRequest,
272 ) -> Result<String, AggregatorClientError>;
273}
274
275pub struct AggregatorHTTPClient {
277 http_client: reqwest::Client,
278 aggregator_endpoint: Url,
279 api_versions: Arc<RwLock<Vec<Version>>>,
280 logger: Logger,
281 http_headers: HeaderMap,
282}
283
284impl AggregatorHTTPClient {
285 pub fn new(
287 aggregator_endpoint: Url,
288 api_versions: Vec<Version>,
289 logger: Logger,
290 custom_headers: Option<HashMap<String, String>>,
291 ) -> MithrilResult<Self> {
292 let http_client = reqwest::ClientBuilder::new()
293 .build()
294 .with_context(|| "Building http client for Aggregator client failed")?;
295
296 let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
300 aggregator_endpoint
301 } else {
302 let mut url = aggregator_endpoint.clone();
303 url.set_path(&format!("{}/", aggregator_endpoint.path()));
304 url
305 };
306
307 let mut http_headers = HeaderMap::new();
308 if let Some(headers) = custom_headers {
309 for (key, value) in headers.iter() {
310 http_headers.insert(
311 HeaderName::from_bytes(key.as_bytes())?,
312 HeaderValue::from_str(value)?,
313 );
314 }
315 }
316
317 Ok(Self {
318 http_client,
319 aggregator_endpoint,
320 api_versions: Arc::new(RwLock::new(api_versions)),
321 logger: logger.new_with_component_name::<Self>(),
322 http_headers,
323 })
324 }
325
326 async fn compute_current_api_version(&self) -> Option<Version> {
328 self.api_versions.read().await.first().cloned()
329 }
330
331 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
333 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
334 async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
335 debug!(self.logger, "GET url='{url}'.");
336 let request_builder = self.http_client.get(url.clone());
337 let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
338 debug!(
339 self.logger,
340 "Prepare request with version: {current_api_version}"
341 );
342 let request_builder = request_builder
343 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
344 .headers(self.http_headers.clone());
345
346 let response = request_builder.send().await.map_err(|e| {
347 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
348 "Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
349 )))
350 })?;
351
352 match response.status() {
353 StatusCode::OK => {
354 self.warn_if_api_version_mismatch(&response).await;
355
356 Ok(response)
357 }
358 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
359 status_code if status_code.is_client_error() => {
360 Err(Self::remote_logical_error(response).await)
361 }
362 _ => Err(Self::remote_technical_error(response).await),
363 }
364 }
365
366 #[cfg_attr(target_family = "wasm", async_recursion(?Send))]
367 #[cfg_attr(not(target_family = "wasm"), async_recursion)]
368 async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
369 debug!(self.logger, "POST url='{url}'"; "json" => json);
370 let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
371 let current_api_version = self.compute_current_api_version().await.unwrap().to_string();
372 debug!(
373 self.logger,
374 "Prepare request with version: {current_api_version}"
375 );
376 let request_builder = request_builder
377 .header(MITHRIL_API_VERSION_HEADER, current_api_version)
378 .headers(self.http_headers.clone());
379
380 let response = request_builder.send().await.map_err(|e| {
381 AggregatorClientError::SubsystemError(
382 anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
383 )
384 })?;
385
386 match response.status() {
387 StatusCode::OK | StatusCode::CREATED => {
388 self.warn_if_api_version_mismatch(&response).await;
389
390 Ok(response)
391 }
392 StatusCode::NOT_FOUND => Err(Self::not_found_error(url)),
393 status_code if status_code.is_client_error() => {
394 Err(Self::remote_logical_error(response).await)
395 }
396 _ => Err(Self::remote_technical_error(response).await),
397 }
398 }
399
400 fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
401 self.aggregator_endpoint
402 .join(endpoint)
403 .with_context(|| {
404 format!(
405 "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
406 self.aggregator_endpoint
407 )
408 })
409 .map_err(AggregatorClientError::SubsystemError)
410 }
411
412 fn not_found_error(url: Url) -> AggregatorClientError {
413 AggregatorClientError::RemoteServerLogical(anyhow!("Url='{url}' not found"))
414 }
415
416 async fn remote_logical_error(response: Response) -> AggregatorClientError {
417 let status_code = response.status();
418 let client_error = response.json::<ClientError>().await.unwrap_or(ClientError::new(
419 format!("Unhandled error {status_code}"),
420 "",
421 ));
422
423 AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"))
424 }
425
426 async fn remote_technical_error(response: Response) -> AggregatorClientError {
427 let status_code = response.status();
428 let server_error = response
429 .json::<ServerError>()
430 .await
431 .unwrap_or(ServerError::new(format!("Unhandled error {status_code}")));
432
433 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"))
434 }
435
436 async fn warn_if_api_version_mismatch(&self, response: &Response) {
438 let aggregator_version = response
439 .headers()
440 .get(MITHRIL_API_VERSION_HEADER)
441 .and_then(|v| v.to_str().ok())
442 .and_then(|s| Version::parse(s).ok());
443
444 let client_version = self.compute_current_api_version().await;
445
446 match (aggregator_version, client_version) {
447 (Some(aggregator), Some(client)) if client < aggregator => {
448 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
449 "aggregator_version" => %aggregator,
450 "client_version" => %client,
451 );
452 }
453 (Some(_), None) => {
454 error!(
455 self.logger,
456 "Failed to compute the current client API version"
457 );
458 }
459 _ => {}
460 }
461 }
462}
463
464#[cfg_attr(target_family = "wasm", async_trait(?Send))]
465#[cfg_attr(not(target_family = "wasm"), async_trait)]
466impl AggregatorClient for AggregatorHTTPClient {
467 async fn get_content(
468 &self,
469 request: AggregatorRequest,
470 ) -> Result<String, AggregatorClientError> {
471 let response = self.get(self.get_url_for_route(&request.route())?).await?;
472 let content = format!("{response:?}");
473
474 response.text().await.map_err(|e| {
475 AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
476 "Could not find a JSON body in the response '{content}'."
477 )))
478 })
479 }
480
481 async fn post_content(
482 &self,
483 request: AggregatorRequest,
484 ) -> Result<String, AggregatorClientError> {
485 let response = self
486 .post(
487 self.get_url_for_route(&request.route())?,
488 &request.get_body().unwrap_or_default(),
489 )
490 .await?;
491
492 response.text().await.map_err(|e| {
493 AggregatorClientError::SubsystemError(
494 anyhow!(e).context("Could not find a text body in the response."),
495 )
496 })
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use httpmock::MockServer;
503 use std::collections::HashMap;
504 use strum::IntoEnumIterator;
505
506 use mithril_common::api_version::APIVersionProvider;
507 use mithril_common::entities::{ClientError, ServerError};
508
509 use crate::test_utils::TestLogger;
510
511 use super::*;
512
513 macro_rules! assert_error_eq {
514 ($left:expr, $right:expr) => {
515 assert_eq!(format!("{:?}", &$left), format!("{:?}", &$right),);
516 };
517 }
518
519 fn setup_client(
520 server_url: &str,
521 api_versions: Vec<Version>,
522 custom_headers: Option<HashMap<String, String>>,
523 ) -> AggregatorHTTPClient {
524 AggregatorHTTPClient::new(
525 Url::parse(server_url).unwrap(),
526 api_versions,
527 TestLogger::stdout(),
528 custom_headers,
529 )
530 .expect("building aggregator http client should not fail")
531 }
532
533 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
534 let server = MockServer::start();
535 let client = setup_client(
536 &server.url(""),
537 APIVersionProvider::compute_all_versions_sorted(),
538 None,
539 );
540 (server, client)
541 }
542
543 fn setup_server_and_client_with_custom_headers(
544 custom_headers: HashMap<String, String>,
545 ) -> (MockServer, AggregatorHTTPClient) {
546 let server = MockServer::start();
547 let client = setup_client(
548 &server.url(""),
549 APIVersionProvider::compute_all_versions_sorted(),
550 Some(custom_headers),
551 );
552 (server, client)
553 }
554
555 #[test]
556 fn always_append_trailing_slash_at_build() {
557 for (expected, url) in [
558 ("http://www.test.net/", "http://www.test.net/"),
559 ("http://www.test.net/", "http://www.test.net"),
560 (
561 "http://www.test.net/aggregator/",
562 "http://www.test.net/aggregator/",
563 ),
564 (
565 "http://www.test.net/aggregator/",
566 "http://www.test.net/aggregator",
567 ),
568 ] {
569 let url = Url::parse(url).unwrap();
570 let client = AggregatorHTTPClient::new(url, vec![], TestLogger::stdout(), None)
571 .expect("building aggregator http client should not fail");
572
573 assert_eq!(expected, client.aggregator_endpoint.as_str());
574 }
575 }
576
577 #[test]
578 fn deduce_routes_from_request() {
579 assert_eq!(
580 "certificate/abc".to_string(),
581 AggregatorRequest::GetCertificate {
582 hash: "abc".to_string()
583 }
584 .route()
585 );
586
587 assert_eq!(
588 "artifact/mithril-stake-distribution/abc".to_string(),
589 AggregatorRequest::GetMithrilStakeDistribution {
590 hash: "abc".to_string()
591 }
592 .route()
593 );
594
595 assert_eq!(
596 "artifact/mithril-stake-distribution/abc".to_string(),
597 AggregatorRequest::GetMithrilStakeDistribution {
598 hash: "abc".to_string()
599 }
600 .route()
601 );
602
603 assert_eq!(
604 "artifact/mithril-stake-distributions".to_string(),
605 AggregatorRequest::ListMithrilStakeDistributions.route()
606 );
607
608 assert_eq!(
609 "artifact/snapshot/abc".to_string(),
610 AggregatorRequest::GetSnapshot {
611 digest: "abc".to_string()
612 }
613 .route()
614 );
615
616 assert_eq!(
617 "artifact/snapshots".to_string(),
618 AggregatorRequest::ListSnapshots.route()
619 );
620
621 assert_eq!(
622 "statistics/snapshot".to_string(),
623 AggregatorRequest::IncrementSnapshotStatistic {
624 snapshot: "abc".to_string()
625 }
626 .route()
627 );
628
629 assert_eq!(
630 "artifact/cardano-database/abc".to_string(),
631 AggregatorRequest::GetCardanoDatabaseSnapshot {
632 hash: "abc".to_string()
633 }
634 .route()
635 );
636
637 assert_eq!(
638 "artifact/cardano-database".to_string(),
639 AggregatorRequest::ListCardanoDatabaseSnapshots.route()
640 );
641
642 assert_eq!(
643 "artifact/cardano-database/epoch/5".to_string(),
644 AggregatorRequest::ListCardanoDatabaseSnapshotByEpoch { epoch: Epoch(5) }.route()
645 );
646
647 assert_eq!(
648 "artifact/cardano-database/epoch/latest".to_string(),
649 AggregatorRequest::ListCardanoDatabaseSnapshotForLatestEpoch { offset: None }.route()
650 );
651
652 assert_eq!(
653 "artifact/cardano-database/epoch/latest-6".to_string(),
654 AggregatorRequest::ListCardanoDatabaseSnapshotForLatestEpoch { offset: Some(6) }
655 .route()
656 );
657
658 assert_eq!(
659 "statistics/cardano-database/immutable-files-restored".to_string(),
660 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
661 number_of_immutables: 58
662 }
663 .route()
664 );
665
666 assert_eq!(
667 "statistics/cardano-database/ancillary-files-restored".to_string(),
668 AggregatorRequest::IncrementCardanoDatabaseAncillaryStatistic.route()
669 );
670
671 assert_eq!(
672 "statistics/cardano-database/complete-restoration".to_string(),
673 AggregatorRequest::IncrementCardanoDatabaseCompleteRestorationStatistic.route()
674 );
675
676 assert_eq!(
677 "statistics/cardano-database/partial-restoration".to_string(),
678 AggregatorRequest::IncrementCardanoDatabasePartialRestorationStatistic.route()
679 );
680
681 assert_eq!(
682 "proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
683 AggregatorRequest::GetTransactionsProofs {
684 transactions_hashes: vec![
685 "abc".to_string(),
686 "def".to_string(),
687 "ghi".to_string(),
688 "jkl".to_string()
689 ]
690 }
691 .route()
692 );
693
694 assert_eq!(
695 "artifact/cardano-transaction/abc".to_string(),
696 AggregatorRequest::GetCardanoTransactionSnapshot {
697 hash: "abc".to_string()
698 }
699 .route()
700 );
701
702 assert_eq!(
703 "artifact/cardano-transactions".to_string(),
704 AggregatorRequest::ListCardanoTransactionSnapshots.route()
705 );
706
707 assert_eq!(
708 "artifact/cardano-stake-distribution/abc".to_string(),
709 AggregatorRequest::GetCardanoStakeDistribution {
710 hash: "abc".to_string()
711 }
712 .route()
713 );
714
715 assert_eq!(
716 "artifact/cardano-stake-distribution/epoch/123".to_string(),
717 AggregatorRequest::GetCardanoStakeDistributionByEpoch { epoch: Epoch(123) }.route()
718 );
719
720 assert_eq!(
721 "artifact/cardano-stake-distribution/epoch/latest".to_string(),
722 AggregatorRequest::GetCardanoStakeDistributionForLatestEpoch { offset: None }.route()
723 );
724
725 assert_eq!(
726 "artifact/cardano-stake-distribution/epoch/latest-8".to_string(),
727 AggregatorRequest::GetCardanoStakeDistributionForLatestEpoch { offset: Some(8) }
728 .route()
729 );
730
731 assert_eq!(
732 "artifact/cardano-stake-distributions".to_string(),
733 AggregatorRequest::ListCardanoStakeDistributions.route()
734 );
735
736 assert_eq!("status".to_string(), AggregatorRequest::Status.route());
737 }
738
739 #[test]
740 fn deduce_body_from_request() {
741 fn that_should_not_have_body(req: &AggregatorRequest) -> bool {
742 !matches!(
743 req,
744 AggregatorRequest::IncrementSnapshotStatistic { .. }
745 | AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic { .. }
746 )
747 }
748
749 assert_eq!(
750 Some(r#"{"key":"value"}"#.to_string()),
751 AggregatorRequest::IncrementSnapshotStatistic {
752 snapshot: r#"{"key":"value"}"#.to_string()
753 }
754 .get_body()
755 );
756
757 assert_eq!(
758 Some(
759 serde_json::to_string(&CardanoDatabaseImmutableFilesRestoredMessage {
760 nb_immutable_files: 432,
761 })
762 .unwrap()
763 ),
764 AggregatorRequest::IncrementCardanoDatabaseImmutablesRestoredStatistic {
765 number_of_immutables: 432,
766 }
767 .get_body()
768 );
769
770 for req_that_should_not_have_body in
771 AggregatorRequest::iter().filter(that_should_not_have_body)
772 {
773 assert_eq!(None, req_that_should_not_have_body.get_body());
774 }
775 }
776
777 #[tokio::test]
778 async fn test_client_handle_4xx_errors() {
779 let client_error = ClientError::new("label", "message");
780
781 let (aggregator, client) = setup_server_and_client();
782 aggregator.mock(|_when, then| {
783 then.status(StatusCode::IM_A_TEAPOT.as_u16())
784 .json_body_obj(&client_error);
785 });
786
787 let expected_error = AggregatorClientError::RemoteServerLogical(anyhow!("{client_error}"));
788
789 let get_content_error = client
790 .get_content(AggregatorRequest::ListCertificates)
791 .await
792 .unwrap_err();
793 assert_error_eq!(get_content_error, expected_error);
794
795 let post_content_error = client
796 .post_content(AggregatorRequest::ListCertificates)
797 .await
798 .unwrap_err();
799 assert_error_eq!(post_content_error, expected_error);
800 }
801
802 #[tokio::test]
803 async fn test_client_handle_404_not_found_error() {
804 let client_error = ClientError::new("label", "message");
805
806 let (aggregator, client) = setup_server_and_client();
807 aggregator.mock(|_when, then| {
808 then.status(StatusCode::NOT_FOUND.as_u16())
809 .json_body_obj(&client_error);
810 });
811
812 let expected_error = AggregatorHTTPClient::not_found_error(
813 Url::parse(&format!(
814 "{}/{}",
815 aggregator.base_url(),
816 AggregatorRequest::ListCertificates.route()
817 ))
818 .unwrap(),
819 );
820
821 let get_content_error = client
822 .get_content(AggregatorRequest::ListCertificates)
823 .await
824 .unwrap_err();
825 assert_error_eq!(get_content_error, expected_error);
826
827 let post_content_error = client
828 .post_content(AggregatorRequest::ListCertificates)
829 .await
830 .unwrap_err();
831 assert_error_eq!(post_content_error, expected_error);
832 }
833
834 #[tokio::test]
835 async fn test_client_handle_5xx_errors() {
836 let server_error = ServerError::new("message");
837
838 let (aggregator, client) = setup_server_and_client();
839 aggregator.mock(|_when, then| {
840 then.status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
841 .json_body_obj(&server_error);
842 });
843
844 let expected_error =
845 AggregatorClientError::RemoteServerTechnical(anyhow!("{server_error}"));
846
847 let get_content_error = client
848 .get_content(AggregatorRequest::ListCertificates)
849 .await
850 .unwrap_err();
851 assert_error_eq!(get_content_error, expected_error);
852
853 let post_content_error = client
854 .post_content(AggregatorRequest::ListCertificates)
855 .await
856 .unwrap_err();
857 assert_error_eq!(post_content_error, expected_error);
858 }
859
860 #[tokio::test]
861 async fn test_client_with_custom_headers() {
862 let mut http_headers = HashMap::new();
863 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
864 http_headers.insert("Another-Header".to_string(), "AnotherValue".to_string());
865 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
866 aggregator.mock(|when, then| {
867 when.header("Custom-Header", "CustomValue")
868 .header("Another-Header", "AnotherValue");
869 then.status(StatusCode::OK.as_u16()).body("ok");
870 });
871
872 client
873 .get_content(AggregatorRequest::ListCertificates)
874 .await
875 .expect("GET request should succeed");
876
877 client
878 .post_content(AggregatorRequest::ListCertificates)
879 .await
880 .expect("GET request should succeed");
881 }
882
883 #[tokio::test]
884 async fn test_client_sends_accept_encoding_header_with_correct_values() {
885 let (aggregator, client) = setup_server_and_client();
886 aggregator.mock(|when, then| {
887 when.matches(|req| {
888 let headers = req.headers();
889 let accept_encoding_header = headers
890 .get("accept-encoding")
891 .expect("Accept-Encoding header not found");
892
893 ["gzip", "br", "deflate", "zstd"].iter().all(|&encoding| {
894 accept_encoding_header.to_str().is_ok_and(|h| h.contains(encoding))
895 })
896 });
897
898 then.status(200).body("ok");
899 });
900
901 client
902 .get_content(AggregatorRequest::ListCertificates)
903 .await
904 .expect("GET request should succeed with Accept-Encoding header");
905 }
906
907 #[tokio::test]
908 async fn test_client_with_custom_headers_sends_accept_encoding_header_with_correct_values() {
909 let mut http_headers = HashMap::new();
910 http_headers.insert("Custom-Header".to_string(), "CustomValue".to_string());
911 let (aggregator, client) = setup_server_and_client_with_custom_headers(http_headers);
912 aggregator.mock(|when, then| {
913 when.matches(|req| {
914 let headers = req.headers();
915 let accept_encoding_header = headers
916 .get("accept-encoding")
917 .expect("Accept-Encoding header not found");
918
919 ["gzip", "br", "deflate", "zstd"].iter().all(|&encoding| {
920 accept_encoding_header.to_str().is_ok_and(|h| h.contains(encoding))
921 })
922 });
923
924 then.status(200).body("ok");
925 });
926
927 client
928 .get_content(AggregatorRequest::ListCertificates)
929 .await
930 .expect("GET request should succeed with Accept-Encoding header");
931 }
932
933 mod warn_if_api_version_mismatch {
934 use http::response::Builder as HttpResponseBuilder;
935
936 use mithril_common::test::logging::MemoryDrainForTestInspector;
937
938 use super::*;
939
940 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
941 key: K,
942 value: V,
943 ) -> Response {
944 HttpResponseBuilder::new()
945 .header(key.into(), value.into())
946 .body("whatever")
947 .unwrap()
948 .into()
949 }
950
951 fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
952 log_inspector: &MemoryDrainForTestInspector,
953 aggregator_version: A,
954 client_version: S,
955 ) {
956 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
957 assert!(
958 log_inspector
959 .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
960 );
961 assert!(
962 log_inspector.contains_log(&format!("client_version={}", client_version.into()))
963 );
964 }
965
966 #[tokio::test]
967 async fn test_logs_warning_when_aggregator_api_version_is_newer() {
968 let aggregator_version = "2.0.0";
969 let client_version = "1.0.0";
970 let (logger, log_inspector) = TestLogger::memory();
971 let mut client = setup_client(
972 "http://whatever",
973 vec![Version::parse(client_version).unwrap()],
974 None,
975 );
976 client.logger = logger;
977 let response =
978 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
979
980 assert!(
981 Version::parse(aggregator_version).unwrap()
982 > Version::parse(client_version).unwrap()
983 );
984
985 client.warn_if_api_version_mismatch(&response).await;
986
987 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
988 }
989
990 #[tokio::test]
991 async fn test_no_warning_logged_when_versions_match() {
992 let version = "1.0.0";
993 let (logger, log_inspector) = TestLogger::memory();
994 let mut client = setup_client(
995 "http://whatever",
996 vec![Version::parse(version).unwrap()],
997 None,
998 );
999 client.logger = logger;
1000 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
1001
1002 client.warn_if_api_version_mismatch(&response).await;
1003
1004 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1005 }
1006
1007 #[tokio::test]
1008 async fn test_no_warning_logged_when_aggregator_api_version_is_older() {
1009 let aggregator_version = "1.0.0";
1010 let client_version = "2.0.0";
1011 let (logger, log_inspector) = TestLogger::memory();
1012 let mut client = setup_client(
1013 "http://whatever",
1014 vec![Version::parse(client_version).unwrap()],
1015 None,
1016 );
1017 client.logger = logger;
1018 let response =
1019 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1020
1021 assert!(
1022 Version::parse(aggregator_version).unwrap()
1023 < Version::parse(client_version).unwrap()
1024 );
1025
1026 client.warn_if_api_version_mismatch(&response).await;
1027
1028 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1029 }
1030
1031 #[tokio::test]
1032 async fn test_does_not_log_or_fail_when_header_is_missing() {
1033 let (logger, log_inspector) = TestLogger::memory();
1034 let mut client = setup_client(
1035 "http://whatever",
1036 APIVersionProvider::compute_all_versions_sorted(),
1037 None,
1038 );
1039 client.logger = logger;
1040 let response =
1041 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
1042
1043 client.warn_if_api_version_mismatch(&response).await;
1044
1045 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1046 }
1047
1048 #[tokio::test]
1049 async fn test_does_not_log_or_fail_when_header_is_not_a_version() {
1050 let (logger, log_inspector) = TestLogger::memory();
1051 let mut client = setup_client(
1052 "http://whatever",
1053 APIVersionProvider::compute_all_versions_sorted(),
1054 None,
1055 );
1056 client.logger = logger;
1057 let response =
1058 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1059
1060 client.warn_if_api_version_mismatch(&response).await;
1061
1062 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1063 }
1064
1065 #[tokio::test]
1066 async fn test_logs_error_when_client_version_cannot_be_computed() {
1067 let (logger, log_inspector) = TestLogger::memory();
1068 let mut client = setup_client("http://whatever", vec![], None);
1069 client.logger = logger;
1070 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1071
1072 client.warn_if_api_version_mismatch(&response).await;
1073
1074 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1075 }
1076
1077 #[tokio::test]
1078 async fn test_client_get_log_warning_if_api_version_mismatch() {
1079 let aggregator_version = "2.0.0";
1080 let client_version = "1.0.0";
1081 let (server, mut client) = setup_server_and_client();
1082 let (logger, log_inspector) = TestLogger::memory();
1083 client.api_versions =
1084 Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1085 client.logger = logger;
1086 server.mock(|_, then| {
1087 then.status(StatusCode::OK.as_u16())
1088 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1089 });
1090
1091 assert!(
1092 Version::parse(aggregator_version).unwrap()
1093 > Version::parse(client_version).unwrap()
1094 );
1095
1096 client.get(Url::parse(&server.base_url()).unwrap()).await.unwrap();
1097
1098 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1099 }
1100
1101 #[tokio::test]
1102 async fn test_client_post_log_warning_if_api_version_mismatch() {
1103 let aggregator_version = "2.0.0";
1104 let client_version = "1.0.0";
1105 let (server, mut client) = setup_server_and_client();
1106 let (logger, log_inspector) = TestLogger::memory();
1107 client.api_versions =
1108 Arc::new(RwLock::new(vec![Version::parse(client_version).unwrap()]));
1109 client.logger = logger;
1110 server.mock(|_, then| {
1111 then.status(StatusCode::OK.as_u16())
1112 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1113 });
1114
1115 assert!(
1116 Version::parse(aggregator_version).unwrap()
1117 > Version::parse(client_version).unwrap()
1118 );
1119
1120 client
1121 .post(Url::parse(&server.base_url()).unwrap(), "whatever")
1122 .await
1123 .unwrap();
1124
1125 assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version);
1126 }
1127 }
1128}