1use anyhow::anyhow;
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
5use semver::Version;
6use slog::{Logger, debug, error, warn};
7use std::{io, sync::Arc, time::Duration};
8use thiserror::Error;
9
10use mithril_common::{
11 MITHRIL_API_VERSION_HEADER, MITHRIL_SIGNER_VERSION_HEADER, StdError,
12 api_version::APIVersionProvider,
13 entities::{
14 ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer, SingleSignature,
15 },
16 logging::LoggerExtensions,
17 messages::{
18 AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
19 },
20};
21
22use crate::entities::SignerEpochSettings;
23use crate::message_adapters::{
24 FromEpochSettingsAdapter, ToRegisterSignatureMessageAdapter, ToRegisterSignerMessageAdapter,
25};
26
27const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
28
29const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
30 "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
31
32#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35 #[error("remote server technical error")]
37 RemoteServerTechnical(#[source] StdError),
38
39 #[error("remote server logical error")]
41 RemoteServerLogical(#[source] StdError),
42
43 #[error("remote server unreachable")]
45 RemoteServerUnreachable(#[source] StdError),
46
47 #[error("unhandled status code: {0}, response text: {1}")]
49 UnhandledStatusCode(StatusCode, String),
50
51 #[error("json parsing failed")]
53 JsonParseFailed(#[source] StdError),
54
55 #[error("Input/Output error")]
57 IOError(#[from] io::Error),
58
59 #[error("HTTP client creation failed")]
61 HTTPClientCreation(#[source] StdError),
62
63 #[error("proxy creation failed")]
65 ProxyCreation(#[source] StdError),
66
67 #[error("adapter failed")]
69 Adapter(#[source] StdError),
70
71 #[error("a signer registration round is not opened yet, please try again later")]
73 RegistrationRoundNotYetOpened(#[source] StdError),
74}
75
76impl AggregatorClientError {
77 pub async fn from_response(response: Response) -> Self {
83 let error_code = response.status();
84
85 if error_code.is_client_error() {
86 let root_cause = Self::get_root_cause(response).await;
87 Self::RemoteServerLogical(anyhow!(root_cause))
88 } else if error_code.is_server_error() {
89 let root_cause = Self::get_root_cause(response).await;
90 match error_code.as_u16() {
91 550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)),
92 _ => Self::RemoteServerTechnical(anyhow!(root_cause)),
93 }
94 } else {
95 let response_text = response.text().await.unwrap_or_default();
96 Self::UnhandledStatusCode(error_code, response_text)
97 }
98 }
99
100 async fn get_root_cause(response: Response) -> String {
101 let error_code = response.status();
102 let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
103 let is_json = response
104 .headers()
105 .get(header::CONTENT_TYPE)
106 .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
107
108 if is_json {
109 let json_value: serde_json::Value = response.json().await.unwrap_or_default();
110
111 if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
112 format!(
113 "{}: {}: {}",
114 canonical_reason, client_error.label, client_error.message
115 )
116 } else if let Ok(server_error) =
117 serde_json::from_value::<ServerError>(json_value.clone())
118 {
119 format!("{}: {}", canonical_reason, server_error.message)
120 } else if json_value.is_null() {
121 canonical_reason.to_string()
122 } else {
123 format!("{canonical_reason}: {json_value}")
124 }
125 } else {
126 let response_text = response.text().await.unwrap_or_default();
127 format!("{canonical_reason}: {response_text}")
128 }
129 }
130}
131
132#[cfg_attr(test, mockall::automock)]
134#[async_trait]
135pub trait AggregatorClient: Sync + Send {
136 async fn retrieve_epoch_settings(
138 &self,
139 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError>;
140
141 async fn register_signer(
143 &self,
144 epoch: Epoch,
145 signer: &Signer,
146 ) -> Result<(), AggregatorClientError>;
147
148 async fn register_signature(
150 &self,
151 signed_entity_type: &SignedEntityType,
152 signature: &SingleSignature,
153 protocol_message: &ProtocolMessage,
154 ) -> Result<(), AggregatorClientError>;
155
156 async fn retrieve_aggregator_features(
158 &self,
159 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError>;
160}
161
162pub struct AggregatorHTTPClient {
164 aggregator_endpoint: String,
165 relay_endpoint: Option<String>,
166 api_version_provider: Arc<APIVersionProvider>,
167 timeout_duration: Option<Duration>,
168 logger: Logger,
169}
170
171impl AggregatorHTTPClient {
172 pub fn new(
174 aggregator_endpoint: String,
175 relay_endpoint: Option<String>,
176 api_version_provider: Arc<APIVersionProvider>,
177 timeout_duration: Option<Duration>,
178 logger: Logger,
179 ) -> Self {
180 let logger = logger.new_with_component_name::<Self>();
181 debug!(logger, "New AggregatorHTTPClient created");
182 Self {
183 aggregator_endpoint,
184 relay_endpoint,
185 api_version_provider,
186 timeout_duration,
187 logger,
188 }
189 }
190
191 fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
192 let client = match &self.relay_endpoint {
193 Some(relay_endpoint) => Client::builder()
194 .proxy(
195 Proxy::all(relay_endpoint)
196 .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
197 )
198 .build()
199 .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
200 None => Client::new(),
201 };
202
203 Ok(client)
204 }
205
206 pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
208 let request_builder = request_builder
209 .header(
210 MITHRIL_API_VERSION_HEADER,
211 self.api_version_provider
212 .compute_current_version()
213 .unwrap()
214 .to_string(),
215 )
216 .header(MITHRIL_SIGNER_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
217
218 if let Some(duration) = self.timeout_duration {
219 request_builder.timeout(duration)
220 } else {
221 request_builder
222 }
223 }
224
225 fn warn_if_api_version_mismatch(&self, response: &Response) {
227 let aggregator_version = response
228 .headers()
229 .get(MITHRIL_API_VERSION_HEADER)
230 .and_then(|v| v.to_str().ok())
231 .and_then(|s| Version::parse(s).ok());
232
233 let signer_version = self.api_version_provider.compute_current_version();
234
235 match (aggregator_version, signer_version) {
236 (Some(aggregator), Ok(signer)) if signer < aggregator => {
237 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
238 "aggregator_version" => %aggregator,
239 "signer_version" => %signer,
240 );
241 }
242 (Some(_), Err(error)) => {
243 error!(
244 self.logger,
245 "Failed to compute the current signer API version";
246 "error" => error.to_string()
247 );
248 }
249 _ => {}
250 }
251 }
252}
253
254#[async_trait]
255impl AggregatorClient for AggregatorHTTPClient {
256 async fn retrieve_epoch_settings(
257 &self,
258 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
259 debug!(self.logger, "Retrieve epoch settings");
260 let url = format!("{}/epoch-settings", self.aggregator_endpoint);
261 let response = self
262 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
263 .send()
264 .await;
265
266 match response {
267 Ok(response) => match response.status() {
268 StatusCode::OK => {
269 self.warn_if_api_version_mismatch(&response);
270 match response.json::<EpochSettingsMessage>().await {
271 Ok(message) => {
272 let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
273 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
274 Ok(Some(epoch_settings))
275 }
276 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
277 }
278 }
279 _ => Err(AggregatorClientError::from_response(response).await),
280 },
281 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
282 }
283 }
284
285 async fn register_signer(
286 &self,
287 epoch: Epoch,
288 signer: &Signer,
289 ) -> Result<(), AggregatorClientError> {
290 debug!(self.logger, "Register signer");
291 let url = format!("{}/register-signer", self.aggregator_endpoint);
292 let register_signer_message =
293 ToRegisterSignerMessageAdapter::try_adapt((epoch, signer.to_owned()))
294 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
295 let response = self
296 .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
297 .json(®ister_signer_message)
298 .send()
299 .await;
300
301 match response {
302 Ok(response) => match response.status() {
303 StatusCode::CREATED => {
304 self.warn_if_api_version_mismatch(&response);
305
306 Ok(())
307 }
308 _ => Err(AggregatorClientError::from_response(response).await),
309 },
310 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
311 }
312 }
313
314 async fn register_signature(
315 &self,
316 signed_entity_type: &SignedEntityType,
317 signature: &SingleSignature,
318 protocol_message: &ProtocolMessage,
319 ) -> Result<(), AggregatorClientError> {
320 debug!(self.logger, "Register signature");
321 let url = format!("{}/register-signatures", self.aggregator_endpoint);
322 let register_single_signature_message = ToRegisterSignatureMessageAdapter::try_adapt((
323 signed_entity_type.to_owned(),
324 signature.to_owned(),
325 protocol_message,
326 ))
327 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
328 let response = self
329 .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
330 .json(®ister_single_signature_message)
331 .send()
332 .await;
333
334 match response {
335 Ok(response) => match response.status() {
336 StatusCode::CREATED | StatusCode::ACCEPTED => {
337 self.warn_if_api_version_mismatch(&response);
338
339 Ok(())
340 }
341 StatusCode::GONE => {
342 self.warn_if_api_version_mismatch(&response);
343 let root_cause = AggregatorClientError::get_root_cause(response).await;
344 debug!(self.logger, "Message already certified or expired"; "details" => &root_cause);
345
346 Ok(())
347 }
348 _ => Err(AggregatorClientError::from_response(response).await),
349 },
350 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
351 }
352 }
353
354 async fn retrieve_aggregator_features(
355 &self,
356 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
357 debug!(self.logger, "Retrieve aggregator features message");
358 let url = format!("{}/", self.aggregator_endpoint);
359 let response = self
360 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
361 .send()
362 .await;
363
364 match response {
365 Ok(response) => match response.status() {
366 StatusCode::OK => {
367 self.warn_if_api_version_mismatch(&response);
368
369 Ok(response
370 .json::<AggregatorFeaturesMessage>()
371 .await
372 .map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?)
373 }
374 _ => Err(AggregatorClientError::from_response(response).await),
375 },
376 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
377 }
378 }
379}
380
381#[cfg(test)]
382pub(crate) mod dumb {
383 use mithril_common::test::double::Dummy;
384 use tokio::sync::RwLock;
385
386 use super::*;
387
388 pub struct DumbAggregatorClient {
392 epoch_settings: RwLock<Option<SignerEpochSettings>>,
393 last_registered_signer: RwLock<Option<Signer>>,
394 aggregator_features: RwLock<AggregatorFeaturesMessage>,
395 }
396
397 impl DumbAggregatorClient {
398 pub async fn get_last_registered_signer(&self) -> Option<Signer> {
400 self.last_registered_signer.read().await.clone()
401 }
402 }
403
404 impl Default for DumbAggregatorClient {
405 fn default() -> Self {
406 Self {
407 epoch_settings: RwLock::new(Some(SignerEpochSettings::dummy())),
408 last_registered_signer: RwLock::new(None),
409 aggregator_features: RwLock::new(AggregatorFeaturesMessage::dummy()),
410 }
411 }
412 }
413
414 #[async_trait]
415 impl AggregatorClient for DumbAggregatorClient {
416 async fn retrieve_epoch_settings(
417 &self,
418 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
419 let epoch_settings = self.epoch_settings.read().await.clone();
420
421 Ok(epoch_settings)
422 }
423
424 async fn register_signer(
426 &self,
427 _epoch: Epoch,
428 signer: &Signer,
429 ) -> Result<(), AggregatorClientError> {
430 let mut last_registered_signer = self.last_registered_signer.write().await;
431 let signer = signer.clone();
432 *last_registered_signer = Some(signer);
433
434 Ok(())
435 }
436
437 async fn register_signature(
439 &self,
440 _signed_entity_type: &SignedEntityType,
441 _signature: &SingleSignature,
442 _protocol_message: &ProtocolMessage,
443 ) -> Result<(), AggregatorClientError> {
444 Ok(())
445 }
446
447 async fn retrieve_aggregator_features(
448 &self,
449 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
450 let aggregator_features = self.aggregator_features.read().await;
451 Ok(aggregator_features.clone())
452 }
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use std::collections::HashMap;
459
460 use http::response::Builder as HttpResponseBuilder;
461 use httpmock::prelude::*;
462 use semver::Version;
463 use serde_json::json;
464
465 use mithril_common::entities::Epoch;
466 use mithril_common::messages::TryFromMessageAdapter;
467 use mithril_common::test::{
468 double::{Dummy, DummyApiVersionDiscriminantSource, fake_data},
469 logging::MemoryDrainForTestInspector,
470 };
471
472 use crate::test_tools::TestLogger;
473
474 use super::*;
475
476 macro_rules! assert_is_error {
477 ($error:expr, $error_type:pat) => {
478 assert!(
479 matches!($error, $error_type),
480 "Expected {} error, got '{:?}'.",
481 stringify!($error_type),
482 $error
483 );
484 };
485 }
486
487 fn setup_client<U: Into<String>>(server_url: U) -> AggregatorHTTPClient {
488 let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
489 let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
490
491 AggregatorHTTPClient::new(
492 server_url.into(),
493 None,
494 Arc::new(api_version_provider),
495 None,
496 TestLogger::stdout(),
497 )
498 }
499
500 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
501 let server = MockServer::start();
502 let aggregator_endpoint = server.url("");
503 let client = setup_client(&aggregator_endpoint);
504
505 (server, client)
506 }
507
508 fn set_returning_500(server: &MockServer) {
509 server.mock(|_, then| {
510 then.status(500).body("an error occurred");
511 });
512 }
513
514 fn set_unparsable_json(server: &MockServer) {
515 server.mock(|_, then| {
516 then.status(200).body("this is not a json");
517 });
518 }
519
520 fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
521 HttpResponseBuilder::new()
522 .status(status_code)
523 .body(body.into())
524 .unwrap()
525 .into()
526 }
527
528 fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
529 HttpResponseBuilder::new()
530 .status(status_code)
531 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
532 .body(serde_json::to_string(&body).unwrap())
533 .unwrap()
534 .into()
535 }
536
537 macro_rules! assert_error_text_contains {
538 ($error: expr, $expect_contains: expr) => {
539 let error = &$error;
540 assert!(
541 error.contains($expect_contains),
542 "Expected error message to contain '{}'\ngot '{error:?}'",
543 $expect_contains,
544 );
545 };
546 }
547
548 #[tokio::test]
549 async fn test_aggregator_features_ok_200() {
550 let (server, client) = setup_server_and_client();
551 let message_expected = AggregatorFeaturesMessage::dummy();
552 let _server_mock = server.mock(|when, then| {
553 when.path("/");
554 then.status(200).body(json!(message_expected).to_string());
555 });
556
557 let message = client.retrieve_aggregator_features().await.unwrap();
558
559 assert_eq!(message_expected, message);
560 }
561
562 #[tokio::test]
563 async fn test_aggregator_features_ko_500() {
564 let (server, client) = setup_server_and_client();
565 set_returning_500(&server);
566
567 let error = client.retrieve_aggregator_features().await.unwrap_err();
568
569 assert_is_error!(error, AggregatorClientError::RemoteServerTechnical(_));
570 }
571
572 #[tokio::test]
573 async fn test_aggregator_features_ko_json_serialization() {
574 let (server, client) = setup_server_and_client();
575 set_unparsable_json(&server);
576
577 let error = client.retrieve_aggregator_features().await.unwrap_err();
578
579 assert_is_error!(error, AggregatorClientError::JsonParseFailed(_));
580 }
581
582 #[tokio::test]
583 async fn test_aggregator_features_timeout() {
584 let (server, mut client) = setup_server_and_client();
585 client.timeout_duration = Some(Duration::from_millis(10));
586 let _server_mock = server.mock(|when, then| {
587 when.path("/");
588 then.delay(Duration::from_millis(100));
589 });
590
591 let error = client.retrieve_aggregator_features().await.unwrap_err();
592
593 assert_is_error!(error, AggregatorClientError::RemoteServerUnreachable(_));
594 }
595
596 #[tokio::test]
597 async fn test_epoch_settings_ok_200() {
598 let (server, client) = setup_server_and_client();
599 let epoch_settings_expected = EpochSettingsMessage::dummy();
600 let _server_mock = server.mock(|when, then| {
601 when.path("/epoch-settings");
602 then.status(200).body(json!(epoch_settings_expected).to_string());
603 });
604
605 let epoch_settings = client.retrieve_epoch_settings().await;
606 epoch_settings.as_ref().expect("unexpected error");
607 assert_eq!(
608 FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
609 epoch_settings.unwrap().unwrap()
610 );
611 }
612
613 #[tokio::test]
614 async fn test_epoch_settings_ko_500() {
615 let (server, client) = setup_server_and_client();
616 let _server_mock = server.mock(|when, then| {
617 when.path("/epoch-settings");
618 then.status(500).body("an error occurred");
619 });
620
621 match client.retrieve_epoch_settings().await.unwrap_err() {
622 AggregatorClientError::RemoteServerTechnical(_) => (),
623 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
624 };
625 }
626
627 #[tokio::test]
628 async fn test_epoch_settings_timeout() {
629 let (server, mut client) = setup_server_and_client();
630 client.timeout_duration = Some(Duration::from_millis(10));
631 let _server_mock = server.mock(|when, then| {
632 when.path("/epoch-settings");
633 then.delay(Duration::from_millis(100));
634 });
635
636 let error = client
637 .retrieve_epoch_settings()
638 .await
639 .expect_err("retrieve_epoch_settings should fail");
640
641 assert!(
642 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
643 "unexpected error type: {error:?}"
644 );
645 }
646
647 #[tokio::test]
648 async fn test_register_signer_ok_201() {
649 let epoch = Epoch(1);
650 let single_signers = fake_data::signers(1);
651 let single_signer = single_signers.first().unwrap();
652 let (server, client) = setup_server_and_client();
653 let _server_mock = server.mock(|when, then| {
654 when.method(POST).path("/register-signer");
655 then.status(201);
656 });
657
658 let register_signer = client.register_signer(epoch, single_signer).await;
659 register_signer.expect("unexpected error");
660 }
661
662 #[tokio::test]
663 async fn test_register_signer_ko_400() {
664 let epoch = Epoch(1);
665 let single_signers = fake_data::signers(1);
666 let single_signer = single_signers.first().unwrap();
667 let (server, client) = setup_server_and_client();
668 let _server_mock = server.mock(|when, then| {
669 when.method(POST).path("/register-signer");
670 then.status(400).body(
671 serde_json::to_vec(&ClientError::new(
672 "error".to_string(),
673 "an error".to_string(),
674 ))
675 .unwrap(),
676 );
677 });
678
679 match client.register_signer(epoch, single_signer).await.unwrap_err() {
680 AggregatorClientError::RemoteServerLogical(_) => (),
681 err => {
682 panic!(
683 "Expected a AggregatorClientError::RemoteServerLogical error, got '{err:?}'."
684 )
685 }
686 };
687 }
688
689 #[tokio::test]
690 async fn test_register_signer_ko_500() {
691 let epoch = Epoch(1);
692 let single_signers = fake_data::signers(1);
693 let single_signer = single_signers.first().unwrap();
694 let (server, client) = setup_server_and_client();
695 let _server_mock = server.mock(|when, then| {
696 when.method(POST).path("/register-signer");
697 then.status(500).body("an error occurred");
698 });
699
700 match client.register_signer(epoch, single_signer).await.unwrap_err() {
701 AggregatorClientError::RemoteServerTechnical(_) => (),
702 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
703 };
704 }
705
706 #[tokio::test]
707 async fn test_register_signer_timeout() {
708 let epoch = Epoch(1);
709 let single_signers = fake_data::signers(1);
710 let single_signer = single_signers.first().unwrap();
711 let (server, mut client) = setup_server_and_client();
712 client.timeout_duration = Some(Duration::from_millis(10));
713 let _server_mock = server.mock(|when, then| {
714 when.method(POST).path("/register-signer");
715 then.delay(Duration::from_millis(100));
716 });
717
718 let error = client
719 .register_signer(epoch, single_signer)
720 .await
721 .expect_err("register_signer should fail");
722
723 assert!(
724 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
725 "unexpected error type: {error:?}"
726 );
727 }
728
729 #[tokio::test]
730 async fn test_register_signature_ok_201() {
731 let single_signature = fake_data::single_signature((1..5).collect());
732 let (server, client) = setup_server_and_client();
733 let _server_mock = server.mock(|when, then| {
734 when.method(POST).path("/register-signatures");
735 then.status(201);
736 });
737
738 let register_signature = client
739 .register_signature(
740 &SignedEntityType::dummy(),
741 &single_signature,
742 &ProtocolMessage::default(),
743 )
744 .await;
745 register_signature.expect("unexpected error");
746 }
747
748 #[tokio::test]
749 async fn test_register_signature_ok_202() {
750 let single_signature = fake_data::single_signature((1..5).collect());
751 let (server, client) = setup_server_and_client();
752 let _server_mock = server.mock(|when, then| {
753 when.method(POST).path("/register-signatures");
754 then.status(202);
755 });
756
757 let register_signature = client
758 .register_signature(
759 &SignedEntityType::dummy(),
760 &single_signature,
761 &ProtocolMessage::default(),
762 )
763 .await;
764 register_signature.expect("unexpected error");
765 }
766
767 #[tokio::test]
768 async fn test_register_signature_ko_400() {
769 let single_signature = fake_data::single_signature((1..5).collect());
770 let (server, client) = setup_server_and_client();
771 let _server_mock = server.mock(|when, then| {
772 when.method(POST).path("/register-signatures");
773 then.status(400).body(
774 serde_json::to_vec(&ClientError::new(
775 "error".to_string(),
776 "an error".to_string(),
777 ))
778 .unwrap(),
779 );
780 });
781
782 match client
783 .register_signature(
784 &SignedEntityType::dummy(),
785 &single_signature,
786 &ProtocolMessage::default(),
787 )
788 .await
789 .unwrap_err()
790 {
791 AggregatorClientError::RemoteServerLogical(_) => (),
792 e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
793 };
794 }
795
796 #[tokio::test]
797 async fn test_register_signature_ok_410_log_response_body() {
798 let (logger, log_inspector) = TestLogger::memory();
799
800 let single_signature = fake_data::single_signature((1..5).collect());
801 let (server, mut client) = setup_server_and_client();
802 client.logger = logger;
803 let _server_mock = server.mock(|when, then| {
804 when.method(POST).path("/register-signatures");
805 then.status(410).body(
806 serde_json::to_vec(&ClientError::new(
807 "already_aggregated".to_string(),
808 "too late".to_string(),
809 ))
810 .unwrap(),
811 );
812 });
813
814 client
815 .register_signature(
816 &SignedEntityType::dummy(),
817 &single_signature,
818 &ProtocolMessage::default(),
819 )
820 .await
821 .expect("Should not fail when status is 410 (GONE)");
822
823 assert!(log_inspector.contains_log("already_aggregated"));
824 assert!(log_inspector.contains_log("too late"));
825 }
826
827 #[tokio::test]
828 async fn test_register_signature_ko_409() {
829 let single_signature = fake_data::single_signature((1..5).collect());
830 let (server, client) = setup_server_and_client();
831 let _server_mock = server.mock(|when, then| {
832 when.method(POST).path("/register-signatures");
833 then.status(409);
834 });
835
836 match client
837 .register_signature(
838 &SignedEntityType::dummy(),
839 &single_signature,
840 &ProtocolMessage::default(),
841 )
842 .await
843 .unwrap_err()
844 {
845 AggregatorClientError::RemoteServerLogical(_) => (),
846 e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
847 }
848 }
849
850 #[tokio::test]
851 async fn test_register_signature_ko_500() {
852 let single_signature = fake_data::single_signature((1..5).collect());
853 let (server, client) = setup_server_and_client();
854 let _server_mock = server.mock(|when, then| {
855 when.method(POST).path("/register-signatures");
856 then.status(500).body("an error occurred");
857 });
858
859 match client
860 .register_signature(
861 &SignedEntityType::dummy(),
862 &single_signature,
863 &ProtocolMessage::default(),
864 )
865 .await
866 .unwrap_err()
867 {
868 AggregatorClientError::RemoteServerTechnical(_) => (),
869 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
870 };
871 }
872
873 #[tokio::test]
874 async fn test_register_signature_timeout() {
875 let single_signature = fake_data::single_signature((1..5).collect());
876 let (server, mut client) = setup_server_and_client();
877 client.timeout_duration = Some(Duration::from_millis(10));
878 let _server_mock = server.mock(|when, then| {
879 when.method(POST).path("/register-signatures");
880 then.delay(Duration::from_millis(100));
881 });
882
883 let error = client
884 .register_signature(
885 &SignedEntityType::dummy(),
886 &single_signature,
887 &ProtocolMessage::default(),
888 )
889 .await
890 .expect_err("register_signature should fail");
891
892 assert!(
893 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
894 "unexpected error type: {error:?}"
895 );
896 }
897
898 #[tokio::test]
899 async fn test_4xx_errors_are_handled_as_remote_server_logical() {
900 let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
901 let handled_error = AggregatorClientError::from_response(response).await;
902
903 assert!(
904 matches!(
905 handled_error,
906 AggregatorClientError::RemoteServerLogical(..)
907 ),
908 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
909 );
910 }
911
912 #[tokio::test]
913 async fn test_5xx_errors_are_handled_as_remote_server_technical() {
914 let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
915 let handled_error = AggregatorClientError::from_response(response).await;
916
917 assert!(
918 matches!(
919 handled_error,
920 AggregatorClientError::RemoteServerTechnical(..)
921 ),
922 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
923 );
924 }
925
926 #[tokio::test]
927 async fn test_550_error_is_handled_as_registration_round_not_yet_opened() {
928 let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available");
929 let handled_error = AggregatorClientError::from_response(response).await;
930
931 assert!(
932 matches!(
933 handled_error,
934 AggregatorClientError::RegistrationRoundNotYetOpened(..)
935 ),
936 "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'",
937 );
938 }
939
940 #[tokio::test]
941 async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
942 {
943 let response = build_text_response(StatusCode::OK, "ok text");
944 let handled_error = AggregatorClientError::from_response(response).await;
945
946 assert!(
947 matches!(
948 handled_error,
949 AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
950 ),
951 "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
952 );
953 }
954
955 #[tokio::test]
956 async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
957 let error_text = "An error occurred; please try again later.";
958 let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
959
960 assert_error_text_contains!(
961 AggregatorClientError::get_root_cause(response).await,
962 "expectation failed: An error occurred; please try again later."
963 );
964 }
965
966 #[tokio::test]
967 async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
968 {
969 let client_error = ClientError::new("label", "message");
970 let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
971
972 assert_error_text_contains!(
973 AggregatorClientError::get_root_cause(response).await,
974 "bad request: label: message"
975 );
976 }
977
978 #[tokio::test]
979 async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
980 {
981 let server_error = ServerError::new("message");
982 let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
983
984 assert_error_text_contains!(
985 AggregatorClientError::get_root_cause(response).await,
986 "bad request: message"
987 );
988 }
989
990 #[tokio::test]
991 async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
992 let response = build_json_response(
993 StatusCode::INTERNAL_SERVER_ERROR,
994 &json!({ "second": "unknown", "first": "foreign" }),
995 );
996
997 assert_error_text_contains!(
998 AggregatorClientError::get_root_cause(response).await,
999 r#"internal server error: {"first":"foreign","second":"unknown"}"#
1000 );
1001 }
1002
1003 #[tokio::test]
1004 async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
1005 let response = HttpResponseBuilder::new()
1006 .status(StatusCode::BAD_REQUEST)
1007 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
1008 .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
1009 .unwrap()
1010 .into();
1011
1012 let root_cause = AggregatorClientError::get_root_cause(response).await;
1013
1014 assert_error_text_contains!(root_cause, "bad request");
1015 assert!(
1016 !root_cause.contains("bad request: "),
1017 "Expected error message should not contain additional information \ngot '{root_cause:?}'"
1018 );
1019 }
1020
1021 #[tokio::test]
1022 async fn test_sends_accept_encoding_header() {
1023 let (server, client) = setup_server_and_client();
1024 server.mock(|when, then| {
1025 when.is_true(|req| {
1026 let headers = req.headers();
1027 let accept_encoding_header = headers
1028 .get("accept-encoding")
1029 .expect("Accept-Encoding header not found");
1030
1031 ["gzip", "br", "deflate", "zstd"].iter().all(|&encoding| {
1032 accept_encoding_header.to_str().is_ok_and(|h| h.contains(encoding))
1033 })
1034 });
1035
1036 then.status(201);
1037 });
1038
1039 client
1040 .register_signature(
1041 &SignedEntityType::dummy(),
1042 &fake_data::single_signature((1..5).collect()),
1043 &ProtocolMessage::default(),
1044 )
1045 .await
1046 .expect("Should succeed with Accept-Encoding header");
1047 }
1048
1049 mod warn_if_api_version_mismatch {
1050 use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
1051
1052 use super::*;
1053
1054 fn version_provider_with_open_api_version<V: Into<String>>(
1055 version: V,
1056 ) -> APIVersionProvider {
1057 let mut version_provider = version_provider_without_open_api_version();
1058 let mut open_api_versions = HashMap::new();
1059 open_api_versions.insert(
1060 "openapi.yaml".to_string(),
1061 Version::parse(&version.into()).unwrap(),
1062 );
1063 version_provider.update_open_api_versions(open_api_versions);
1064
1065 version_provider
1066 }
1067
1068 fn version_provider_without_open_api_version() -> APIVersionProvider {
1069 let mut version_provider =
1070 APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::new("dummy")));
1071 version_provider.update_open_api_versions(HashMap::new());
1072
1073 version_provider
1074 }
1075
1076 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
1077 key: K,
1078 value: V,
1079 ) -> Response {
1080 HttpResponseBuilder::new()
1081 .header(key.into(), value.into())
1082 .body("whatever")
1083 .unwrap()
1084 .into()
1085 }
1086
1087 fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
1088 log_inspector: &MemoryDrainForTestInspector,
1089 aggregator_version: A,
1090 signer_version: S,
1091 ) {
1092 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1093 assert!(
1094 log_inspector
1095 .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
1096 );
1097 assert!(
1098 log_inspector.contains_log(&format!("signer_version={}", signer_version.into()))
1099 );
1100 }
1101
1102 #[test]
1103 fn test_logs_warning_when_aggregator_api_version_is_newer() {
1104 let aggregator_version = "2.0.0";
1105 let signer_version = "1.0.0";
1106 let (logger, log_inspector) = TestLogger::memory();
1107 let version_provider = version_provider_with_open_api_version(signer_version);
1108 let mut client = setup_client("whatever");
1109 client.api_version_provider = Arc::new(version_provider);
1110 client.logger = logger;
1111 let response =
1112 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1113
1114 assert!(
1115 Version::parse(aggregator_version).unwrap()
1116 > Version::parse(signer_version).unwrap()
1117 );
1118
1119 client.warn_if_api_version_mismatch(&response);
1120
1121 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1122 }
1123
1124 #[test]
1125 fn test_no_warning_logged_when_versions_match() {
1126 let version = "1.0.0";
1127 let (logger, log_inspector) = TestLogger::memory();
1128 let version_provider = version_provider_with_open_api_version(version);
1129 let mut client = setup_client("whatever");
1130 client.api_version_provider = Arc::new(version_provider);
1131 client.logger = logger;
1132 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
1133
1134 client.warn_if_api_version_mismatch(&response);
1135
1136 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1137 }
1138
1139 #[test]
1140 fn test_no_warning_logged_when_aggregator_api_version_is_older() {
1141 let aggregator_version = "1.0.0";
1142 let signer_version = "2.0.0";
1143 let (logger, log_inspector) = TestLogger::memory();
1144 let version_provider = version_provider_with_open_api_version(signer_version);
1145 let mut client = setup_client("whatever");
1146 client.api_version_provider = Arc::new(version_provider);
1147 client.logger = logger;
1148 let response =
1149 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1150
1151 assert!(
1152 Version::parse(aggregator_version).unwrap()
1153 < Version::parse(signer_version).unwrap()
1154 );
1155
1156 client.warn_if_api_version_mismatch(&response);
1157
1158 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1159 }
1160
1161 #[test]
1162 fn test_does_not_log_or_fail_when_header_is_missing() {
1163 let (logger, log_inspector) = TestLogger::memory();
1164 let mut client = setup_client("whatever");
1165 client.logger = logger;
1166 let response =
1167 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
1168
1169 client.warn_if_api_version_mismatch(&response);
1170
1171 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1172 }
1173
1174 #[test]
1175 fn test_does_not_log_or_fail_when_header_is_not_a_version() {
1176 let (logger, log_inspector) = TestLogger::memory();
1177 let mut client = setup_client("whatever");
1178 client.logger = logger;
1179 let response =
1180 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1181
1182 client.warn_if_api_version_mismatch(&response);
1183
1184 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1185 }
1186
1187 #[test]
1188 fn test_logs_error_when_signer_version_cannot_be_computed() {
1189 let (logger, log_inspector) = TestLogger::memory();
1190 let version_provider = version_provider_without_open_api_version();
1191 let mut client = setup_client("whatever");
1192 client.api_version_provider = Arc::new(version_provider);
1193 client.logger = logger;
1194 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1195
1196 client.warn_if_api_version_mismatch(&response);
1197
1198 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1199 }
1200
1201 #[tokio::test]
1202 async fn test_aggregator_features_ok_200_log_warning_if_api_version_mismatch() {
1203 let aggregator_version = "2.0.0";
1204 let signer_version = "1.0.0";
1205 let (server, mut client) = setup_server_and_client();
1206 let (logger, log_inspector) = TestLogger::memory();
1207 let version_provider = version_provider_with_open_api_version(signer_version);
1208 client.api_version_provider = Arc::new(version_provider);
1209 client.logger = logger;
1210
1211 let message_expected = AggregatorFeaturesMessage::dummy();
1212 let _server_mock = server.mock(|when, then| {
1213 when.path("/");
1214 then.status(200)
1215 .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1216 .body(json!(message_expected).to_string());
1217 });
1218
1219 assert!(
1220 Version::parse(aggregator_version).unwrap()
1221 > Version::parse(signer_version).unwrap()
1222 );
1223
1224 client.retrieve_aggregator_features().await.unwrap();
1225
1226 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1227 }
1228
1229 #[tokio::test]
1230 async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
1231 let aggregator_version = "2.0.0";
1232 let signer_version = "1.0.0";
1233 let (server, mut client) = setup_server_and_client();
1234 let (logger, log_inspector) = TestLogger::memory();
1235 let version_provider = version_provider_with_open_api_version(signer_version);
1236 client.api_version_provider = Arc::new(version_provider);
1237 client.logger = logger;
1238
1239 let epoch_settings_expected = EpochSettingsMessage::dummy();
1240 let _server_mock = server.mock(|when, then| {
1241 when.path("/epoch-settings");
1242 then.status(200)
1243 .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1244 .body(json!(epoch_settings_expected).to_string());
1245 });
1246
1247 assert!(
1248 Version::parse(aggregator_version).unwrap()
1249 > Version::parse(signer_version).unwrap()
1250 );
1251
1252 client.retrieve_epoch_settings().await.unwrap();
1253
1254 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1255 }
1256
1257 #[tokio::test]
1258 async fn test_register_signer_ok_201_log_warning_if_api_version_mismatch() {
1259 let aggregator_version = "2.0.0";
1260 let signer_version = "1.0.0";
1261 let epoch = Epoch(1);
1262 let single_signers = fake_data::signers(1);
1263 let single_signer = single_signers.first().unwrap();
1264 let (server, mut client) = setup_server_and_client();
1265 let (logger, log_inspector) = TestLogger::memory();
1266 let version_provider = version_provider_with_open_api_version(signer_version);
1267 client.api_version_provider = Arc::new(version_provider);
1268 client.logger = logger;
1269 let _server_mock = server.mock(|when, then| {
1270 when.method(POST).path("/register-signer");
1271 then.status(201)
1272 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1273 });
1274
1275 assert!(
1276 Version::parse(aggregator_version).unwrap()
1277 > Version::parse(signer_version).unwrap()
1278 );
1279
1280 client.register_signer(epoch, single_signer).await.unwrap();
1281
1282 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1283 }
1284
1285 #[tokio::test]
1286 async fn test_register_signature_ok_201_log_warning_if_api_version_mismatch() {
1287 let aggregator_version = "2.0.0";
1288 let signer_version = "1.0.0";
1289 let single_signature = fake_data::single_signature((1..5).collect());
1290 let (server, mut client) = setup_server_and_client();
1291 let (logger, log_inspector) = TestLogger::memory();
1292 let version_provider = version_provider_with_open_api_version(signer_version);
1293 client.api_version_provider = Arc::new(version_provider);
1294 client.logger = logger;
1295 let _server_mock = server.mock(|when, then| {
1296 when.method(POST).path("/register-signatures");
1297 then.status(201)
1298 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1299 });
1300
1301 assert!(
1302 Version::parse(aggregator_version).unwrap()
1303 > Version::parse(signer_version).unwrap()
1304 );
1305
1306 client
1307 .register_signature(
1308 &SignedEntityType::dummy(),
1309 &single_signature,
1310 &ProtocolMessage::default(),
1311 )
1312 .await
1313 .expect("Should not fail");
1314
1315 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1316 }
1317
1318 #[tokio::test]
1319 async fn test_register_signature_ok_202_log_warning_if_api_version_mismatch() {
1320 let aggregator_version = "2.0.0";
1321 let signer_version = "1.0.0";
1322 let single_signature = fake_data::single_signature((1..5).collect());
1323 let (server, mut client) = setup_server_and_client();
1324 let (logger, log_inspector) = TestLogger::memory();
1325 let version_provider = version_provider_with_open_api_version(signer_version);
1326 client.api_version_provider = Arc::new(version_provider);
1327 client.logger = logger;
1328 let _server_mock = server.mock(|when, then| {
1329 when.method(POST).path("/register-signatures");
1330 then.status(202)
1331 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1332 });
1333
1334 assert!(
1335 Version::parse(aggregator_version).unwrap()
1336 > Version::parse(signer_version).unwrap()
1337 );
1338
1339 client
1340 .register_signature(
1341 &SignedEntityType::dummy(),
1342 &single_signature,
1343 &ProtocolMessage::default(),
1344 )
1345 .await
1346 .unwrap();
1347
1348 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1349 }
1350
1351 #[tokio::test]
1352 async fn test_register_signature_ok_410_log_warning_if_api_version_mismatch() {
1353 let aggregator_version = "2.0.0";
1354 let signer_version = "1.0.0";
1355 let single_signature = fake_data::single_signature((1..5).collect());
1356 let (server, mut client) = setup_server_and_client();
1357 let (logger, log_inspector) = TestLogger::memory();
1358 let version_provider = version_provider_with_open_api_version(signer_version);
1359 client.api_version_provider = Arc::new(version_provider);
1360 client.logger = logger;
1361 let _server_mock = server.mock(|when, then| {
1362 when.method(POST).path("/register-signatures");
1363 then.status(410)
1364 .body(
1365 serde_json::to_vec(&ClientError::new(
1366 "already_aggregated".to_string(),
1367 "too late".to_string(),
1368 ))
1369 .unwrap(),
1370 )
1371 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1372 });
1373
1374 assert!(
1375 Version::parse(aggregator_version).unwrap()
1376 > Version::parse(signer_version).unwrap()
1377 );
1378
1379 client
1380 .register_signature(
1381 &SignedEntityType::dummy(),
1382 &single_signature,
1383 &ProtocolMessage::default(),
1384 )
1385 .await
1386 .unwrap();
1387
1388 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1389 }
1390 }
1391}