1use anyhow::anyhow;
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
5use slog::{debug, error, Logger};
6use std::{io, sync::Arc, time::Duration};
7use thiserror::Error;
8
9use mithril_common::{
10 api_version::APIVersionProvider,
11 entities::{
12 ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer,
13 SingleSignatures,
14 },
15 logging::LoggerExtensions,
16 messages::{
17 AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
18 },
19 StdError, StdResult, MITHRIL_API_VERSION_HEADER, MITHRIL_SIGNER_VERSION_HEADER,
20};
21
22use crate::entities::SignerEpochSettings;
23use crate::message_adapters::{
24 FromEpochSettingsAdapter, ToRegisterSignatureMessageAdapter, ToRegisterSignerMessageAdapter,
25};
26use crate::services::SignaturePublisher;
27
28const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
29
30#[derive(Error, Debug)]
32pub enum AggregatorClientError {
33 #[error("remote server technical error")]
35 RemoteServerTechnical(#[source] StdError),
36
37 #[error("remote server logical error")]
39 RemoteServerLogical(#[source] StdError),
40
41 #[error("remote server unreachable")]
43 RemoteServerUnreachable(#[source] StdError),
44
45 #[error("unhandled status code: {0}, response text: {1}")]
47 UnhandledStatusCode(StatusCode, String),
48
49 #[error("json parsing failed")]
51 JsonParseFailed(#[source] StdError),
52
53 #[error("Input/Output error")]
55 IOError(#[from] io::Error),
56
57 #[error("HTTP API version mismatch")]
59 ApiVersionMismatch(#[source] StdError),
60
61 #[error("HTTP client creation failed")]
63 HTTPClientCreation(#[source] StdError),
64
65 #[error("proxy creation failed")]
67 ProxyCreation(#[source] StdError),
68
69 #[error("adapter failed")]
71 Adapter(#[source] StdError),
72
73 #[error("a signer registration round is not opened yet, please try again later")]
75 RegistrationRoundNotYetOpened(#[source] StdError),
76}
77
78#[cfg(test)]
79impl AggregatorClientError {
81 pub(crate) fn is_api_version_mismatch(&self) -> bool {
82 matches!(self, Self::ApiVersionMismatch(_))
83 }
84}
85
86impl AggregatorClientError {
87 pub async fn from_response(response: Response) -> Self {
93 let error_code = response.status();
94
95 if error_code.is_client_error() {
96 let root_cause = Self::get_root_cause(response).await;
97 Self::RemoteServerLogical(anyhow!(root_cause))
98 } else if error_code.is_server_error() {
99 let root_cause = Self::get_root_cause(response).await;
100 match error_code.as_u16() {
101 550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)),
102 _ => Self::RemoteServerTechnical(anyhow!(root_cause)),
103 }
104 } else {
105 let response_text = response.text().await.unwrap_or_default();
106 Self::UnhandledStatusCode(error_code, response_text)
107 }
108 }
109
110 async fn get_root_cause(response: Response) -> String {
111 let error_code = response.status();
112 let canonical_reason = error_code
113 .canonical_reason()
114 .unwrap_or_default()
115 .to_lowercase();
116 let is_json = response
117 .headers()
118 .get(header::CONTENT_TYPE)
119 .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
120
121 if is_json {
122 let json_value: serde_json::Value = response.json().await.unwrap_or_default();
123
124 if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
125 format!(
126 "{}: {}: {}",
127 canonical_reason, client_error.label, client_error.message
128 )
129 } else if let Ok(server_error) =
130 serde_json::from_value::<ServerError>(json_value.clone())
131 {
132 format!("{}: {}", canonical_reason, server_error.message)
133 } else if json_value.is_null() {
134 canonical_reason.to_string()
135 } else {
136 format!("{}: {}", canonical_reason, json_value)
137 }
138 } else {
139 let response_text = response.text().await.unwrap_or_default();
140 format!("{}: {}", canonical_reason, response_text)
141 }
142 }
143}
144
145#[cfg_attr(test, mockall::automock)]
147#[async_trait]
148pub trait AggregatorClient: Sync + Send {
149 async fn retrieve_epoch_settings(
151 &self,
152 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError>;
153
154 async fn register_signer(
156 &self,
157 epoch: Epoch,
158 signer: &Signer,
159 ) -> Result<(), AggregatorClientError>;
160
161 async fn register_signatures(
163 &self,
164 signed_entity_type: &SignedEntityType,
165 signatures: &SingleSignatures,
166 protocol_message: &ProtocolMessage,
167 ) -> Result<(), AggregatorClientError>;
168
169 async fn retrieve_aggregator_features(
171 &self,
172 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError>;
173}
174
175#[async_trait]
176impl<T: AggregatorClient> SignaturePublisher for T {
177 async fn publish(
178 &self,
179 signed_entity_type: &SignedEntityType,
180 signatures: &SingleSignatures,
181 protocol_message: &ProtocolMessage,
182 ) -> StdResult<()> {
183 self.register_signatures(signed_entity_type, signatures, protocol_message)
184 .await?;
185 Ok(())
186 }
187}
188
189pub struct AggregatorHTTPClient {
191 aggregator_endpoint: String,
192 relay_endpoint: Option<String>,
193 api_version_provider: Arc<APIVersionProvider>,
194 timeout_duration: Option<Duration>,
195 logger: Logger,
196}
197
198impl AggregatorHTTPClient {
199 pub fn new(
201 aggregator_endpoint: String,
202 relay_endpoint: Option<String>,
203 api_version_provider: Arc<APIVersionProvider>,
204 timeout_duration: Option<Duration>,
205 logger: Logger,
206 ) -> Self {
207 let logger = logger.new_with_component_name::<Self>();
208 debug!(logger, "New AggregatorHTTPClient created");
209 Self {
210 aggregator_endpoint,
211 relay_endpoint,
212 api_version_provider,
213 timeout_duration,
214 logger,
215 }
216 }
217
218 fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
219 let client = match &self.relay_endpoint {
220 Some(relay_endpoint) => Client::builder()
221 .proxy(
222 Proxy::all(relay_endpoint)
223 .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
224 )
225 .build()
226 .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
227 None => Client::new(),
228 };
229
230 Ok(client)
231 }
232
233 pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
235 let request_builder = request_builder
236 .header(
237 MITHRIL_API_VERSION_HEADER,
238 self.api_version_provider
239 .compute_current_version()
240 .unwrap()
241 .to_string(),
242 )
243 .header(MITHRIL_SIGNER_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
244
245 if let Some(duration) = self.timeout_duration {
246 request_builder.timeout(duration)
247 } else {
248 request_builder
249 }
250 }
251
252 fn handle_api_error(&self, response: &Response) -> AggregatorClientError {
254 if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
255 AggregatorClientError::ApiVersionMismatch(anyhow!(
256 "server version: '{}', signer version: '{}'",
257 version.to_str().unwrap(),
258 self.api_version_provider.compute_current_version().unwrap()
259 ))
260 } else {
261 AggregatorClientError::ApiVersionMismatch(anyhow!(
262 "version precondition failed, sent version '{}'.",
263 self.api_version_provider.compute_current_version().unwrap()
264 ))
265 }
266 }
267}
268
269#[async_trait]
270impl AggregatorClient for AggregatorHTTPClient {
271 async fn retrieve_epoch_settings(
272 &self,
273 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
274 debug!(self.logger, "Retrieve epoch settings");
275 let url = format!("{}/epoch-settings", self.aggregator_endpoint);
276 let response = self
277 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
278 .send()
279 .await;
280
281 match response {
282 Ok(response) => match response.status() {
283 StatusCode::OK => match response.json::<EpochSettingsMessage>().await {
284 Ok(message) => {
285 let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
286 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
287 Ok(Some(epoch_settings))
288 }
289 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
290 },
291 StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
292 _ => Err(AggregatorClientError::from_response(response).await),
293 },
294 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
295 }
296 }
297
298 async fn register_signer(
299 &self,
300 epoch: Epoch,
301 signer: &Signer,
302 ) -> Result<(), AggregatorClientError> {
303 debug!(self.logger, "Register signer");
304 let url = format!("{}/register-signer", self.aggregator_endpoint);
305 let register_signer_message =
306 ToRegisterSignerMessageAdapter::try_adapt((epoch, signer.to_owned()))
307 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
308 let response = self
309 .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
310 .json(®ister_signer_message)
311 .send()
312 .await;
313
314 match response {
315 Ok(response) => match response.status() {
316 StatusCode::CREATED => Ok(()),
317 StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
318 _ => Err(AggregatorClientError::from_response(response).await),
319 },
320 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
321 }
322 }
323
324 async fn register_signatures(
325 &self,
326 signed_entity_type: &SignedEntityType,
327 signatures: &SingleSignatures,
328 protocol_message: &ProtocolMessage,
329 ) -> Result<(), AggregatorClientError> {
330 debug!(self.logger, "Register signatures");
331 let url = format!("{}/register-signatures", self.aggregator_endpoint);
332 let register_single_signature_message = ToRegisterSignatureMessageAdapter::try_adapt((
333 signed_entity_type.to_owned(),
334 signatures.to_owned(),
335 protocol_message,
336 ))
337 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
338 let response = self
339 .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
340 .json(®ister_single_signature_message)
341 .send()
342 .await;
343
344 match response {
345 Ok(response) => match response.status() {
346 StatusCode::CREATED | StatusCode::ACCEPTED => Ok(()),
347 StatusCode::GONE => {
348 let root_cause = AggregatorClientError::get_root_cause(response).await;
349 debug!(self.logger, "Message already certified or expired"; "details" => &root_cause);
350
351 Ok(())
352 }
353 StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
354 _ => Err(AggregatorClientError::from_response(response).await),
355 },
356 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
357 }
358 }
359
360 async fn retrieve_aggregator_features(
361 &self,
362 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
363 debug!(self.logger, "Retrieve aggregator features message");
364 let url = format!("{}/", self.aggregator_endpoint);
365 let response = self
366 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
367 .send()
368 .await;
369
370 match response {
371 Ok(response) => match response.status() {
372 StatusCode::OK => Ok(response
373 .json::<AggregatorFeaturesMessage>()
374 .await
375 .map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?),
376 StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
377 _ => Err(AggregatorClientError::from_response(response).await),
378 },
379 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
380 }
381 }
382}
383
384#[cfg(test)]
385pub(crate) mod dumb {
386 use tokio::sync::RwLock;
387
388 use super::*;
389
390 pub struct DumbAggregatorClient {
394 epoch_settings: RwLock<Option<SignerEpochSettings>>,
395 last_registered_signer: RwLock<Option<Signer>>,
396 aggregator_features: RwLock<AggregatorFeaturesMessage>,
397 }
398
399 impl DumbAggregatorClient {
400 pub async fn get_last_registered_signer(&self) -> Option<Signer> {
402 self.last_registered_signer.read().await.clone()
403 }
404
405 pub async fn set_aggregator_features(
406 &self,
407 aggregator_features: AggregatorFeaturesMessage,
408 ) {
409 let mut aggregator_features_writer = self.aggregator_features.write().await;
410 *aggregator_features_writer = aggregator_features;
411 }
412 }
413
414 impl Default for DumbAggregatorClient {
415 fn default() -> Self {
416 Self {
417 epoch_settings: RwLock::new(Some(SignerEpochSettings::dummy())),
418 last_registered_signer: RwLock::new(None),
419 aggregator_features: RwLock::new(AggregatorFeaturesMessage::dummy()),
420 }
421 }
422 }
423
424 #[async_trait]
425 impl AggregatorClient for DumbAggregatorClient {
426 async fn retrieve_epoch_settings(
427 &self,
428 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
429 let epoch_settings = self.epoch_settings.read().await.clone();
430
431 Ok(epoch_settings)
432 }
433
434 async fn register_signer(
436 &self,
437 _epoch: Epoch,
438 signer: &Signer,
439 ) -> Result<(), AggregatorClientError> {
440 let mut last_registered_signer = self.last_registered_signer.write().await;
441 let signer = signer.clone();
442 *last_registered_signer = Some(signer);
443
444 Ok(())
445 }
446
447 async fn register_signatures(
449 &self,
450 _signed_entity_type: &SignedEntityType,
451 _signatures: &SingleSignatures,
452 _protocol_message: &ProtocolMessage,
453 ) -> Result<(), AggregatorClientError> {
454 Ok(())
455 }
456
457 async fn retrieve_aggregator_features(
458 &self,
459 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
460 let aggregator_features = self.aggregator_features.read().await;
461 Ok(aggregator_features.clone())
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use http::response::Builder as HttpResponseBuilder;
469 use httpmock::prelude::*;
470 use serde_json::json;
471
472 use mithril_common::entities::Epoch;
473 use mithril_common::era::{EraChecker, SupportedEra};
474 use mithril_common::messages::TryFromMessageAdapter;
475 use mithril_common::test_utils::{fake_data, TempDir};
476
477 use crate::test_tools::TestLogger;
478
479 use super::*;
480
481 macro_rules! assert_is_error {
482 ($error:expr, $error_type:pat) => {
483 assert!(
484 matches!($error, $error_type),
485 "Expected {} error, got '{:?}'.",
486 stringify!($error_type),
487 $error
488 );
489 };
490 }
491
492 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
493 let server = MockServer::start();
494 let aggregator_endpoint = server.url("");
495 let relay_endpoint = None;
496 let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
497 let api_version_provider = APIVersionProvider::new(Arc::new(era_checker));
498
499 (
500 server,
501 AggregatorHTTPClient::new(
502 aggregator_endpoint,
503 relay_endpoint,
504 Arc::new(api_version_provider),
505 None,
506 TestLogger::stdout(),
507 ),
508 )
509 }
510
511 fn set_returning_412(server: &MockServer) {
512 server.mock(|_, then| {
513 then.status(412)
514 .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
515 });
516 }
517
518 fn set_returning_500(server: &MockServer) {
519 server.mock(|_, then| {
520 then.status(500).body("an error occurred");
521 });
522 }
523
524 fn set_unparsable_json(server: &MockServer) {
525 server.mock(|_, then| {
526 then.status(200).body("this is not a json");
527 });
528 }
529
530 fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
531 HttpResponseBuilder::new()
532 .status(status_code)
533 .body(body.into())
534 .unwrap()
535 .into()
536 }
537
538 fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
539 HttpResponseBuilder::new()
540 .status(status_code)
541 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
542 .body(serde_json::to_string(&body).unwrap())
543 .unwrap()
544 .into()
545 }
546
547 macro_rules! assert_error_text_contains {
548 ($error: expr, $expect_contains: expr) => {
549 let error = &$error;
550 assert!(
551 error.contains($expect_contains),
552 "Expected error message to contain '{}'\ngot '{error:?}'",
553 $expect_contains,
554 );
555 };
556 }
557
558 #[tokio::test]
559 async fn test_aggregator_features_ok_200() {
560 let (server, client) = setup_server_and_client();
561 let message_expected = AggregatorFeaturesMessage::dummy();
562 let _server_mock = server.mock(|when, then| {
563 when.path("/");
564 then.status(200).body(json!(message_expected).to_string());
565 });
566
567 let message = client.retrieve_aggregator_features().await.unwrap();
568
569 assert_eq!(message_expected, message);
570 }
571
572 #[tokio::test]
573 async fn test_aggregator_features_ko_412() {
574 let (server, client) = setup_server_and_client();
575 set_returning_412(&server);
576
577 let error = client.retrieve_aggregator_features().await.unwrap_err();
578
579 assert_is_error!(error, AggregatorClientError::ApiVersionMismatch(_));
580 }
581
582 #[tokio::test]
583 async fn test_aggregator_features_ko_500() {
584 let (server, client) = setup_server_and_client();
585 set_returning_500(&server);
586
587 let error = client.retrieve_aggregator_features().await.unwrap_err();
588
589 assert_is_error!(error, AggregatorClientError::RemoteServerTechnical(_));
590 }
591
592 #[tokio::test]
593 async fn test_aggregator_features_ko_json_serialization() {
594 let (server, client) = setup_server_and_client();
595 set_unparsable_json(&server);
596
597 let error = client.retrieve_aggregator_features().await.unwrap_err();
598
599 assert_is_error!(error, AggregatorClientError::JsonParseFailed(_));
600 }
601
602 #[tokio::test]
603 async fn test_aggregator_features_timeout() {
604 let (server, mut client) = setup_server_and_client();
605 client.timeout_duration = Some(Duration::from_millis(10));
606 let _server_mock = server.mock(|when, then| {
607 when.path("/");
608 then.delay(Duration::from_millis(100));
609 });
610
611 let error = client.retrieve_aggregator_features().await.unwrap_err();
612
613 assert_is_error!(error, AggregatorClientError::RemoteServerUnreachable(_));
614 }
615
616 #[tokio::test]
617 async fn test_epoch_settings_ok_200() {
618 let (server, client) = setup_server_and_client();
619 let epoch_settings_expected = EpochSettingsMessage::dummy();
620 let _server_mock = server.mock(|when, then| {
621 when.path("/epoch-settings");
622 then.status(200)
623 .body(json!(epoch_settings_expected).to_string());
624 });
625
626 let epoch_settings = client.retrieve_epoch_settings().await;
627 epoch_settings.as_ref().expect("unexpected error");
628 assert_eq!(
629 FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
630 epoch_settings.unwrap().unwrap()
631 );
632 }
633
634 #[tokio::test]
635 async fn test_epoch_settings_ko_412() {
636 let (server, client) = setup_server_and_client();
637 let _server_mock = server.mock(|when, then| {
638 when.path("/epoch-settings");
639 then.status(412)
640 .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
641 });
642
643 let epoch_settings = client.retrieve_epoch_settings().await.unwrap_err();
644
645 assert!(epoch_settings.is_api_version_mismatch());
646 }
647
648 #[tokio::test]
649 async fn test_epoch_settings_ko_500() {
650 let (server, client) = setup_server_and_client();
651 let _server_mock = server.mock(|when, then| {
652 when.path("/epoch-settings");
653 then.status(500).body("an error occurred");
654 });
655
656 match client.retrieve_epoch_settings().await.unwrap_err() {
657 AggregatorClientError::RemoteServerTechnical(_) => (),
658 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
659 };
660 }
661
662 #[tokio::test]
663 async fn test_epoch_settings_timeout() {
664 let (server, mut client) = setup_server_and_client();
665 client.timeout_duration = Some(Duration::from_millis(10));
666 let _server_mock = server.mock(|when, then| {
667 when.path("/epoch-settings");
668 then.delay(Duration::from_millis(100));
669 });
670
671 let error = client
672 .retrieve_epoch_settings()
673 .await
674 .expect_err("retrieve_epoch_settings should fail");
675
676 assert!(
677 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
678 "unexpected error type: {error:?}"
679 );
680 }
681
682 #[tokio::test]
683 async fn test_register_signer_ok_201() {
684 let epoch = Epoch(1);
685 let single_signers = fake_data::signers(1);
686 let single_signer = single_signers.first().unwrap();
687 let (server, client) = setup_server_and_client();
688 let _server_mock = server.mock(|when, then| {
689 when.method(POST).path("/register-signer");
690 then.status(201);
691 });
692
693 let register_signer = client.register_signer(epoch, single_signer).await;
694 register_signer.expect("unexpected error");
695 }
696
697 #[tokio::test]
698 async fn test_register_signer_ko_412() {
699 let epoch = Epoch(1);
700 let (server, client) = setup_server_and_client();
701 let _server_mock = server.mock(|when, then| {
702 when.method(POST).path("/register-signer");
703 then.status(412)
704 .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
705 });
706 let single_signers = fake_data::signers(1);
707 let single_signer = single_signers.first().unwrap();
708
709 let error = client
710 .register_signer(epoch, single_signer)
711 .await
712 .unwrap_err();
713
714 assert!(error.is_api_version_mismatch());
715 }
716
717 #[tokio::test]
718 async fn test_register_signer_ko_400() {
719 let epoch = Epoch(1);
720 let single_signers = fake_data::signers(1);
721 let single_signer = single_signers.first().unwrap();
722 let (server, client) = setup_server_and_client();
723 let _server_mock = server.mock(|when, then| {
724 when.method(POST).path("/register-signer");
725 then.status(400).body(
726 serde_json::to_vec(&ClientError::new(
727 "error".to_string(),
728 "an error".to_string(),
729 ))
730 .unwrap(),
731 );
732 });
733
734 match client
735 .register_signer(epoch, single_signer)
736 .await
737 .unwrap_err()
738 {
739 AggregatorClientError::RemoteServerLogical(_) => (),
740 err => {
741 panic!(
742 "Expected a AggregatorClientError::RemoteServerLogical error, got '{err:?}'."
743 )
744 }
745 };
746 }
747
748 #[tokio::test]
749 async fn test_register_signer_ko_500() {
750 let epoch = Epoch(1);
751 let single_signers = fake_data::signers(1);
752 let single_signer = single_signers.first().unwrap();
753 let (server, client) = setup_server_and_client();
754 let _server_mock = server.mock(|when, then| {
755 when.method(POST).path("/register-signer");
756 then.status(500).body("an error occurred");
757 });
758
759 match client
760 .register_signer(epoch, single_signer)
761 .await
762 .unwrap_err()
763 {
764 AggregatorClientError::RemoteServerTechnical(_) => (),
765 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
766 };
767 }
768
769 #[tokio::test]
770 async fn test_register_signer_timeout() {
771 let epoch = Epoch(1);
772 let single_signers = fake_data::signers(1);
773 let single_signer = single_signers.first().unwrap();
774 let (server, mut client) = setup_server_and_client();
775 client.timeout_duration = Some(Duration::from_millis(10));
776 let _server_mock = server.mock(|when, then| {
777 when.method(POST).path("/register-signer");
778 then.delay(Duration::from_millis(100));
779 });
780
781 let error = client
782 .register_signer(epoch, single_signer)
783 .await
784 .expect_err("register_signer should fail");
785
786 assert!(
787 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
788 "unexpected error type: {error:?}"
789 );
790 }
791
792 #[tokio::test]
793 async fn test_register_signatures_ok_201() {
794 let single_signatures = fake_data::single_signatures((1..5).collect());
795 let (server, client) = setup_server_and_client();
796 let _server_mock = server.mock(|when, then| {
797 when.method(POST).path("/register-signatures");
798 then.status(201);
799 });
800
801 let register_signatures = client
802 .register_signatures(
803 &SignedEntityType::dummy(),
804 &single_signatures,
805 &ProtocolMessage::default(),
806 )
807 .await;
808 register_signatures.expect("unexpected error");
809 }
810
811 #[tokio::test]
812 async fn test_register_signatures_ok_202() {
813 let single_signatures = fake_data::single_signatures((1..5).collect());
814 let (server, client) = setup_server_and_client();
815 let _server_mock = server.mock(|when, then| {
816 when.method(POST).path("/register-signatures");
817 then.status(202);
818 });
819
820 let register_signatures = client
821 .register_signatures(
822 &SignedEntityType::dummy(),
823 &single_signatures,
824 &ProtocolMessage::default(),
825 )
826 .await;
827 register_signatures.expect("unexpected error");
828 }
829
830 #[tokio::test]
831 async fn test_register_signatures_ko_412() {
832 let (server, client) = setup_server_and_client();
833 let _server_mock = server.mock(|when, then| {
834 when.method(POST).path("/register-signatures");
835 then.status(412)
836 .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
837 });
838 let single_signatures = fake_data::single_signatures((1..5).collect());
839
840 let error = client
841 .register_signatures(
842 &SignedEntityType::dummy(),
843 &single_signatures,
844 &ProtocolMessage::default(),
845 )
846 .await
847 .unwrap_err();
848
849 assert!(error.is_api_version_mismatch());
850 }
851
852 #[tokio::test]
853 async fn test_register_signatures_ko_400() {
854 let single_signatures = fake_data::single_signatures((1..5).collect());
855 let (server, client) = setup_server_and_client();
856 let _server_mock = server.mock(|when, then| {
857 when.method(POST).path("/register-signatures");
858 then.status(400).body(
859 serde_json::to_vec(&ClientError::new(
860 "error".to_string(),
861 "an error".to_string(),
862 ))
863 .unwrap(),
864 );
865 });
866
867 match client
868 .register_signatures(
869 &SignedEntityType::dummy(),
870 &single_signatures,
871 &ProtocolMessage::default(),
872 )
873 .await
874 .unwrap_err()
875 {
876 AggregatorClientError::RemoteServerLogical(_) => (),
877 e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
878 };
879 }
880
881 #[tokio::test]
882 async fn test_register_signatures_ok_410_log_response_body() {
883 let log_path = TempDir::create(
884 "aggregator_client",
885 "test_register_signatures_ok_410_log_response_body",
886 )
887 .join("test.log");
888
889 let single_signatures = fake_data::single_signatures((1..5).collect());
890 {
891 let (server, mut client) = setup_server_and_client();
892 client.logger = TestLogger::file(&log_path);
893 let _server_mock = server.mock(|when, then| {
894 when.method(POST).path("/register-signatures");
895 then.status(410).body(
896 serde_json::to_vec(&ClientError::new(
897 "already_aggregated".to_string(),
898 "too late".to_string(),
899 ))
900 .unwrap(),
901 );
902 });
903
904 client
905 .register_signatures(
906 &SignedEntityType::dummy(),
907 &single_signatures,
908 &ProtocolMessage::default(),
909 )
910 .await
911 .expect("Should not fail when status is 410 (GONE)");
912 }
913
914 let logs = std::fs::read_to_string(&log_path).unwrap();
915 assert!(logs.contains("already_aggregated"));
916 assert!(logs.contains("too late"));
917 }
918
919 #[tokio::test]
920 async fn test_register_signatures_ko_409() {
921 let single_signatures = fake_data::single_signatures((1..5).collect());
922 let (server, client) = setup_server_and_client();
923 let _server_mock = server.mock(|when, then| {
924 when.method(POST).path("/register-signatures");
925 then.status(409);
926 });
927
928 match client
929 .register_signatures(
930 &SignedEntityType::dummy(),
931 &single_signatures,
932 &ProtocolMessage::default(),
933 )
934 .await
935 .unwrap_err()
936 {
937 AggregatorClientError::RemoteServerLogical(_) => (),
938 e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
939 }
940 }
941
942 #[tokio::test]
943 async fn test_register_signatures_ko_500() {
944 let single_signatures = fake_data::single_signatures((1..5).collect());
945 let (server, client) = setup_server_and_client();
946 let _server_mock = server.mock(|when, then| {
947 when.method(POST).path("/register-signatures");
948 then.status(500).body("an error occurred");
949 });
950
951 match client
952 .register_signatures(
953 &SignedEntityType::dummy(),
954 &single_signatures,
955 &ProtocolMessage::default(),
956 )
957 .await
958 .unwrap_err()
959 {
960 AggregatorClientError::RemoteServerTechnical(_) => (),
961 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
962 };
963 }
964
965 #[tokio::test]
966 async fn test_register_signatures_timeout() {
967 let single_signatures = fake_data::single_signatures((1..5).collect());
968 let (server, mut client) = setup_server_and_client();
969 client.timeout_duration = Some(Duration::from_millis(10));
970 let _server_mock = server.mock(|when, then| {
971 when.method(POST).path("/register-signatures");
972 then.delay(Duration::from_millis(100));
973 });
974
975 let error = client
976 .register_signatures(
977 &SignedEntityType::dummy(),
978 &single_signatures,
979 &ProtocolMessage::default(),
980 )
981 .await
982 .expect_err("register_signatures should fail");
983
984 assert!(
985 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
986 "unexpected error type: {error:?}"
987 );
988 }
989
990 #[tokio::test]
991 async fn test_4xx_errors_are_handled_as_remote_server_logical() {
992 let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
993 let handled_error = AggregatorClientError::from_response(response).await;
994
995 assert!(
996 matches!(
997 handled_error,
998 AggregatorClientError::RemoteServerLogical(..)
999 ),
1000 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
1001 );
1002 }
1003
1004 #[tokio::test]
1005 async fn test_5xx_errors_are_handled_as_remote_server_technical() {
1006 let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
1007 let handled_error = AggregatorClientError::from_response(response).await;
1008
1009 assert!(
1010 matches!(
1011 handled_error,
1012 AggregatorClientError::RemoteServerTechnical(..)
1013 ),
1014 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
1015 );
1016 }
1017
1018 #[tokio::test]
1019 async fn test_550_error_is_handled_as_registration_round_not_yet_opened() {
1020 let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available");
1021 let handled_error = AggregatorClientError::from_response(response).await;
1022
1023 assert!(
1024 matches!(
1025 handled_error,
1026 AggregatorClientError::RegistrationRoundNotYetOpened(..)
1027 ),
1028 "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'",
1029 );
1030 }
1031
1032 #[tokio::test]
1033 async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text(
1034 ) {
1035 let response = build_text_response(StatusCode::OK, "ok text");
1036 let handled_error = AggregatorClientError::from_response(response).await;
1037
1038 assert!(
1039 matches!(
1040 handled_error,
1041 AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
1042 ),
1043 "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
1044 );
1045 }
1046
1047 #[tokio::test]
1048 async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
1049 let error_text = "An error occurred; please try again later.";
1050 let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
1051
1052 assert_error_text_contains!(
1053 AggregatorClientError::get_root_cause(response).await,
1054 "expectation failed: An error occurred; please try again later."
1055 );
1056 }
1057
1058 #[tokio::test]
1059 async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message(
1060 ) {
1061 let client_error = ClientError::new("label", "message");
1062 let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
1063
1064 assert_error_text_contains!(
1065 AggregatorClientError::get_root_cause(response).await,
1066 "bad request: label: message"
1067 );
1068 }
1069
1070 #[tokio::test]
1071 async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message(
1072 ) {
1073 let server_error = ServerError::new("message");
1074 let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
1075
1076 assert_error_text_contains!(
1077 AggregatorClientError::get_root_cause(response).await,
1078 "bad request: message"
1079 );
1080 }
1081
1082 #[tokio::test]
1083 async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
1084 let response = build_json_response(
1085 StatusCode::INTERNAL_SERVER_ERROR,
1086 &json!({ "second": "unknown", "first": "foreign" }),
1087 );
1088
1089 assert_error_text_contains!(
1090 AggregatorClientError::get_root_cause(response).await,
1091 r#"internal server error: {"first":"foreign","second":"unknown"}"#
1092 );
1093 }
1094
1095 #[tokio::test]
1096 async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
1097 let response = HttpResponseBuilder::new()
1098 .status(StatusCode::BAD_REQUEST)
1099 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
1100 .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
1101 .unwrap()
1102 .into();
1103
1104 let root_cause = AggregatorClientError::get_root_cause(response).await;
1105
1106 assert_error_text_contains!(root_cause, "bad request");
1107 assert!(
1108 !root_cause.contains("bad request: "),
1109 "Expected error message should not contain additional information \ngot '{root_cause:?}'"
1110 );
1111 }
1112
1113 #[tokio::test]
1114 async fn test_sends_accept_encoding_header() {
1115 let (server, client) = setup_server_and_client();
1116 server.mock(|when, then| {
1117 when.matches(|req| {
1118 let headers = req.headers.clone().expect("HTTP headers not found");
1119 let accept_encoding_header = headers
1120 .iter()
1121 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
1122 .expect("Accept-Encoding header not found");
1123
1124 let header_value = accept_encoding_header.clone().1;
1125 ["gzip", "br", "deflate", "zstd"]
1126 .iter()
1127 .all(|&value| header_value.contains(value))
1128 });
1129
1130 then.status(201);
1131 });
1132
1133 client
1134 .register_signatures(
1135 &SignedEntityType::dummy(),
1136 &fake_data::single_signatures((1..5).collect()),
1137 &ProtocolMessage::default(),
1138 )
1139 .await
1140 .expect("Should succeed with Accept-Encoding header");
1141 }
1142}