1use anyhow::{Context, anyhow};
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode, Url};
5
6use semver::Version;
7use slog::{Logger, debug, error, warn};
8use std::{io, sync::Arc, time::Duration};
9use thiserror::Error;
10
11use mithril_common::{
12 MITHRIL_AGGREGATOR_VERSION_HEADER, MITHRIL_API_VERSION_HEADER, StdError, StdResult,
13 api_version::APIVersionProvider,
14 certificate_chain::{CertificateRetriever, CertificateRetrieverError},
15 entities::{Certificate, ClientError, ServerError},
16 logging::LoggerExtensions,
17 messages::{
18 CertificateListMessage, CertificateMessage, EpochSettingsMessage, TryFromMessageAdapter,
19 },
20};
21
22use crate::entities::LeaderAggregatorEpochSettings;
23use crate::message_adapters::FromEpochSettingsAdapter;
24use crate::services::{LeaderAggregatorClient, RemoteCertificateRetriever};
25
26const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
27
28const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
29 "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
30
31#[derive(Error, Debug)]
33pub enum AggregatorClientError {
34 #[error("remote server technical error")]
36 RemoteServerTechnical(#[source] StdError),
37
38 #[error("remote server logical error")]
40 RemoteServerLogical(#[source] StdError),
41
42 #[error("remote server unreachable")]
44 RemoteServerUnreachable(#[source] StdError),
45
46 #[error("unhandled status code: {0}, response text: {1}")]
48 UnhandledStatusCode(StatusCode, String),
49
50 #[error("json parsing failed")]
52 JsonParseFailed(#[source] StdError),
53
54 #[error("Input/Output error")]
56 IOError(#[from] io::Error),
57
58 #[error("HTTP client creation failed")]
60 HTTPClientCreation(#[source] StdError),
61
62 #[error("proxy creation failed")]
64 ProxyCreation(#[source] StdError),
65
66 #[error("adapter failed")]
68 Adapter(#[source] StdError),
69}
70
71impl AggregatorClientError {
72 pub async fn from_response(response: Response) -> Self {
78 let error_code = response.status();
79
80 if error_code.is_client_error() {
81 let root_cause = Self::get_root_cause(response).await;
82 Self::RemoteServerLogical(anyhow!(root_cause))
83 } else if error_code.is_server_error() {
84 let root_cause = Self::get_root_cause(response).await;
85 Self::RemoteServerTechnical(anyhow!(root_cause))
86 } else {
87 let response_text = response.text().await.unwrap_or_default();
88 Self::UnhandledStatusCode(error_code, response_text)
89 }
90 }
91
92 async fn get_root_cause(response: Response) -> String {
93 let error_code = response.status();
94 let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
95 let is_json = response
96 .headers()
97 .get(header::CONTENT_TYPE)
98 .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
99
100 if is_json {
101 let json_value: serde_json::Value = response.json().await.unwrap_or_default();
102
103 if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
104 format!(
105 "{}: {}: {}",
106 canonical_reason, client_error.label, client_error.message
107 )
108 } else if let Ok(server_error) =
109 serde_json::from_value::<ServerError>(json_value.clone())
110 {
111 format!("{}: {}", canonical_reason, server_error.message)
112 } else if json_value.is_null() {
113 canonical_reason.to_string()
114 } else {
115 format!("{canonical_reason}: {json_value}")
116 }
117 } else {
118 let response_text = response.text().await.unwrap_or_default();
119 format!("{canonical_reason}: {response_text}")
120 }
121 }
122}
123
124pub struct AggregatorHTTPClient {
126 aggregator_endpoint: Url,
127 relay_endpoint: Option<String>,
128 api_version_provider: Arc<APIVersionProvider>,
129 timeout_duration: Option<Duration>,
130 logger: Logger,
131}
132
133impl AggregatorHTTPClient {
134 pub fn new(
136 aggregator_endpoint: Url,
137 relay_endpoint: Option<String>,
138 api_version_provider: Arc<APIVersionProvider>,
139 timeout_duration: Option<Duration>,
140 logger: Logger,
141 ) -> Self {
142 let logger = logger.new_with_component_name::<Self>();
143 debug!(logger, "New AggregatorHTTPClient created");
144
145 let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
149 aggregator_endpoint
150 } else {
151 let mut url = aggregator_endpoint.clone();
152 url.set_path(&format!("{}/", aggregator_endpoint.path()));
153 url
154 };
155
156 Self {
157 aggregator_endpoint,
158 relay_endpoint,
159 api_version_provider,
160 timeout_duration,
161 logger,
162 }
163 }
164
165 fn join_aggregator_endpoint(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
166 self.aggregator_endpoint
167 .join(endpoint)
168 .with_context(|| {
169 format!(
170 "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
171 self.aggregator_endpoint
172 )
173 })
174 .map_err(AggregatorClientError::HTTPClientCreation)
175 }
176
177 fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
178 let client = match &self.relay_endpoint {
179 Some(relay_endpoint) => Client::builder()
180 .proxy(
181 Proxy::all(relay_endpoint)
182 .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
183 )
184 .build()
185 .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
186 None => Client::new(),
187 };
188
189 Ok(client)
190 }
191
192 pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
194 let request_builder = request_builder
195 .header(
196 MITHRIL_API_VERSION_HEADER,
197 self.api_version_provider
198 .compute_current_version()
199 .unwrap()
200 .to_string(),
201 )
202 .header(MITHRIL_AGGREGATOR_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
203
204 if let Some(duration) = self.timeout_duration {
205 request_builder.timeout(duration)
206 } else {
207 request_builder
208 }
209 }
210
211 fn warn_if_api_version_mismatch(&self, response: &Response) {
213 let leader_version = response
214 .headers()
215 .get(MITHRIL_API_VERSION_HEADER)
216 .and_then(|v| v.to_str().ok())
217 .and_then(|s| Version::parse(s).ok());
218
219 let follower_version = self.api_version_provider.compute_current_version();
220
221 match (leader_version, follower_version) {
222 (Some(leader), Ok(follower)) if follower < leader => {
223 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
224 "leader_aggregator_version" => %leader,
225 "aggregator_version" => %follower,
226 );
227 }
228 (Some(_), Err(error)) => {
229 error!(
230 self.logger,
231 "Failed to compute the current aggregator API version";
232 "error" => error.to_string()
233 );
234 }
235 _ => {}
236 }
237 }
238}
239
240impl AggregatorHTTPClient {
242 async fn epoch_settings(
243 &self,
244 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
245 debug!(self.logger, "Retrieve epoch settings");
246 let url = self.join_aggregator_endpoint("epoch-settings")?;
247 let response = self
248 .prepare_request_builder(self.prepare_http_client()?.get(url))
249 .send()
250 .await;
251
252 match response {
253 Ok(response) => match response.status() {
254 StatusCode::OK => {
255 self.warn_if_api_version_mismatch(&response);
256 match response.json::<EpochSettingsMessage>().await {
257 Ok(message) => {
258 let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
259 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
260 Ok(Some(epoch_settings))
261 }
262 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
263 }
264 }
265 _ => Err(AggregatorClientError::from_response(response).await),
266 },
267 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
268 }
269 }
270
271 async fn latest_certificates_list(
272 &self,
273 ) -> Result<CertificateListMessage, AggregatorClientError> {
274 debug!(self.logger, "Retrieve latest certificates list");
275 let url = self.join_aggregator_endpoint("certificates")?;
276 let response = self
277 .prepare_request_builder(self.prepare_http_client()?.get(url))
278 .send()
279 .await;
280
281 match response {
282 Ok(response) => match response.status() {
283 StatusCode::OK => {
284 self.warn_if_api_version_mismatch(&response);
285 match response.json::<CertificateListMessage>().await {
286 Ok(message) => Ok(message),
287 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
288 }
289 }
290 _ => Err(AggregatorClientError::from_response(response).await),
291 },
292 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
293 }
294 }
295
296 async fn certificate_details(
297 &self,
298 certificate_hash: &str,
299 ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
300 debug!(self.logger, "Retrieve certificate details"; "certificate_hash" => %certificate_hash);
301 let url = self.join_aggregator_endpoint(&format!("certificate/{certificate_hash}"))?;
302 let response = self
303 .prepare_request_builder(self.prepare_http_client()?.get(url))
304 .send()
305 .await;
306
307 match response {
308 Ok(response) => match response.status() {
309 StatusCode::OK => {
310 self.warn_if_api_version_mismatch(&response);
311 match response.json::<CertificateMessage>().await {
312 Ok(message) => Ok(Some(message)),
313 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
314 }
315 }
316 StatusCode::NOT_FOUND => Ok(None),
317 _ => Err(AggregatorClientError::from_response(response).await),
318 },
319 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
320 }
321 }
322
323 async fn latest_genesis_certificate(
324 &self,
325 ) -> Result<Option<CertificateMessage>, AggregatorClientError> {
326 self.certificate_details("genesis").await
327 }
328}
329
330#[async_trait]
331impl LeaderAggregatorClient for AggregatorHTTPClient {
332 async fn retrieve_epoch_settings(&self) -> StdResult<Option<LeaderAggregatorEpochSettings>> {
333 let epoch_settings = self.epoch_settings().await?;
334 Ok(epoch_settings)
335 }
336}
337
338#[async_trait]
339impl CertificateRetriever for AggregatorHTTPClient {
340 async fn get_certificate_details(
341 &self,
342 certificate_hash: &str,
343 ) -> Result<Certificate, CertificateRetrieverError> {
344 let message = self
345 .certificate_details(certificate_hash)
346 .await
347 .with_context(|| {
348 format!("Failed to retrieve certificate with hash: '{certificate_hash}'")
349 })
350 .map_err(CertificateRetrieverError)?
351 .ok_or(CertificateRetrieverError(anyhow!(
352 "Certificate does not exist: '{certificate_hash}'"
353 )))?;
354
355 message.try_into().map_err(CertificateRetrieverError)
356 }
357}
358
359#[async_trait]
360impl RemoteCertificateRetriever for AggregatorHTTPClient {
361 async fn get_latest_certificate_details(&self) -> StdResult<Option<Certificate>> {
362 let latest_certificates_list = self.latest_certificates_list().await?;
363
364 match latest_certificates_list.first() {
365 None => Ok(None),
366 Some(latest_certificate_list_item) => {
367 let latest_certificate_message =
368 self.certificate_details(&latest_certificate_list_item.hash).await?;
369 latest_certificate_message.map(TryInto::try_into).transpose()
370 }
371 }
372 }
373
374 async fn get_genesis_certificate_details(&self) -> StdResult<Option<Certificate>> {
375 match self.latest_genesis_certificate().await? {
376 Some(message) => Ok(Some(message.try_into()?)),
377 None => Ok(None),
378 }
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use http::response::Builder as HttpResponseBuilder;
385 use httpmock::prelude::*;
386 use reqwest::IntoUrl;
387 use serde_json::json;
388
389 use mithril_common::messages::CertificateListItemMessage;
390 use mithril_common::test::double::{Dummy, DummyApiVersionDiscriminantSource};
391
392 use crate::test::TestLogger;
393
394 use super::*;
395
396 fn setup_client<U: IntoUrl>(server_url: U) -> AggregatorHTTPClient {
397 let discriminant_source = DummyApiVersionDiscriminantSource::default();
398 let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
399
400 AggregatorHTTPClient::new(
401 server_url.into_url().unwrap(),
402 None,
403 Arc::new(api_version_provider),
404 None,
405 TestLogger::stdout(),
406 )
407 }
408
409 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
410 let server = MockServer::start();
411 let aggregator_endpoint = server.url("");
412 let client = setup_client(&aggregator_endpoint);
413
414 (server, client)
415 }
416
417 fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
418 HttpResponseBuilder::new()
419 .status(status_code)
420 .body(body.into())
421 .unwrap()
422 .into()
423 }
424
425 fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
426 HttpResponseBuilder::new()
427 .status(status_code)
428 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
429 .body(serde_json::to_string(&body).unwrap())
430 .unwrap()
431 .into()
432 }
433
434 macro_rules! assert_error_text_contains {
435 ($error: expr, $expect_contains: expr) => {
436 let error = &$error;
437 assert!(
438 error.contains($expect_contains),
439 "Expected error message to contain '{}'\ngot '{error:?}'",
440 $expect_contains,
441 );
442 };
443 }
444
445 #[tokio::test]
446 async fn test_epoch_settings_ok_200() {
447 let (server, client) = setup_server_and_client();
448 let epoch_settings_expected = EpochSettingsMessage::dummy();
449 let _server_mock = server.mock(|when, then| {
450 when.path("/epoch-settings");
451 then.status(200).body(json!(epoch_settings_expected).to_string());
452 });
453
454 let epoch_settings = client.retrieve_epoch_settings().await;
455 epoch_settings.as_ref().expect("unexpected error");
456 assert_eq!(
457 FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
458 epoch_settings.unwrap().unwrap()
459 );
460 }
461
462 #[tokio::test]
463 async fn test_epoch_settings_ko_500() {
464 let (server, client) = setup_server_and_client();
465 let _server_mock = server.mock(|when, then| {
466 when.path("/epoch-settings");
467 then.status(500).body("an error occurred");
468 });
469
470 match client.epoch_settings().await.unwrap_err() {
471 AggregatorClientError::RemoteServerTechnical(_) => (),
472 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
473 };
474 }
475
476 #[tokio::test]
477 async fn test_epoch_settings_timeout() {
478 let (server, mut client) = setup_server_and_client();
479 client.timeout_duration = Some(Duration::from_millis(10));
480 let _server_mock = server.mock(|when, then| {
481 when.path("/epoch-settings");
482 then.delay(Duration::from_millis(100));
483 });
484
485 let error = client
486 .epoch_settings()
487 .await
488 .expect_err("retrieve_epoch_settings should fail");
489
490 assert!(
491 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
492 "unexpected error type: {error:?}"
493 );
494 }
495
496 #[tokio::test]
497 async fn test_latest_certificates_list_ok_200() {
498 let (server, client) = setup_server_and_client();
499 let expected_list = vec![
500 CertificateListItemMessage::dummy(),
501 CertificateListItemMessage::dummy(),
502 ];
503 let _server_mock = server.mock(|when, then| {
504 when.path("/certificates");
505 then.status(200).body(json!(expected_list).to_string());
506 });
507
508 let fetched_list = client.latest_certificates_list().await.unwrap();
509
510 assert_eq!(expected_list, fetched_list);
511 }
512
513 #[tokio::test]
514 async fn test_latest_certificates_list_ko_500() {
515 let (server, client) = setup_server_and_client();
516 let _server_mock = server.mock(|when, then| {
517 when.path("/certificates");
518 then.status(500).body("an error occurred");
519 });
520
521 match client.latest_certificates_list().await.unwrap_err() {
522 AggregatorClientError::RemoteServerTechnical(_) => (),
523 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
524 };
525 }
526
527 #[tokio::test]
528 async fn test_latest_certificates_list_timeout() {
529 let (server, mut client) = setup_server_and_client();
530 client.timeout_duration = Some(Duration::from_millis(10));
531 let _server_mock = server.mock(|when, then| {
532 when.path("/certificates");
533 then.delay(Duration::from_millis(100));
534 });
535
536 let error = client
537 .latest_certificates_list()
538 .await
539 .expect_err("retrieve_epoch_settings should fail");
540
541 assert!(
542 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
543 "unexpected error type: {error:?}"
544 );
545 }
546
547 #[tokio::test]
548 async fn test_certificates_details_ok_200() {
549 let (server, client) = setup_server_and_client();
550 let expected_message = CertificateMessage::dummy();
551 let _server_mock = server.mock(|when, then| {
552 when.path(format!("/certificate/{}", expected_message.hash));
553 then.status(200).body(json!(expected_message).to_string());
554 });
555
556 let fetched_message = client.certificate_details(&expected_message.hash).await.unwrap();
557
558 assert_eq!(Some(expected_message), fetched_message);
559 }
560
561 #[tokio::test]
562 async fn test_certificates_details_ok_404() {
563 let (server, client) = setup_server_and_client();
564 let _server_mock = server.mock(|when, then| {
565 when.path("/certificate/not-found");
566 then.status(404);
567 });
568
569 let fetched_message = client.latest_genesis_certificate().await.unwrap();
570
571 assert_eq!(None, fetched_message);
572 }
573
574 #[tokio::test]
575 async fn test_certificates_details_ko_500() {
576 let (server, client) = setup_server_and_client();
577 let _server_mock = server.mock(|when, then| {
578 when.path("/certificate/whatever");
579 then.status(500).body("an error occurred");
580 });
581
582 match client.certificate_details("whatever").await.unwrap_err() {
583 AggregatorClientError::RemoteServerTechnical(_) => (),
584 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
585 };
586 }
587
588 #[tokio::test]
589 async fn test_certificates_details_timeout() {
590 let (server, mut client) = setup_server_and_client();
591 client.timeout_duration = Some(Duration::from_millis(10));
592 let _server_mock = server.mock(|when, then| {
593 when.path("/certificate/whatever");
594 then.delay(Duration::from_millis(100));
595 });
596
597 let error = client
598 .certificate_details("whatever")
599 .await
600 .expect_err("retrieve_epoch_settings should fail");
601
602 assert!(
603 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
604 "unexpected error type: {error:?}"
605 );
606 }
607
608 #[tokio::test]
609 async fn test_latest_genesis_ok_200() {
610 let (server, client) = setup_server_and_client();
611 let genesis_message = CertificateMessage::dummy();
612 let _server_mock = server.mock(|when, then| {
613 when.path("/certificate/genesis");
614 then.status(200).body(json!(genesis_message).to_string());
615 });
616
617 let fetched = client.latest_genesis_certificate().await.unwrap();
618
619 assert_eq!(Some(genesis_message), fetched);
620 }
621
622 #[tokio::test]
623 async fn test_latest_genesis_ok_404() {
624 let (server, client) = setup_server_and_client();
625 let _server_mock = server.mock(|when, then| {
626 when.path("/certificate/genesis");
627 then.status(404);
628 });
629
630 let fetched = client.latest_genesis_certificate().await.unwrap();
631
632 assert_eq!(None, fetched);
633 }
634
635 #[tokio::test]
636 async fn test_latest_genesis_ko_500() {
637 let (server, client) = setup_server_and_client();
638 let _server_mock = server.mock(|when, then| {
639 when.path("/certificate/genesis");
640 then.status(500).body("an error occurred");
641 });
642
643 let error = client.latest_genesis_certificate().await.unwrap_err();
644
645 assert!(
646 matches!(error, AggregatorClientError::RemoteServerTechnical(_)),
647 "Expected Aggregator::RemoteServerTechnical error, got {error:?}"
648 );
649 }
650
651 #[tokio::test]
652 async fn test_latest_genesis_timeout() {
653 let (server, mut client) = setup_server_and_client();
654 client.timeout_duration = Some(Duration::from_millis(10));
655 let _server_mock = server.mock(|when, then| {
656 when.path("/certificate/genesis");
657 then.delay(Duration::from_millis(100));
658 });
659
660 let error = client.latest_genesis_certificate().await.unwrap_err();
661
662 assert!(
663 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
664 "unexpected error type: {error:?}"
665 );
666 }
667
668 #[tokio::test]
669 async fn test_4xx_errors_are_handled_as_remote_server_logical() {
670 let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
671 let handled_error = AggregatorClientError::from_response(response).await;
672
673 assert!(
674 matches!(
675 handled_error,
676 AggregatorClientError::RemoteServerLogical(..)
677 ),
678 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
679 );
680 }
681
682 #[tokio::test]
683 async fn test_5xx_errors_are_handled_as_remote_server_technical() {
684 let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
685 let handled_error = AggregatorClientError::from_response(response).await;
686
687 assert!(
688 matches!(
689 handled_error,
690 AggregatorClientError::RemoteServerTechnical(..)
691 ),
692 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
693 );
694 }
695
696 #[tokio::test]
697 async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
698 {
699 let response = build_text_response(StatusCode::OK, "ok text");
700 let handled_error = AggregatorClientError::from_response(response).await;
701
702 assert!(
703 matches!(
704 handled_error,
705 AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
706 ),
707 "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
708 );
709 }
710
711 #[tokio::test]
712 async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
713 let error_text = "An error occurred; please try again later.";
714 let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
715
716 assert_error_text_contains!(
717 AggregatorClientError::get_root_cause(response).await,
718 "expectation failed: An error occurred; please try again later."
719 );
720 }
721
722 #[tokio::test]
723 async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
724 {
725 let client_error = ClientError::new("label", "message");
726 let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
727
728 assert_error_text_contains!(
729 AggregatorClientError::get_root_cause(response).await,
730 "bad request: label: message"
731 );
732 }
733
734 #[tokio::test]
735 async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
736 {
737 let server_error = ServerError::new("message");
738 let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
739
740 assert_error_text_contains!(
741 AggregatorClientError::get_root_cause(response).await,
742 "bad request: message"
743 );
744 }
745
746 #[tokio::test]
747 async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
748 let response = build_json_response(
749 StatusCode::INTERNAL_SERVER_ERROR,
750 &json!({ "second": "unknown", "first": "foreign" }),
751 );
752
753 assert_error_text_contains!(
754 AggregatorClientError::get_root_cause(response).await,
755 r#"internal server error: {"first":"foreign","second":"unknown"}"#
756 );
757 }
758
759 #[tokio::test]
760 async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
761 let response = HttpResponseBuilder::new()
762 .status(StatusCode::BAD_REQUEST)
763 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
764 .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
765 .unwrap()
766 .into();
767
768 let root_cause = AggregatorClientError::get_root_cause(response).await;
769
770 assert_error_text_contains!(root_cause, "bad request");
771 assert!(
772 !root_cause.contains("bad request: "),
773 "Expected error message should not contain additional information \ngot '{root_cause:?}'"
774 );
775 }
776
777 mod warn_if_api_version_mismatch {
778 use std::collections::HashMap;
779
780 use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
781 use mithril_common::test::logging::MemoryDrainForTestInspector;
782
783 use super::*;
784
785 fn version_provider_with_open_api_version<V: Into<String>>(
786 version: V,
787 ) -> APIVersionProvider {
788 let mut version_provider = version_provider_without_open_api_version();
789 let mut open_api_versions = HashMap::new();
790 open_api_versions.insert(
791 "openapi.yaml".to_string(),
792 Version::parse(&version.into()).unwrap(),
793 );
794 version_provider.update_open_api_versions(open_api_versions);
795
796 version_provider
797 }
798
799 fn version_provider_without_open_api_version() -> APIVersionProvider {
800 let mut version_provider =
801 APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::default()));
802 version_provider.update_open_api_versions(HashMap::new());
803
804 version_provider
805 }
806
807 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
808 key: K,
809 value: V,
810 ) -> Response {
811 HttpResponseBuilder::new()
812 .header(key.into(), value.into())
813 .body("whatever")
814 .unwrap()
815 .into()
816 }
817
818 fn assert_api_version_warning_logged<L: Into<String>, A: Into<String>>(
819 log_inspector: &MemoryDrainForTestInspector,
820 leader_aggregator_version: L,
821 aggregator_version: A,
822 ) {
823 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
824 assert!(log_inspector.contains_log(&format!(
825 "leader_aggregator_version={}",
826 leader_aggregator_version.into()
827 )));
828 assert!(
829 log_inspector
830 .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
831 );
832 }
833
834 #[test]
835 fn test_logs_warning_when_leader_aggregator_api_version_is_newer() {
836 let leader_aggregator_version = "2.0.0";
837 let aggregator_version = "1.0.0";
838 let (logger, log_inspector) = TestLogger::memory();
839 let version_provider = version_provider_with_open_api_version(aggregator_version);
840 let mut client = setup_client("http://whatever");
841 client.api_version_provider = Arc::new(version_provider);
842 client.logger = logger;
843 let response = build_fake_response_with_header(
844 MITHRIL_API_VERSION_HEADER,
845 leader_aggregator_version,
846 );
847
848 assert!(
849 Version::parse(leader_aggregator_version).unwrap()
850 > Version::parse(aggregator_version).unwrap()
851 );
852
853 client.warn_if_api_version_mismatch(&response);
854
855 assert_api_version_warning_logged(
856 &log_inspector,
857 leader_aggregator_version,
858 aggregator_version,
859 );
860 }
861
862 #[test]
863 fn test_no_warning_logged_when_versions_match() {
864 let version = "1.0.0";
865 let (logger, log_inspector) = TestLogger::memory();
866 let version_provider = version_provider_with_open_api_version(version);
867 let mut client = setup_client("http://whatever");
868 client.api_version_provider = Arc::new(version_provider);
869 client.logger = logger;
870 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
871
872 client.warn_if_api_version_mismatch(&response);
873
874 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
875 }
876
877 #[test]
878 fn test_no_warning_logged_when_leader_aggregator_api_version_is_older() {
879 let leader_aggregator_version = "1.0.0";
880 let aggregator_version = "2.0.0";
881 let (logger, log_inspector) = TestLogger::memory();
882 let version_provider = version_provider_with_open_api_version(aggregator_version);
883 let mut client = setup_client("http://whatever");
884 client.api_version_provider = Arc::new(version_provider);
885 client.logger = logger;
886 let response = build_fake_response_with_header(
887 MITHRIL_API_VERSION_HEADER,
888 leader_aggregator_version,
889 );
890
891 assert!(
892 Version::parse(leader_aggregator_version).unwrap()
893 < Version::parse(aggregator_version).unwrap()
894 );
895
896 client.warn_if_api_version_mismatch(&response);
897
898 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
899 }
900
901 #[test]
902 fn test_does_not_log_or_fail_when_header_is_missing() {
903 let (logger, log_inspector) = TestLogger::memory();
904 let mut client = setup_client("http://whatever");
905 client.logger = logger;
906 let response =
907 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
908
909 client.warn_if_api_version_mismatch(&response);
910
911 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
912 }
913
914 #[test]
915 fn test_does_not_log_or_fail_when_header_is_not_a_version() {
916 let (logger, log_inspector) = TestLogger::memory();
917 let mut client = setup_client("http://whatever");
918 client.logger = logger;
919 let response =
920 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
921
922 client.warn_if_api_version_mismatch(&response);
923
924 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
925 }
926
927 #[test]
928 fn test_logs_error_when_aggregator_version_cannot_be_computed() {
929 let (logger, log_inspector) = TestLogger::memory();
930 let version_provider = version_provider_without_open_api_version();
931 let mut client = setup_client("http://whatever");
932 client.api_version_provider = Arc::new(version_provider);
933 client.logger = logger;
934 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
935
936 client.warn_if_api_version_mismatch(&response);
937
938 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
939 }
940
941 #[tokio::test]
942 async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
943 let leader_aggregator_version = "2.0.0";
944 let aggregator_version = "1.0.0";
945 let (server, mut client) = setup_server_and_client();
946 let (logger, log_inspector) = TestLogger::memory();
947 let version_provider = version_provider_with_open_api_version(aggregator_version);
948 client.api_version_provider = Arc::new(version_provider);
949 client.logger = logger;
950 let epoch_settings_expected = EpochSettingsMessage::dummy();
951 let _server_mock = server.mock(|when, then| {
952 when.path("/epoch-settings");
953 then.status(200)
954 .body(json!(epoch_settings_expected).to_string())
955 .header(MITHRIL_API_VERSION_HEADER, leader_aggregator_version);
956 });
957
958 assert!(
959 Version::parse(leader_aggregator_version).unwrap()
960 > Version::parse(aggregator_version).unwrap()
961 );
962
963 client.retrieve_epoch_settings().await.unwrap();
964
965 assert_api_version_warning_logged(
966 &log_inspector,
967 leader_aggregator_version,
968 aggregator_version,
969 );
970 }
971 }
972
973 mod remote_certificate_retriever {
974 use mithril_common::test::double::fake_data;
975
976 use super::*;
977
978 #[tokio::test]
979 async fn test_get_latest_certificate_details() {
980 let (server, client) = setup_server_and_client();
981 let expected_certificate = fake_data::certificate("expected");
982 let latest_message: CertificateMessage =
983 expected_certificate.clone().try_into().unwrap();
984 let latest_certificates = vec![
985 CertificateListItemMessage {
986 hash: expected_certificate.hash.clone(),
987 ..CertificateListItemMessage::dummy()
988 },
989 CertificateListItemMessage::dummy(),
990 CertificateListItemMessage::dummy(),
991 ];
992 let _server_mock = server.mock(|when, then| {
993 when.path("/certificates");
994 then.status(200).body(json!(latest_certificates).to_string());
995 });
996 let _server_mock = server.mock(|when, then| {
997 when.path(format!("/certificate/{}", latest_message.hash));
998 then.status(200).body(json!(latest_message).to_string());
999 });
1000
1001 let fetched_certificate = client.get_latest_certificate_details().await.unwrap();
1002
1003 assert_eq!(Some(expected_certificate), fetched_certificate);
1004 }
1005
1006 #[tokio::test]
1007 async fn test_get_latest_genesis_certificate() {
1008 let (server, client) = setup_server_and_client();
1009 let genesis_message = CertificateMessage::dummy();
1010 let expected_genesis: Certificate = genesis_message.clone().try_into().unwrap();
1011 let _server_mock = server.mock(|when, then| {
1012 when.path("/certificate/genesis");
1013 then.status(200).body(json!(genesis_message).to_string());
1014 });
1015
1016 let fetched = client.get_genesis_certificate_details().await.unwrap();
1017
1018 assert_eq!(Some(expected_genesis), fetched);
1019 }
1020 }
1021}