1use anyhow::anyhow;
2use async_trait::async_trait;
3use mithril_common::messages::TryFromMessageAdapter;
4use reqwest::header::{self, HeaderValue};
5use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
6use semver::Version;
7use slog::{debug, error, warn, Logger};
8use std::{io, sync::Arc, time::Duration};
9use thiserror::Error;
10
11use mithril_common::{
12 api_version::APIVersionProvider,
13 entities::{ClientError, ServerError},
14 logging::LoggerExtensions,
15 messages::EpochSettingsMessage,
16 StdError, MITHRIL_AGGREGATOR_VERSION_HEADER, MITHRIL_API_VERSION_HEADER,
17};
18
19use crate::entities::LeaderAggregatorEpochSettings;
20use crate::message_adapters::FromEpochSettingsAdapter;
21
22const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
23
24const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
25 "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
26
27#[derive(Error, Debug)]
29pub enum AggregatorClientError {
30 #[error("remote server technical error")]
32 RemoteServerTechnical(#[source] StdError),
33
34 #[error("remote server logical error")]
36 RemoteServerLogical(#[source] StdError),
37
38 #[error("remote server unreachable")]
40 RemoteServerUnreachable(#[source] StdError),
41
42 #[error("unhandled status code: {0}, response text: {1}")]
44 UnhandledStatusCode(StatusCode, String),
45
46 #[error("json parsing failed")]
48 JsonParseFailed(#[source] StdError),
49
50 #[error("Input/Output error")]
52 IOError(#[from] io::Error),
53
54 #[error("HTTP client creation failed")]
56 HTTPClientCreation(#[source] StdError),
57
58 #[error("proxy creation failed")]
60 ProxyCreation(#[source] StdError),
61
62 #[error("adapter failed")]
64 Adapter(#[source] StdError),
65}
66
67impl AggregatorClientError {
68 pub async fn from_response(response: Response) -> Self {
74 let error_code = response.status();
75
76 if error_code.is_client_error() {
77 let root_cause = Self::get_root_cause(response).await;
78 Self::RemoteServerLogical(anyhow!(root_cause))
79 } else if error_code.is_server_error() {
80 let root_cause = Self::get_root_cause(response).await;
81 Self::RemoteServerTechnical(anyhow!(root_cause))
82 } else {
83 let response_text = response.text().await.unwrap_or_default();
84 Self::UnhandledStatusCode(error_code, response_text)
85 }
86 }
87
88 async fn get_root_cause(response: Response) -> String {
89 let error_code = response.status();
90 let canonical_reason = error_code
91 .canonical_reason()
92 .unwrap_or_default()
93 .to_lowercase();
94 let is_json = response
95 .headers()
96 .get(header::CONTENT_TYPE)
97 .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
98
99 if is_json {
100 let json_value: serde_json::Value = response.json().await.unwrap_or_default();
101
102 if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
103 format!(
104 "{}: {}: {}",
105 canonical_reason, client_error.label, client_error.message
106 )
107 } else if let Ok(server_error) =
108 serde_json::from_value::<ServerError>(json_value.clone())
109 {
110 format!("{}: {}", canonical_reason, server_error.message)
111 } else if json_value.is_null() {
112 canonical_reason.to_string()
113 } else {
114 format!("{canonical_reason}: {json_value}")
115 }
116 } else {
117 let response_text = response.text().await.unwrap_or_default();
118 format!("{canonical_reason}: {response_text}")
119 }
120 }
121}
122
123#[cfg_attr(test, mockall::automock)]
125#[async_trait]
126pub trait AggregatorClient: Sync + Send {
127 async fn retrieve_epoch_settings(
129 &self,
130 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError>;
131}
132
133pub struct AggregatorHTTPClient {
135 aggregator_endpoint: String,
136 relay_endpoint: Option<String>,
137 api_version_provider: Arc<APIVersionProvider>,
138 timeout_duration: Option<Duration>,
139 logger: Logger,
140}
141
142impl AggregatorHTTPClient {
143 pub fn new(
145 aggregator_endpoint: String,
146 relay_endpoint: Option<String>,
147 api_version_provider: Arc<APIVersionProvider>,
148 timeout_duration: Option<Duration>,
149 logger: Logger,
150 ) -> Self {
151 let logger = logger.new_with_component_name::<Self>();
152 debug!(logger, "New AggregatorHTTPClient created");
153 Self {
154 aggregator_endpoint,
155 relay_endpoint,
156 api_version_provider,
157 timeout_duration,
158 logger,
159 }
160 }
161
162 fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
163 let client = match &self.relay_endpoint {
164 Some(relay_endpoint) => Client::builder()
165 .proxy(
166 Proxy::all(relay_endpoint)
167 .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
168 )
169 .build()
170 .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
171 None => Client::new(),
172 };
173
174 Ok(client)
175 }
176
177 pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
179 let request_builder = request_builder
180 .header(
181 MITHRIL_API_VERSION_HEADER,
182 self.api_version_provider
183 .compute_current_version()
184 .unwrap()
185 .to_string(),
186 )
187 .header(MITHRIL_AGGREGATOR_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
188
189 if let Some(duration) = self.timeout_duration {
190 request_builder.timeout(duration)
191 } else {
192 request_builder
193 }
194 }
195
196 fn warn_if_api_version_mismatch(&self, response: &Response) {
198 let leader_version = response
199 .headers()
200 .get(MITHRIL_API_VERSION_HEADER)
201 .and_then(|v| v.to_str().ok())
202 .and_then(|s| Version::parse(s).ok());
203
204 let follower_version = self.api_version_provider.compute_current_version();
205
206 match (leader_version, follower_version) {
207 (Some(leader), Ok(follower)) if follower < leader => {
208 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
209 "leader_aggregator_version" => %leader,
210 "aggregator_version" => %follower,
211 );
212 }
213 (Some(_), Err(error)) => {
214 error!(
215 self.logger,
216 "Failed to compute the current aggregator API version";
217 "error" => error.to_string()
218 );
219 }
220 _ => {}
221 }
222 }
223}
224
225#[async_trait]
226impl AggregatorClient for AggregatorHTTPClient {
227 async fn retrieve_epoch_settings(
228 &self,
229 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
230 debug!(self.logger, "Retrieve epoch settings");
231 let url = format!("{}/epoch-settings", self.aggregator_endpoint);
232 let response = self
233 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
234 .send()
235 .await;
236
237 match response {
238 Ok(response) => match response.status() {
239 StatusCode::OK => {
240 self.warn_if_api_version_mismatch(&response);
241 match response.json::<EpochSettingsMessage>().await {
242 Ok(message) => {
243 let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
244 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
245 Ok(Some(epoch_settings))
246 }
247 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
248 }
249 }
250 _ => Err(AggregatorClientError::from_response(response).await),
251 },
252 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
253 }
254 }
255}
256
257#[cfg(test)]
258pub(crate) mod dumb {
259 use tokio::sync::RwLock;
260
261 use super::*;
262
263 pub struct DumbAggregatorClient {
267 epoch_settings: RwLock<Option<LeaderAggregatorEpochSettings>>,
268 }
269
270 impl Default for DumbAggregatorClient {
271 fn default() -> Self {
272 Self {
273 epoch_settings: RwLock::new(Some(LeaderAggregatorEpochSettings::dummy())),
274 }
275 }
276 }
277
278 #[async_trait]
279 impl AggregatorClient for DumbAggregatorClient {
280 async fn retrieve_epoch_settings(
281 &self,
282 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
283 let epoch_settings = self.epoch_settings.read().await.clone();
284
285 Ok(epoch_settings)
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use http::response::Builder as HttpResponseBuilder;
293 use httpmock::prelude::*;
294 use serde_json::json;
295
296 use mithril_common::api_version::DummyApiVersionDiscriminantSource;
297
298 use crate::test_tools::TestLogger;
299
300 use super::*;
301
302 fn setup_client<U: Into<String>>(server_url: U) -> AggregatorHTTPClient {
303 let discriminant_source = DummyApiVersionDiscriminantSource::default();
304 let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
305
306 AggregatorHTTPClient::new(
307 server_url.into(),
308 None,
309 Arc::new(api_version_provider),
310 None,
311 TestLogger::stdout(),
312 )
313 }
314
315 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
316 let server = MockServer::start();
317 let aggregator_endpoint = server.url("");
318 let client = setup_client(&aggregator_endpoint);
319
320 (server, client)
321 }
322
323 fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
324 HttpResponseBuilder::new()
325 .status(status_code)
326 .body(body.into())
327 .unwrap()
328 .into()
329 }
330
331 fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
332 HttpResponseBuilder::new()
333 .status(status_code)
334 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
335 .body(serde_json::to_string(&body).unwrap())
336 .unwrap()
337 .into()
338 }
339
340 macro_rules! assert_error_text_contains {
341 ($error: expr, $expect_contains: expr) => {
342 let error = &$error;
343 assert!(
344 error.contains($expect_contains),
345 "Expected error message to contain '{}'\ngot '{error:?}'",
346 $expect_contains,
347 );
348 };
349 }
350
351 #[tokio::test]
352 async fn test_epoch_settings_ok_200() {
353 let (server, client) = setup_server_and_client();
354 let epoch_settings_expected = EpochSettingsMessage::dummy();
355 let _server_mock = server.mock(|when, then| {
356 when.path("/epoch-settings");
357 then.status(200)
358 .body(json!(epoch_settings_expected).to_string());
359 });
360
361 let epoch_settings = client.retrieve_epoch_settings().await;
362 epoch_settings.as_ref().expect("unexpected error");
363 assert_eq!(
364 FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
365 epoch_settings.unwrap().unwrap()
366 );
367 }
368
369 #[tokio::test]
370 async fn test_epoch_settings_ko_500() {
371 let (server, client) = setup_server_and_client();
372 let _server_mock = server.mock(|when, then| {
373 when.path("/epoch-settings");
374 then.status(500).body("an error occurred");
375 });
376
377 match client.retrieve_epoch_settings().await.unwrap_err() {
378 AggregatorClientError::RemoteServerTechnical(_) => (),
379 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
380 };
381 }
382
383 #[tokio::test]
384 async fn test_epoch_settings_timeout() {
385 let (server, mut client) = setup_server_and_client();
386 client.timeout_duration = Some(Duration::from_millis(10));
387 let _server_mock = server.mock(|when, then| {
388 when.path("/epoch-settings");
389 then.delay(Duration::from_millis(100));
390 });
391
392 let error = client
393 .retrieve_epoch_settings()
394 .await
395 .expect_err("retrieve_epoch_settings should fail");
396
397 assert!(
398 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
399 "unexpected error type: {error:?}"
400 );
401 }
402
403 #[tokio::test]
404 async fn test_4xx_errors_are_handled_as_remote_server_logical() {
405 let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
406 let handled_error = AggregatorClientError::from_response(response).await;
407
408 assert!(
409 matches!(
410 handled_error,
411 AggregatorClientError::RemoteServerLogical(..)
412 ),
413 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
414 );
415 }
416
417 #[tokio::test]
418 async fn test_5xx_errors_are_handled_as_remote_server_technical() {
419 let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
420 let handled_error = AggregatorClientError::from_response(response).await;
421
422 assert!(
423 matches!(
424 handled_error,
425 AggregatorClientError::RemoteServerTechnical(..)
426 ),
427 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
428 );
429 }
430
431 #[tokio::test]
432 async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text(
433 ) {
434 let response = build_text_response(StatusCode::OK, "ok text");
435 let handled_error = AggregatorClientError::from_response(response).await;
436
437 assert!(
438 matches!(
439 handled_error,
440 AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
441 ),
442 "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
443 );
444 }
445
446 #[tokio::test]
447 async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
448 let error_text = "An error occurred; please try again later.";
449 let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
450
451 assert_error_text_contains!(
452 AggregatorClientError::get_root_cause(response).await,
453 "expectation failed: An error occurred; please try again later."
454 );
455 }
456
457 #[tokio::test]
458 async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message(
459 ) {
460 let client_error = ClientError::new("label", "message");
461 let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
462
463 assert_error_text_contains!(
464 AggregatorClientError::get_root_cause(response).await,
465 "bad request: label: message"
466 );
467 }
468
469 #[tokio::test]
470 async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message(
471 ) {
472 let server_error = ServerError::new("message");
473 let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
474
475 assert_error_text_contains!(
476 AggregatorClientError::get_root_cause(response).await,
477 "bad request: message"
478 );
479 }
480
481 #[tokio::test]
482 async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
483 let response = build_json_response(
484 StatusCode::INTERNAL_SERVER_ERROR,
485 &json!({ "second": "unknown", "first": "foreign" }),
486 );
487
488 assert_error_text_contains!(
489 AggregatorClientError::get_root_cause(response).await,
490 r#"internal server error: {"first":"foreign","second":"unknown"}"#
491 );
492 }
493
494 #[tokio::test]
495 async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
496 let response = HttpResponseBuilder::new()
497 .status(StatusCode::BAD_REQUEST)
498 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
499 .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
500 .unwrap()
501 .into();
502
503 let root_cause = AggregatorClientError::get_root_cause(response).await;
504
505 assert_error_text_contains!(root_cause, "bad request");
506 assert!(
507 !root_cause.contains("bad request: "),
508 "Expected error message should not contain additional information \ngot '{root_cause:?}'"
509 );
510 }
511
512 mod warn_if_api_version_mismatch {
513 use std::collections::HashMap;
514
515 use mithril_common::test_utils::MemoryDrainForTestInspector;
516
517 use super::*;
518
519 fn version_provider_with_open_api_version<V: Into<String>>(
520 version: V,
521 ) -> APIVersionProvider {
522 let mut version_provider = version_provider_without_open_api_version();
523 let mut open_api_versions = HashMap::new();
524 open_api_versions.insert(
525 "openapi.yaml".to_string(),
526 Version::parse(&version.into()).unwrap(),
527 );
528 version_provider.update_open_api_versions(open_api_versions);
529
530 version_provider
531 }
532
533 fn version_provider_without_open_api_version() -> APIVersionProvider {
534 let mut version_provider =
535 APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::default()));
536 version_provider.update_open_api_versions(HashMap::new());
537
538 version_provider
539 }
540
541 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
542 key: K,
543 value: V,
544 ) -> Response {
545 HttpResponseBuilder::new()
546 .header(key.into(), value.into())
547 .body("whatever")
548 .unwrap()
549 .into()
550 }
551
552 fn assert_api_version_warning_logged<L: Into<String>, A: Into<String>>(
553 log_inspector: &MemoryDrainForTestInspector,
554 leader_aggregator_version: L,
555 aggregator_version: A,
556 ) {
557 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
558 assert!(log_inspector.contains_log(&format!(
559 "leader_aggregator_version={}",
560 leader_aggregator_version.into()
561 )));
562 assert!(log_inspector
563 .contains_log(&format!("aggregator_version={}", aggregator_version.into())));
564 }
565
566 #[test]
567 fn test_logs_warning_when_leader_aggregator_api_version_is_newer() {
568 let leader_aggregator_version = "2.0.0";
569 let aggregator_version = "1.0.0";
570 let (logger, log_inspector) = TestLogger::memory();
571 let version_provider = version_provider_with_open_api_version(aggregator_version);
572 let mut client = setup_client("whatever");
573 client.api_version_provider = Arc::new(version_provider);
574 client.logger = logger;
575 let response = build_fake_response_with_header(
576 MITHRIL_API_VERSION_HEADER,
577 leader_aggregator_version,
578 );
579
580 assert!(
581 Version::parse(leader_aggregator_version).unwrap()
582 > Version::parse(aggregator_version).unwrap()
583 );
584
585 client.warn_if_api_version_mismatch(&response);
586
587 assert_api_version_warning_logged(
588 &log_inspector,
589 leader_aggregator_version,
590 aggregator_version,
591 );
592 }
593
594 #[test]
595 fn test_no_warning_logged_when_versions_match() {
596 let version = "1.0.0";
597 let (logger, log_inspector) = TestLogger::memory();
598 let version_provider = version_provider_with_open_api_version(version);
599 let mut client = setup_client("whatever");
600 client.api_version_provider = Arc::new(version_provider);
601 client.logger = logger;
602 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
603
604 client.warn_if_api_version_mismatch(&response);
605
606 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
607 }
608
609 #[test]
610 fn test_no_warning_logged_when_leader_aggregator_api_version_is_older() {
611 let leader_aggregator_version = "1.0.0";
612 let aggregator_version = "2.0.0";
613 let (logger, log_inspector) = TestLogger::memory();
614 let version_provider = version_provider_with_open_api_version(aggregator_version);
615 let mut client = setup_client("whatever");
616 client.api_version_provider = Arc::new(version_provider);
617 client.logger = logger;
618 let response = build_fake_response_with_header(
619 MITHRIL_API_VERSION_HEADER,
620 leader_aggregator_version,
621 );
622
623 assert!(
624 Version::parse(leader_aggregator_version).unwrap()
625 < Version::parse(aggregator_version).unwrap()
626 );
627
628 client.warn_if_api_version_mismatch(&response);
629
630 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
631 }
632
633 #[test]
634 fn test_does_not_log_or_fail_when_header_is_missing() {
635 let (logger, log_inspector) = TestLogger::memory();
636 let mut client = setup_client("whatever");
637 client.logger = logger;
638 let response =
639 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
640
641 client.warn_if_api_version_mismatch(&response);
642
643 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
644 }
645
646 #[test]
647 fn test_does_not_log_or_fail_when_header_is_not_a_version() {
648 let (logger, log_inspector) = TestLogger::memory();
649 let mut client = setup_client("whatever");
650 client.logger = logger;
651 let response =
652 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
653
654 client.warn_if_api_version_mismatch(&response);
655
656 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
657 }
658
659 #[test]
660 fn test_logs_error_when_aggregator_version_cannot_be_computed() {
661 let (logger, log_inspector) = TestLogger::memory();
662 let version_provider = version_provider_without_open_api_version();
663 let mut client = setup_client("whatever");
664 client.api_version_provider = Arc::new(version_provider);
665 client.logger = logger;
666 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
667
668 client.warn_if_api_version_mismatch(&response);
669
670 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
671 }
672
673 #[tokio::test]
674 async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
675 let leader_aggregator_version = "2.0.0";
676 let aggregator_version = "1.0.0";
677 let (server, mut client) = setup_server_and_client();
678 let (logger, log_inspector) = TestLogger::memory();
679 let version_provider = version_provider_with_open_api_version(aggregator_version);
680 client.api_version_provider = Arc::new(version_provider);
681 client.logger = logger;
682 let epoch_settings_expected = EpochSettingsMessage::dummy();
683 let _server_mock = server.mock(|when, then| {
684 when.path("/epoch-settings");
685 then.status(200)
686 .body(json!(epoch_settings_expected).to_string())
687 .header(MITHRIL_API_VERSION_HEADER, leader_aggregator_version);
688 });
689
690 assert!(
691 Version::parse(leader_aggregator_version).unwrap()
692 > Version::parse(aggregator_version).unwrap()
693 );
694
695 client.retrieve_epoch_settings().await.unwrap();
696
697 assert_api_version_warning_logged(
698 &log_inspector,
699 leader_aggregator_version,
700 aggregator_version,
701 );
702 }
703 }
704}