1use anyhow::{Context, anyhow};
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode, Url};
5
6use semver::Version;
7use slog::{Logger, debug, error, warn};
8use std::{io, sync::Arc, time::Duration};
9use thiserror::Error;
10
11use mithril_common::{
12 MITHRIL_AGGREGATOR_VERSION_HEADER, MITHRIL_API_VERSION_HEADER, StdError, StdResult,
13 api_version::APIVersionProvider,
14 certificate_chain::{CertificateRetriever, CertificateRetrieverError},
15 entities::{Certificate, ClientError, ServerError},
16 logging::LoggerExtensions,
17 messages::{
18 CertificateListMessage, CertificateMessage, EpochSettingsMessage, TryFromMessageAdapter,
19 },
20};
21
22use crate::entities::LeaderAggregatorEpochSettings;
23use crate::message_adapters::FromEpochSettingsAdapter;
24use crate::services::{LeaderAggregatorClient, RemoteCertificateRetriever};
25
26const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
27
28const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
29 "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
30
31#[derive(Error, Debug)]
33pub enum AggregatorClientError {
34 #[error("remote server technical error")]
36 RemoteServerTechnical(#[source] StdError),
37
38 #[error("remote server logical error")]
40 RemoteServerLogical(#[source] StdError),
41
42 #[error("remote server unreachable")]
44 RemoteServerUnreachable(#[source] StdError),
45
46 #[error("unhandled status code: {0}, response text: {1}")]
48 UnhandledStatusCode(StatusCode, String),
49
50 #[error("json parsing failed")]
52 JsonParseFailed(#[source] StdError),
53
54 #[error("Input/Output error")]
56 IOError(#[from] io::Error),
57
58 #[error("HTTP client creation failed")]
60 HTTPClientCreation(#[source] StdError),
61
62 #[error("proxy creation failed")]
64 ProxyCreation(#[source] StdError),
65
66 #[error("adapter failed")]
68 Adapter(#[source] StdError),
69}
70
71impl AggregatorClientError {
72 pub async fn from_response(response: Response) -> Self {
78 let error_code = response.status();
79
80 if error_code.is_client_error() {
81 let root_cause = Self::get_root_cause(response).await;
82 Self::RemoteServerLogical(anyhow!(root_cause))
83 } else if error_code.is_server_error() {
84 let root_cause = Self::get_root_cause(response).await;
85 Self::RemoteServerTechnical(anyhow!(root_cause))
86 } else {
87 let response_text = response.text().await.unwrap_or_default();
88 Self::UnhandledStatusCode(error_code, response_text)
89 }
90 }
91
92 async fn get_root_cause(response: Response) -> String {
93 let error_code = response.status();
94 let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
95 let is_json = response
96 .headers()
97 .get(header::CONTENT_TYPE)
98 .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
99
100 if is_json {
101 let json_value: serde_json::Value = response.json().await.unwrap_or_default();
102
103 if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
104 format!(
105 "{}: {}: {}",
106 canonical_reason, client_error.label, client_error.message
107 )
108 } else if let Ok(server_error) =
109 serde_json::from_value::<ServerError>(json_value.clone())
110 {
111 format!("{}: {}", canonical_reason, server_error.message)
112 } else if json_value.is_null() {
113 canonical_reason.to_string()
114 } else {
115 format!("{canonical_reason}: {json_value}")
116 }
117 } else {
118 let response_text = response.text().await.unwrap_or_default();
119 format!("{canonical_reason}: {response_text}")
120 }
121 }
122}
123
124pub struct AggregatorHTTPClient {
126 aggregator_endpoint: Url,
127 relay_endpoint: Option<String>,
128 api_version_provider: Arc<APIVersionProvider>,
129 timeout_duration: Option<Duration>,
130 logger: Logger,
131}
132
133impl AggregatorHTTPClient {
134 pub fn new(
136 aggregator_endpoint: Url,
137 relay_endpoint: Option<String>,
138 api_version_provider: Arc<APIVersionProvider>,
139 timeout_duration: Option<Duration>,
140 logger: Logger,
141 ) -> Self {
142 let logger = logger.new_with_component_name::<Self>();
143 debug!(logger, "New AggregatorHTTPClient created");
144
145 let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
149 aggregator_endpoint
150 } else {
151 let mut url = aggregator_endpoint.clone();
152 url.set_path(&format!("{}/", aggregator_endpoint.path()));
153 url
154 };
155
156 Self {
157 aggregator_endpoint,
158 relay_endpoint,
159 api_version_provider,
160 timeout_duration,
161 logger,
162 }
163 }
164
165 fn join_aggregator_endpoint(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
166 self.aggregator_endpoint
167 .join(endpoint)
168 .with_context(|| {
169 format!(
170 "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
171 self.aggregator_endpoint
172 )
173 })
174 .map_err(AggregatorClientError::HTTPClientCreation)
175 }
176
177 fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
178 let client = match &self.relay_endpoint {
179 Some(relay_endpoint) => Client::builder()
180 .proxy(
181 Proxy::all(relay_endpoint)
182 .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
183 )
184 .build()
185 .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
186 None => Client::new(),
187 };
188
189 Ok(client)
190 }
191
192 pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
194 let request_builder = request_builder
195 .header(
196 MITHRIL_API_VERSION_HEADER,
197 self.api_version_provider
198 .compute_current_version()
199 .unwrap()
200 .to_string(),
201 )
202 .header(MITHRIL_AGGREGATOR_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
203
204 if let Some(duration) = self.timeout_duration {
205 request_builder.timeout(duration)
206 } else {
207 request_builder
208 }
209 }
210
211 fn warn_if_api_version_mismatch(&self, response: &Response) {
213 let leader_version = response
214 .headers()
215 .get(MITHRIL_API_VERSION_HEADER)
216 .and_then(|v| v.to_str().ok())
217 .and_then(|s| Version::parse(s).ok());
218
219 let follower_version = self.api_version_provider.compute_current_version();
220
221 match (leader_version, follower_version) {
222 (Some(leader), Ok(follower)) if follower < leader => {
223 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
224 "leader_aggregator_version" => %leader,
225 "aggregator_version" => %follower,
226 );
227 }
228 (Some(_), Err(error)) => {
229 error!(
230 self.logger,
231 "Failed to compute the current aggregator API version";
232 "error" => error.to_string()
233 );
234 }
235 _ => {}
236 }
237 }
238}
239
240impl AggregatorHTTPClient {
242 async fn epoch_settings(
243 &self,
244 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
245 debug!(self.logger, "Retrieve epoch settings");
246 let url = self.join_aggregator_endpoint("epoch-settings")?;
247 let response = self
248 .prepare_request_builder(self.prepare_http_client()?.get(url))
249 .send()
250 .await;
251
252 match response {
253 Ok(response) => match response.status() {
254 StatusCode::OK => {
255 self.warn_if_api_version_mismatch(&response);
256 match response.json::<EpochSettingsMessage>().await {
257 Ok(message) => {
258 let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
259 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
260 Ok(Some(epoch_settings))
261 }
262 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
263 }
264 }
265 _ => Err(AggregatorClientError::from_response(response).await),
266 },
267 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
268 }
269 }
270
271 async fn latest_certificates_list(
272 &self,
273 ) -> Result<CertificateListMessage, AggregatorClientError> {
274 debug!(self.logger, "Retrieve latest certificates list");
275 let url = self.join_aggregator_endpoint("certificates")?;
276 let response = self
277 .prepare_request_builder(self.prepare_http_client()?.get(url))
278 .send()
279 .await;
280
281 match response {
282 Ok(response) => match response.status() {
283 StatusCode::OK => {
284 self.warn_if_api_version_mismatch(&response);
285 match response.json::<CertificateListMessage>().await {
286 Ok(message) => Ok(message),
287 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
288 }
289 }
290 _ => Err(AggregatorClientError::from_response(response).await),
291 },
292 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
293 }
294 }
295
296 async fn certificate_details(
297 &self,
298 certificate_hash: &str,
299 ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
300 debug!(self.logger, "Retrieve certificate details"; "certificate_hash" => %certificate_hash);
301 let url = self.join_aggregator_endpoint(&format!("certificate/{certificate_hash}"))?;
302 let response = self
303 .prepare_request_builder(self.prepare_http_client()?.get(url))
304 .send()
305 .await;
306
307 match response {
308 Ok(response) => match response.status() {
309 StatusCode::OK => {
310 self.warn_if_api_version_mismatch(&response);
311 match response.json::<CertificateMessage>().await {
312 Ok(message) => Ok(Some(message)),
313 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
314 }
315 }
316 StatusCode::NOT_FOUND => Ok(None),
317 _ => Err(AggregatorClientError::from_response(response).await),
318 },
319 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
320 }
321 }
322
323 async fn latest_genesis_certificate(
324 &self,
325 ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
326 self.certificate_details("genesis").await
327 }
328}
329
330#[async_trait]
331impl LeaderAggregatorClient for AggregatorHTTPClient {
332 async fn retrieve_epoch_settings(&self) -> StdResult<Option<LeaderAggregatorEpochSettings>> {
333 let epoch_settings = self.epoch_settings().await?;
334 Ok(epoch_settings)
335 }
336}
337
338#[async_trait]
339impl CertificateRetriever for AggregatorHTTPClient {
340 async fn get_certificate_details(
341 &self,
342 certificate_hash: &str,
343 ) -> Result<Certificate, CertificateRetrieverError> {
344 let message = self
345 .certificate_details(certificate_hash)
346 .await
347 .with_context(|| {
348 format!("Failed to retrieve certificate with hash: '{certificate_hash}'")
349 })
350 .map_err(CertificateRetrieverError)?
351 .ok_or(CertificateRetrieverError(anyhow!(
352 "Certificate does not exist: '{certificate_hash}'"
353 )))?;
354
355 message.try_into().map_err(CertificateRetrieverError)
356 }
357}
358
359#[async_trait]
360impl RemoteCertificateRetriever for AggregatorHTTPClient {
361 async fn get_latest_certificate_details(&self) -> StdResult<Option<Certificate>> {
362 let latest_certificates_list = self.latest_certificates_list().await?;
363
364 match latest_certificates_list.first() {
365 None => Ok(None),
366 Some(latest_certificate_list_item) => {
367 let latest_certificate_message =
368 self.certificate_details(&latest_certificate_list_item.hash).await?;
369 latest_certificate_message.map(TryInto::try_into).transpose()
370 }
371 }
372 }
373
374 async fn get_genesis_certificate_details(&self) -> StdResult<Option<Certificate>> {
375 match self.latest_genesis_certificate().await? {
376 Some(message) => Ok(Some(message.try_into()?)),
377 None => Ok(None),
378 }
379 }
380}
381
382#[cfg(test)]
383pub(crate) mod dumb {
384 use tokio::sync::RwLock;
385
386 use mithril_common::test::double::Dummy;
387
388 use super::*;
389
390 pub struct DumbAggregatorClient {
394 epoch_settings: RwLock<Option<LeaderAggregatorEpochSettings>>,
395 }
396
397 impl Default for DumbAggregatorClient {
398 fn default() -> Self {
399 Self {
400 epoch_settings: RwLock::new(Some(LeaderAggregatorEpochSettings::dummy())),
401 }
402 }
403 }
404
405 #[async_trait]
406 impl LeaderAggregatorClient for DumbAggregatorClient {
407 async fn retrieve_epoch_settings(
408 &self,
409 ) -> StdResult<Option<LeaderAggregatorEpochSettings>> {
410 let epoch_settings = self.epoch_settings.read().await.clone();
411
412 Ok(epoch_settings)
413 }
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use http::response::Builder as HttpResponseBuilder;
420 use httpmock::prelude::*;
421 use reqwest::IntoUrl;
422 use serde_json::json;
423
424 use mithril_common::messages::CertificateListItemMessage;
425 use mithril_common::test::double::{Dummy, DummyApiVersionDiscriminantSource};
426
427 use crate::test::TestLogger;
428
429 use super::*;
430
431 fn setup_client<U: IntoUrl>(server_url: U) -> AggregatorHTTPClient {
432 let discriminant_source = DummyApiVersionDiscriminantSource::default();
433 let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
434
435 AggregatorHTTPClient::new(
436 server_url.into_url().unwrap(),
437 None,
438 Arc::new(api_version_provider),
439 None,
440 TestLogger::stdout(),
441 )
442 }
443
444 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
445 let server = MockServer::start();
446 let aggregator_endpoint = server.url("");
447 let client = setup_client(&aggregator_endpoint);
448
449 (server, client)
450 }
451
452 fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
453 HttpResponseBuilder::new()
454 .status(status_code)
455 .body(body.into())
456 .unwrap()
457 .into()
458 }
459
460 fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
461 HttpResponseBuilder::new()
462 .status(status_code)
463 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
464 .body(serde_json::to_string(&body).unwrap())
465 .unwrap()
466 .into()
467 }
468
469 macro_rules! assert_error_text_contains {
470 ($error: expr, $expect_contains: expr) => {
471 let error = &$error;
472 assert!(
473 error.contains($expect_contains),
474 "Expected error message to contain '{}'\ngot '{error:?}'",
475 $expect_contains,
476 );
477 };
478 }
479
480 #[tokio::test]
481 async fn test_epoch_settings_ok_200() {
482 let (server, client) = setup_server_and_client();
483 let epoch_settings_expected = EpochSettingsMessage::dummy();
484 let _server_mock = server.mock(|when, then| {
485 when.path("/epoch-settings");
486 then.status(200).body(json!(epoch_settings_expected).to_string());
487 });
488
489 let epoch_settings = client.retrieve_epoch_settings().await;
490 epoch_settings.as_ref().expect("unexpected error");
491 assert_eq!(
492 FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
493 epoch_settings.unwrap().unwrap()
494 );
495 }
496
497 #[tokio::test]
498 async fn test_epoch_settings_ko_500() {
499 let (server, client) = setup_server_and_client();
500 let _server_mock = server.mock(|when, then| {
501 when.path("/epoch-settings");
502 then.status(500).body("an error occurred");
503 });
504
505 match client.epoch_settings().await.unwrap_err() {
506 AggregatorClientError::RemoteServerTechnical(_) => (),
507 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
508 };
509 }
510
511 #[tokio::test]
512 async fn test_epoch_settings_timeout() {
513 let (server, mut client) = setup_server_and_client();
514 client.timeout_duration = Some(Duration::from_millis(10));
515 let _server_mock = server.mock(|when, then| {
516 when.path("/epoch-settings");
517 then.delay(Duration::from_millis(100));
518 });
519
520 let error = client
521 .epoch_settings()
522 .await
523 .expect_err("retrieve_epoch_settings should fail");
524
525 assert!(
526 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
527 "unexpected error type: {error:?}"
528 );
529 }
530
531 #[tokio::test]
532 async fn test_latest_certificates_list_ok_200() {
533 let (server, client) = setup_server_and_client();
534 let expected_list = vec![
535 CertificateListItemMessage::dummy(),
536 CertificateListItemMessage::dummy(),
537 ];
538 let _server_mock = server.mock(|when, then| {
539 when.path("/certificates");
540 then.status(200).body(json!(expected_list).to_string());
541 });
542
543 let fetched_list = client.latest_certificates_list().await.unwrap();
544
545 assert_eq!(expected_list, fetched_list);
546 }
547
548 #[tokio::test]
549 async fn test_latest_certificates_list_ko_500() {
550 let (server, client) = setup_server_and_client();
551 let _server_mock = server.mock(|when, then| {
552 when.path("/certificates");
553 then.status(500).body("an error occurred");
554 });
555
556 match client.latest_certificates_list().await.unwrap_err() {
557 AggregatorClientError::RemoteServerTechnical(_) => (),
558 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
559 };
560 }
561
562 #[tokio::test]
563 async fn test_latest_certificates_list_timeout() {
564 let (server, mut client) = setup_server_and_client();
565 client.timeout_duration = Some(Duration::from_millis(10));
566 let _server_mock = server.mock(|when, then| {
567 when.path("/certificates");
568 then.delay(Duration::from_millis(100));
569 });
570
571 let error = client
572 .latest_certificates_list()
573 .await
574 .expect_err("retrieve_epoch_settings should fail");
575
576 assert!(
577 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
578 "unexpected error type: {error:?}"
579 );
580 }
581
582 #[tokio::test]
583 async fn test_certificates_details_ok_200() {
584 let (server, client) = setup_server_and_client();
585 let expected_message = CertificateMessage::dummy();
586 let _server_mock = server.mock(|when, then| {
587 when.path(format!("/certificate/{}", expected_message.hash));
588 then.status(200).body(json!(expected_message).to_string());
589 });
590
591 let fetched_message = client.certificate_details(&expected_message.hash).await.unwrap();
592
593 assert_eq!(Some(expected_message), fetched_message);
594 }
595
596 #[tokio::test]
597 async fn test_certificates_details_ok_404() {
598 let (server, client) = setup_server_and_client();
599 let _server_mock = server.mock(|when, then| {
600 when.path("/certificate/not-found");
601 then.status(404);
602 });
603
604 let fetched_message = client.latest_genesis_certificate().await.unwrap();
605
606 assert_eq!(None, fetched_message);
607 }
608
609 #[tokio::test]
610 async fn test_certificates_details_ko_500() {
611 let (server, client) = setup_server_and_client();
612 let _server_mock = server.mock(|when, then| {
613 when.path("/certificate/whatever");
614 then.status(500).body("an error occurred");
615 });
616
617 match client.certificate_details("whatever").await.unwrap_err() {
618 AggregatorClientError::RemoteServerTechnical(_) => (),
619 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
620 };
621 }
622
623 #[tokio::test]
624 async fn test_certificates_details_timeout() {
625 let (server, mut client) = setup_server_and_client();
626 client.timeout_duration = Some(Duration::from_millis(10));
627 let _server_mock = server.mock(|when, then| {
628 when.path("/certificate/whatever");
629 then.delay(Duration::from_millis(100));
630 });
631
632 let error = client
633 .certificate_details("whatever")
634 .await
635 .expect_err("retrieve_epoch_settings should fail");
636
637 assert!(
638 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
639 "unexpected error type: {error:?}"
640 );
641 }
642
643 #[tokio::test]
644 async fn test_latest_genesis_ok_200() {
645 let (server, client) = setup_server_and_client();
646 let genesis_message = CertificateMessage::dummy();
647 let _server_mock = server.mock(|when, then| {
648 when.path("/certificate/genesis");
649 then.status(200).body(json!(genesis_message).to_string());
650 });
651
652 let fetched = client.latest_genesis_certificate().await.unwrap();
653
654 assert_eq!(Some(genesis_message), fetched);
655 }
656
657 #[tokio::test]
658 async fn test_latest_genesis_ok_404() {
659 let (server, client) = setup_server_and_client();
660 let _server_mock = server.mock(|when, then| {
661 when.path("/certificate/genesis");
662 then.status(404);
663 });
664
665 let fetched = client.latest_genesis_certificate().await.unwrap();
666
667 assert_eq!(None, fetched);
668 }
669
670 #[tokio::test]
671 async fn test_latest_genesis_ko_500() {
672 let (server, client) = setup_server_and_client();
673 let _server_mock = server.mock(|when, then| {
674 when.path("/certificate/genesis");
675 then.status(500).body("an error occurred");
676 });
677
678 let error = client.latest_genesis_certificate().await.unwrap_err();
679
680 assert!(
681 matches!(error, AggregatorClientError::RemoteServerTechnical(_)),
682 "Expected Aggregator::RemoteServerTechnical error, got {error:?}"
683 );
684 }
685
686 #[tokio::test]
687 async fn test_latest_genesis_timeout() {
688 let (server, mut client) = setup_server_and_client();
689 client.timeout_duration = Some(Duration::from_millis(10));
690 let _server_mock = server.mock(|when, then| {
691 when.path("/certificate/genesis");
692 then.delay(Duration::from_millis(100));
693 });
694
695 let error = client.latest_genesis_certificate().await.unwrap_err();
696
697 assert!(
698 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
699 "unexpected error type: {error:?}"
700 );
701 }
702
703 #[tokio::test]
704 async fn test_4xx_errors_are_handled_as_remote_server_logical() {
705 let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
706 let handled_error = AggregatorClientError::from_response(response).await;
707
708 assert!(
709 matches!(
710 handled_error,
711 AggregatorClientError::RemoteServerLogical(..)
712 ),
713 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
714 );
715 }
716
717 #[tokio::test]
718 async fn test_5xx_errors_are_handled_as_remote_server_technical() {
719 let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
720 let handled_error = AggregatorClientError::from_response(response).await;
721
722 assert!(
723 matches!(
724 handled_error,
725 AggregatorClientError::RemoteServerTechnical(..)
726 ),
727 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
728 );
729 }
730
731 #[tokio::test]
732 async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
733 {
734 let response = build_text_response(StatusCode::OK, "ok text");
735 let handled_error = AggregatorClientError::from_response(response).await;
736
737 assert!(
738 matches!(
739 handled_error,
740 AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
741 ),
742 "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
743 );
744 }
745
746 #[tokio::test]
747 async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
748 let error_text = "An error occurred; please try again later.";
749 let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
750
751 assert_error_text_contains!(
752 AggregatorClientError::get_root_cause(response).await,
753 "expectation failed: An error occurred; please try again later."
754 );
755 }
756
757 #[tokio::test]
758 async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
759 {
760 let client_error = ClientError::new("label", "message");
761 let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
762
763 assert_error_text_contains!(
764 AggregatorClientError::get_root_cause(response).await,
765 "bad request: label: message"
766 );
767 }
768
769 #[tokio::test]
770 async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
771 {
772 let server_error = ServerError::new("message");
773 let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
774
775 assert_error_text_contains!(
776 AggregatorClientError::get_root_cause(response).await,
777 "bad request: message"
778 );
779 }
780
781 #[tokio::test]
782 async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
783 let response = build_json_response(
784 StatusCode::INTERNAL_SERVER_ERROR,
785 &json!({ "second": "unknown", "first": "foreign" }),
786 );
787
788 assert_error_text_contains!(
789 AggregatorClientError::get_root_cause(response).await,
790 r#"internal server error: {"first":"foreign","second":"unknown"}"#
791 );
792 }
793
794 #[tokio::test]
795 async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
796 let response = HttpResponseBuilder::new()
797 .status(StatusCode::BAD_REQUEST)
798 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
799 .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
800 .unwrap()
801 .into();
802
803 let root_cause = AggregatorClientError::get_root_cause(response).await;
804
805 assert_error_text_contains!(root_cause, "bad request");
806 assert!(
807 !root_cause.contains("bad request: "),
808 "Expected error message should not contain additional information \ngot '{root_cause:?}'"
809 );
810 }
811
812 mod warn_if_api_version_mismatch {
813 use std::collections::HashMap;
814
815 use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
816 use mithril_common::test::logging::MemoryDrainForTestInspector;
817
818 use super::*;
819
820 fn version_provider_with_open_api_version<V: Into<String>>(
821 version: V,
822 ) -> APIVersionProvider {
823 let mut version_provider = version_provider_without_open_api_version();
824 let mut open_api_versions = HashMap::new();
825 open_api_versions.insert(
826 "openapi.yaml".to_string(),
827 Version::parse(&version.into()).unwrap(),
828 );
829 version_provider.update_open_api_versions(open_api_versions);
830
831 version_provider
832 }
833
834 fn version_provider_without_open_api_version() -> APIVersionProvider {
835 let mut version_provider =
836 APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::default()));
837 version_provider.update_open_api_versions(HashMap::new());
838
839 version_provider
840 }
841
842 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
843 key: K,
844 value: V,
845 ) -> Response {
846 HttpResponseBuilder::new()
847 .header(key.into(), value.into())
848 .body("whatever")
849 .unwrap()
850 .into()
851 }
852
853 fn assert_api_version_warning_logged<L: Into<String>, A: Into<String>>(
854 log_inspector: &MemoryDrainForTestInspector,
855 leader_aggregator_version: L,
856 aggregator_version: A,
857 ) {
858 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
859 assert!(log_inspector.contains_log(&format!(
860 "leader_aggregator_version={}",
861 leader_aggregator_version.into()
862 )));
863 assert!(
864 log_inspector
865 .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
866 );
867 }
868
869 #[test]
870 fn test_logs_warning_when_leader_aggregator_api_version_is_newer() {
871 let leader_aggregator_version = "2.0.0";
872 let aggregator_version = "1.0.0";
873 let (logger, log_inspector) = TestLogger::memory();
874 let version_provider = version_provider_with_open_api_version(aggregator_version);
875 let mut client = setup_client("http://whatever");
876 client.api_version_provider = Arc::new(version_provider);
877 client.logger = logger;
878 let response = build_fake_response_with_header(
879 MITHRIL_API_VERSION_HEADER,
880 leader_aggregator_version,
881 );
882
883 assert!(
884 Version::parse(leader_aggregator_version).unwrap()
885 > Version::parse(aggregator_version).unwrap()
886 );
887
888 client.warn_if_api_version_mismatch(&response);
889
890 assert_api_version_warning_logged(
891 &log_inspector,
892 leader_aggregator_version,
893 aggregator_version,
894 );
895 }
896
897 #[test]
898 fn test_no_warning_logged_when_versions_match() {
899 let version = "1.0.0";
900 let (logger, log_inspector) = TestLogger::memory();
901 let version_provider = version_provider_with_open_api_version(version);
902 let mut client = setup_client("http://whatever");
903 client.api_version_provider = Arc::new(version_provider);
904 client.logger = logger;
905 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
906
907 client.warn_if_api_version_mismatch(&response);
908
909 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
910 }
911
912 #[test]
913 fn test_no_warning_logged_when_leader_aggregator_api_version_is_older() {
914 let leader_aggregator_version = "1.0.0";
915 let aggregator_version = "2.0.0";
916 let (logger, log_inspector) = TestLogger::memory();
917 let version_provider = version_provider_with_open_api_version(aggregator_version);
918 let mut client = setup_client("http://whatever");
919 client.api_version_provider = Arc::new(version_provider);
920 client.logger = logger;
921 let response = build_fake_response_with_header(
922 MITHRIL_API_VERSION_HEADER,
923 leader_aggregator_version,
924 );
925
926 assert!(
927 Version::parse(leader_aggregator_version).unwrap()
928 < Version::parse(aggregator_version).unwrap()
929 );
930
931 client.warn_if_api_version_mismatch(&response);
932
933 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
934 }
935
936 #[test]
937 fn test_does_not_log_or_fail_when_header_is_missing() {
938 let (logger, log_inspector) = TestLogger::memory();
939 let mut client = setup_client("http://whatever");
940 client.logger = logger;
941 let response =
942 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
943
944 client.warn_if_api_version_mismatch(&response);
945
946 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
947 }
948
949 #[test]
950 fn test_does_not_log_or_fail_when_header_is_not_a_version() {
951 let (logger, log_inspector) = TestLogger::memory();
952 let mut client = setup_client("http://whatever");
953 client.logger = logger;
954 let response =
955 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
956
957 client.warn_if_api_version_mismatch(&response);
958
959 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
960 }
961
962 #[test]
963 fn test_logs_error_when_aggregator_version_cannot_be_computed() {
964 let (logger, log_inspector) = TestLogger::memory();
965 let version_provider = version_provider_without_open_api_version();
966 let mut client = setup_client("http://whatever");
967 client.api_version_provider = Arc::new(version_provider);
968 client.logger = logger;
969 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
970
971 client.warn_if_api_version_mismatch(&response);
972
973 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
974 }
975
976 #[tokio::test]
977 async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
978 let leader_aggregator_version = "2.0.0";
979 let aggregator_version = "1.0.0";
980 let (server, mut client) = setup_server_and_client();
981 let (logger, log_inspector) = TestLogger::memory();
982 let version_provider = version_provider_with_open_api_version(aggregator_version);
983 client.api_version_provider = Arc::new(version_provider);
984 client.logger = logger;
985 let epoch_settings_expected = EpochSettingsMessage::dummy();
986 let _server_mock = server.mock(|when, then| {
987 when.path("/epoch-settings");
988 then.status(200)
989 .body(json!(epoch_settings_expected).to_string())
990 .header(MITHRIL_API_VERSION_HEADER, leader_aggregator_version);
991 });
992
993 assert!(
994 Version::parse(leader_aggregator_version).unwrap()
995 > Version::parse(aggregator_version).unwrap()
996 );
997
998 client.retrieve_epoch_settings().await.unwrap();
999
1000 assert_api_version_warning_logged(
1001 &log_inspector,
1002 leader_aggregator_version,
1003 aggregator_version,
1004 );
1005 }
1006 }
1007
1008 mod remote_certificate_retriever {
1009 use mithril_common::test::double::fake_data;
1010
1011 use super::*;
1012
1013 #[tokio::test]
1014 async fn test_get_latest_certificate_details() {
1015 let (server, client) = setup_server_and_client();
1016 let expected_certificate = fake_data::certificate("expected");
1017 let latest_message: CertificateMessage =
1018 expected_certificate.clone().try_into().unwrap();
1019 let latest_certificates = vec![
1020 CertificateListItemMessage {
1021 hash: expected_certificate.hash.clone(),
1022 ..CertificateListItemMessage::dummy()
1023 },
1024 CertificateListItemMessage::dummy(),
1025 CertificateListItemMessage::dummy(),
1026 ];
1027 let _server_mock = server.mock(|when, then| {
1028 when.path("/certificates");
1029 then.status(200).body(json!(latest_certificates).to_string());
1030 });
1031 let _server_mock = server.mock(|when, then| {
1032 when.path(format!("/certificate/{}", latest_message.hash));
1033 then.status(200).body(json!(latest_message).to_string());
1034 });
1035
1036 let fetched_certificate = client.get_latest_certificate_details().await.unwrap();
1037
1038 assert_eq!(Some(expected_certificate), fetched_certificate);
1039 }
1040
1041 #[tokio::test]
1042 async fn test_get_latest_genesis_certificate() {
1043 let (server, client) = setup_server_and_client();
1044 let genesis_message = CertificateMessage::dummy();
1045 let expected_genesis: Certificate = genesis_message.clone().try_into().unwrap();
1046 let _server_mock = server.mock(|when, then| {
1047 when.path("/certificate/genesis");
1048 then.status(200).body(json!(genesis_message).to_string());
1049 });
1050
1051 let fetched = client.get_genesis_certificate_details().await.unwrap();
1052
1053 assert_eq!(Some(expected_genesis), fetched);
1054 }
1055 }
1056}