mithril_aggregator/services/
aggregator_client.rs1use anyhow::anyhow;
2use async_trait::async_trait;
3use mithril_common::messages::TryFromMessageAdapter;
4use reqwest::header::{self, HeaderValue};
5use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
6use slog::{debug, error, Logger};
7use std::{io, sync::Arc, time::Duration};
8use thiserror::Error;
9
10use mithril_common::{
11 api_version::APIVersionProvider,
12 entities::{ClientError, ServerError},
13 logging::LoggerExtensions,
14 messages::EpochSettingsMessage,
15 StdError, MITHRIL_AGGREGATOR_VERSION_HEADER, MITHRIL_API_VERSION_HEADER,
16};
17
18use crate::entities::LeaderAggregatorEpochSettings;
19use crate::message_adapters::FromEpochSettingsAdapter;
20
21const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
22
23#[derive(Error, Debug)]
25pub enum AggregatorClientError {
26 #[error("remote server technical error")]
28 RemoteServerTechnical(#[source] StdError),
29
30 #[error("remote server logical error")]
32 RemoteServerLogical(#[source] StdError),
33
34 #[error("remote server unreachable")]
36 RemoteServerUnreachable(#[source] StdError),
37
38 #[error("unhandled status code: {0}, response text: {1}")]
40 UnhandledStatusCode(StatusCode, String),
41
42 #[error("json parsing failed")]
44 JsonParseFailed(#[source] StdError),
45
46 #[error("Input/Output error")]
48 IOError(#[from] io::Error),
49
50 #[error("HTTP API version mismatch")]
52 ApiVersionMismatch(#[source] StdError),
53
54 #[error("HTTP client creation failed")]
56 HTTPClientCreation(#[source] StdError),
57
58 #[error("proxy creation failed")]
60 ProxyCreation(#[source] StdError),
61
62 #[error("adapter failed")]
64 Adapter(#[source] StdError),
65}
66
67#[cfg(test)]
68impl AggregatorClientError {
70 pub(crate) fn is_api_version_mismatch(&self) -> bool {
71 matches!(self, Self::ApiVersionMismatch(_))
72 }
73}
74
75impl AggregatorClientError {
76 pub async fn from_response(response: Response) -> Self {
82 let error_code = response.status();
83
84 if error_code.is_client_error() {
85 let root_cause = Self::get_root_cause(response).await;
86 Self::RemoteServerLogical(anyhow!(root_cause))
87 } else if error_code.is_server_error() {
88 let root_cause = Self::get_root_cause(response).await;
89 Self::RemoteServerTechnical(anyhow!(root_cause))
90 } else {
91 let response_text = response.text().await.unwrap_or_default();
92 Self::UnhandledStatusCode(error_code, response_text)
93 }
94 }
95
96 async fn get_root_cause(response: Response) -> String {
97 let error_code = response.status();
98 let canonical_reason = error_code
99 .canonical_reason()
100 .unwrap_or_default()
101 .to_lowercase();
102 let is_json = response
103 .headers()
104 .get(header::CONTENT_TYPE)
105 .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
106
107 if is_json {
108 let json_value: serde_json::Value = response.json().await.unwrap_or_default();
109
110 if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
111 format!(
112 "{}: {}: {}",
113 canonical_reason, client_error.label, client_error.message
114 )
115 } else if let Ok(server_error) =
116 serde_json::from_value::<ServerError>(json_value.clone())
117 {
118 format!("{}: {}", canonical_reason, server_error.message)
119 } else if json_value.is_null() {
120 canonical_reason.to_string()
121 } else {
122 format!("{}: {}", canonical_reason, json_value)
123 }
124 } else {
125 let response_text = response.text().await.unwrap_or_default();
126 format!("{}: {}", canonical_reason, response_text)
127 }
128 }
129}
130
131#[cfg_attr(test, mockall::automock)]
133#[async_trait]
134pub trait AggregatorClient: Sync + Send {
135 async fn retrieve_epoch_settings(
137 &self,
138 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError>;
139}
140
141pub struct AggregatorHTTPClient {
143 aggregator_endpoint: String,
144 relay_endpoint: Option<String>,
145 api_version_provider: Arc<APIVersionProvider>,
146 timeout_duration: Option<Duration>,
147 logger: Logger,
148}
149
150impl AggregatorHTTPClient {
151 pub fn new(
153 aggregator_endpoint: String,
154 relay_endpoint: Option<String>,
155 api_version_provider: Arc<APIVersionProvider>,
156 timeout_duration: Option<Duration>,
157 logger: Logger,
158 ) -> Self {
159 let logger = logger.new_with_component_name::<Self>();
160 debug!(logger, "New AggregatorHTTPClient created");
161 Self {
162 aggregator_endpoint,
163 relay_endpoint,
164 api_version_provider,
165 timeout_duration,
166 logger,
167 }
168 }
169
170 fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
171 let client = match &self.relay_endpoint {
172 Some(relay_endpoint) => Client::builder()
173 .proxy(
174 Proxy::all(relay_endpoint)
175 .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
176 )
177 .build()
178 .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
179 None => Client::new(),
180 };
181
182 Ok(client)
183 }
184
185 pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
187 let request_builder = request_builder
188 .header(
189 MITHRIL_API_VERSION_HEADER,
190 self.api_version_provider
191 .compute_current_version()
192 .unwrap()
193 .to_string(),
194 )
195 .header(MITHRIL_AGGREGATOR_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
196
197 if let Some(duration) = self.timeout_duration {
198 request_builder.timeout(duration)
199 } else {
200 request_builder
201 }
202 }
203
204 fn handle_api_error(&self, response: &Response) -> AggregatorClientError {
206 if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
207 AggregatorClientError::ApiVersionMismatch(anyhow!(
208 "server version: '{}', signer version: '{}'",
209 version.to_str().unwrap(),
210 self.api_version_provider.compute_current_version().unwrap()
211 ))
212 } else {
213 AggregatorClientError::ApiVersionMismatch(anyhow!(
214 "version precondition failed, sent version '{}'.",
215 self.api_version_provider.compute_current_version().unwrap()
216 ))
217 }
218 }
219}
220
221#[async_trait]
222impl AggregatorClient for AggregatorHTTPClient {
223 async fn retrieve_epoch_settings(
224 &self,
225 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
226 debug!(self.logger, "Retrieve epoch settings");
227 let url = format!("{}/epoch-settings", self.aggregator_endpoint);
228 let response = self
229 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
230 .send()
231 .await;
232
233 match response {
234 Ok(response) => match response.status() {
235 StatusCode::OK => match response.json::<EpochSettingsMessage>().await {
236 Ok(message) => {
237 let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
238 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
239 Ok(Some(epoch_settings))
240 }
241 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
242 },
243 StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
244 _ => Err(AggregatorClientError::from_response(response).await),
245 },
246 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
247 }
248 }
249}
250
251#[cfg(test)]
252pub(crate) mod dumb {
253 use tokio::sync::RwLock;
254
255 use super::*;
256
257 pub struct DumbAggregatorClient {
261 epoch_settings: RwLock<Option<LeaderAggregatorEpochSettings>>,
262 }
263
264 impl Default for DumbAggregatorClient {
265 fn default() -> Self {
266 Self {
267 epoch_settings: RwLock::new(Some(LeaderAggregatorEpochSettings::dummy())),
268 }
269 }
270 }
271
272 #[async_trait]
273 impl AggregatorClient for DumbAggregatorClient {
274 async fn retrieve_epoch_settings(
275 &self,
276 ) -> Result<Option<LeaderAggregatorEpochSettings>, AggregatorClientError> {
277 let epoch_settings = self.epoch_settings.read().await.clone();
278
279 Ok(epoch_settings)
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use http::response::Builder as HttpResponseBuilder;
287 use httpmock::prelude::*;
288 use serde_json::json;
289
290 use mithril_common::entities::Epoch;
291 use mithril_common::era::{EraChecker, SupportedEra};
292
293 use crate::test_tools::TestLogger;
294
295 use super::*;
296
297 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
298 let server = MockServer::start();
299 let aggregator_endpoint = server.url("");
300 let relay_endpoint = None;
301 let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
302 let api_version_provider = APIVersionProvider::new(Arc::new(era_checker));
303
304 (
305 server,
306 AggregatorHTTPClient::new(
307 aggregator_endpoint,
308 relay_endpoint,
309 Arc::new(api_version_provider),
310 None,
311 TestLogger::stdout(),
312 ),
313 )
314 }
315
316 fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
317 HttpResponseBuilder::new()
318 .status(status_code)
319 .body(body.into())
320 .unwrap()
321 .into()
322 }
323
324 fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
325 HttpResponseBuilder::new()
326 .status(status_code)
327 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
328 .body(serde_json::to_string(&body).unwrap())
329 .unwrap()
330 .into()
331 }
332
333 macro_rules! assert_error_text_contains {
334 ($error: expr, $expect_contains: expr) => {
335 let error = &$error;
336 assert!(
337 error.contains($expect_contains),
338 "Expected error message to contain '{}'\ngot '{error:?}'",
339 $expect_contains,
340 );
341 };
342 }
343
344 #[tokio::test]
345 async fn test_epoch_settings_ok_200() {
346 let (server, client) = setup_server_and_client();
347 let epoch_settings_expected = EpochSettingsMessage::dummy();
348 let _server_mock = server.mock(|when, then| {
349 when.path("/epoch-settings");
350 then.status(200)
351 .body(json!(epoch_settings_expected).to_string());
352 });
353
354 let epoch_settings = client.retrieve_epoch_settings().await;
355 epoch_settings.as_ref().expect("unexpected error");
356 assert_eq!(
357 FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
358 epoch_settings.unwrap().unwrap()
359 );
360 }
361
362 #[tokio::test]
363 async fn test_epoch_settings_ko_412() {
364 let (server, client) = setup_server_and_client();
365 let _server_mock = server.mock(|when, then| {
366 when.path("/epoch-settings");
367 then.status(412)
368 .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
369 });
370
371 let epoch_settings = client.retrieve_epoch_settings().await.unwrap_err();
372
373 assert!(epoch_settings.is_api_version_mismatch());
374 }
375
376 #[tokio::test]
377 async fn test_epoch_settings_ko_500() {
378 let (server, client) = setup_server_and_client();
379 let _server_mock = server.mock(|when, then| {
380 when.path("/epoch-settings");
381 then.status(500).body("an error occurred");
382 });
383
384 match client.retrieve_epoch_settings().await.unwrap_err() {
385 AggregatorClientError::RemoteServerTechnical(_) => (),
386 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
387 };
388 }
389
390 #[tokio::test]
391 async fn test_epoch_settings_timeout() {
392 let (server, mut client) = setup_server_and_client();
393 client.timeout_duration = Some(Duration::from_millis(10));
394 let _server_mock = server.mock(|when, then| {
395 when.path("/epoch-settings");
396 then.delay(Duration::from_millis(100));
397 });
398
399 let error = client
400 .retrieve_epoch_settings()
401 .await
402 .expect_err("retrieve_epoch_settings should fail");
403
404 assert!(
405 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
406 "unexpected error type: {error:?}"
407 );
408 }
409
410 #[tokio::test]
411 async fn test_4xx_errors_are_handled_as_remote_server_logical() {
412 let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
413 let handled_error = AggregatorClientError::from_response(response).await;
414
415 assert!(
416 matches!(
417 handled_error,
418 AggregatorClientError::RemoteServerLogical(..)
419 ),
420 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
421 );
422 }
423
424 #[tokio::test]
425 async fn test_5xx_errors_are_handled_as_remote_server_technical() {
426 let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
427 let handled_error = AggregatorClientError::from_response(response).await;
428
429 assert!(
430 matches!(
431 handled_error,
432 AggregatorClientError::RemoteServerTechnical(..)
433 ),
434 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
435 );
436 }
437
438 #[tokio::test]
439 async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text(
440 ) {
441 let response = build_text_response(StatusCode::OK, "ok text");
442 let handled_error = AggregatorClientError::from_response(response).await;
443
444 assert!(
445 matches!(
446 handled_error,
447 AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
448 ),
449 "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
450 );
451 }
452
453 #[tokio::test]
454 async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
455 let error_text = "An error occurred; please try again later.";
456 let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
457
458 assert_error_text_contains!(
459 AggregatorClientError::get_root_cause(response).await,
460 "expectation failed: An error occurred; please try again later."
461 );
462 }
463
464 #[tokio::test]
465 async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message(
466 ) {
467 let client_error = ClientError::new("label", "message");
468 let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
469
470 assert_error_text_contains!(
471 AggregatorClientError::get_root_cause(response).await,
472 "bad request: label: message"
473 );
474 }
475
476 #[tokio::test]
477 async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message(
478 ) {
479 let server_error = ServerError::new("message");
480 let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
481
482 assert_error_text_contains!(
483 AggregatorClientError::get_root_cause(response).await,
484 "bad request: message"
485 );
486 }
487
488 #[tokio::test]
489 async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
490 let response = build_json_response(
491 StatusCode::INTERNAL_SERVER_ERROR,
492 &json!({ "second": "unknown", "first": "foreign" }),
493 );
494
495 assert_error_text_contains!(
496 AggregatorClientError::get_root_cause(response).await,
497 r#"internal server error: {"first":"foreign","second":"unknown"}"#
498 );
499 }
500
501 #[tokio::test]
502 async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
503 let response = HttpResponseBuilder::new()
504 .status(StatusCode::BAD_REQUEST)
505 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
506 .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
507 .unwrap()
508 .into();
509
510 let root_cause = AggregatorClientError::get_root_cause(response).await;
511
512 assert_error_text_contains!(root_cause, "bad request");
513 assert!(
514 !root_cause.contains("bad request: "),
515 "Expected error message should not contain additional information \ngot '{root_cause:?}'"
516 );
517 }
518}