1use anyhow::anyhow;
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
5use semver::Version;
6use slog::{Logger, debug, error, warn};
7use std::{io, sync::Arc, time::Duration};
8use thiserror::Error;
9
10use mithril_common::{
11 MITHRIL_API_VERSION_HEADER, MITHRIL_SIGNER_VERSION_HEADER, StdError,
12 api_version::APIVersionProvider,
13 entities::{
14 ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer, SingleSignature,
15 },
16 logging::LoggerExtensions,
17 messages::{
18 AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
19 },
20};
21
22use crate::entities::SignerEpochSettings;
23use crate::message_adapters::{
24 FromEpochSettingsAdapter, ToRegisterSignatureMessageAdapter, ToRegisterSignerMessageAdapter,
25};
26
27const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
28
29const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
30 "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
31
32#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35 #[error("remote server technical error")]
37 RemoteServerTechnical(#[source] StdError),
38
39 #[error("remote server logical error")]
41 RemoteServerLogical(#[source] StdError),
42
43 #[error("remote server unreachable")]
45 RemoteServerUnreachable(#[source] StdError),
46
47 #[error("unhandled status code: {0}, response text: {1}")]
49 UnhandledStatusCode(StatusCode, String),
50
51 #[error("json parsing failed")]
53 JsonParseFailed(#[source] StdError),
54
55 #[error("Input/Output error")]
57 IOError(#[from] io::Error),
58
59 #[error("HTTP client creation failed")]
61 HTTPClientCreation(#[source] StdError),
62
63 #[error("proxy creation failed")]
65 ProxyCreation(#[source] StdError),
66
67 #[error("adapter failed")]
69 Adapter(#[source] StdError),
70
71 #[error("a signer registration round is not opened yet, please try again later")]
73 RegistrationRoundNotYetOpened(#[source] StdError),
74}
75
76impl AggregatorClientError {
77 pub async fn from_response(response: Response) -> Self {
83 let error_code = response.status();
84
85 if error_code.is_client_error() {
86 let root_cause = Self::get_root_cause(response).await;
87 Self::RemoteServerLogical(anyhow!(root_cause))
88 } else if error_code.is_server_error() {
89 let root_cause = Self::get_root_cause(response).await;
90 match error_code.as_u16() {
91 550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)),
92 _ => Self::RemoteServerTechnical(anyhow!(root_cause)),
93 }
94 } else {
95 let response_text = response.text().await.unwrap_or_default();
96 Self::UnhandledStatusCode(error_code, response_text)
97 }
98 }
99
100 async fn get_root_cause(response: Response) -> String {
101 let error_code = response.status();
102 let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
103 let is_json = response
104 .headers()
105 .get(header::CONTENT_TYPE)
106 .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
107
108 if is_json {
109 let json_value: serde_json::Value = response.json().await.unwrap_or_default();
110
111 if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
112 format!(
113 "{}: {}: {}",
114 canonical_reason, client_error.label, client_error.message
115 )
116 } else if let Ok(server_error) =
117 serde_json::from_value::<ServerError>(json_value.clone())
118 {
119 format!("{}: {}", canonical_reason, server_error.message)
120 } else if json_value.is_null() {
121 canonical_reason.to_string()
122 } else {
123 format!("{canonical_reason}: {json_value}")
124 }
125 } else {
126 let response_text = response.text().await.unwrap_or_default();
127 format!("{canonical_reason}: {response_text}")
128 }
129 }
130}
131
132#[cfg_attr(test, mockall::automock)]
134#[async_trait]
135pub trait AggregatorClient: Sync + Send {
136 async fn retrieve_epoch_settings(
138 &self,
139 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError>;
140
141 async fn register_signer(
143 &self,
144 epoch: Epoch,
145 signer: &Signer,
146 ) -> Result<(), AggregatorClientError>;
147
148 async fn register_signature(
150 &self,
151 signed_entity_type: &SignedEntityType,
152 signature: &SingleSignature,
153 protocol_message: &ProtocolMessage,
154 ) -> Result<(), AggregatorClientError>;
155
156 async fn retrieve_aggregator_features(
158 &self,
159 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError>;
160}
161
162pub struct AggregatorHTTPClient {
164 aggregator_endpoint: String,
165 relay_endpoint: Option<String>,
166 api_version_provider: Arc<APIVersionProvider>,
167 timeout_duration: Option<Duration>,
168 logger: Logger,
169}
170
171impl AggregatorHTTPClient {
172 pub fn new(
174 aggregator_endpoint: String,
175 relay_endpoint: Option<String>,
176 api_version_provider: Arc<APIVersionProvider>,
177 timeout_duration: Option<Duration>,
178 logger: Logger,
179 ) -> Self {
180 let logger = logger.new_with_component_name::<Self>();
181 debug!(logger, "New AggregatorHTTPClient created");
182 Self {
183 aggregator_endpoint,
184 relay_endpoint,
185 api_version_provider,
186 timeout_duration,
187 logger,
188 }
189 }
190
191 fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
192 let client = match &self.relay_endpoint {
193 Some(relay_endpoint) => Client::builder()
194 .proxy(
195 Proxy::all(relay_endpoint)
196 .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
197 )
198 .build()
199 .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
200 None => Client::new(),
201 };
202
203 Ok(client)
204 }
205
206 pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
208 let request_builder = request_builder
209 .header(
210 MITHRIL_API_VERSION_HEADER,
211 self.api_version_provider
212 .compute_current_version()
213 .unwrap()
214 .to_string(),
215 )
216 .header(MITHRIL_SIGNER_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
217
218 if let Some(duration) = self.timeout_duration {
219 request_builder.timeout(duration)
220 } else {
221 request_builder
222 }
223 }
224
225 fn warn_if_api_version_mismatch(&self, response: &Response) {
227 let aggregator_version = response
228 .headers()
229 .get(MITHRIL_API_VERSION_HEADER)
230 .and_then(|v| v.to_str().ok())
231 .and_then(|s| Version::parse(s).ok());
232
233 let signer_version = self.api_version_provider.compute_current_version();
234
235 match (aggregator_version, signer_version) {
236 (Some(aggregator), Ok(signer)) if signer < aggregator => {
237 warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
238 "aggregator_version" => %aggregator,
239 "signer_version" => %signer,
240 );
241 }
242 (Some(_), Err(error)) => {
243 error!(
244 self.logger,
245 "Failed to compute the current signer API version";
246 "error" => error.to_string()
247 );
248 }
249 _ => {}
250 }
251 }
252}
253
254#[async_trait]
255impl AggregatorClient for AggregatorHTTPClient {
256 async fn retrieve_epoch_settings(
257 &self,
258 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
259 debug!(self.logger, "Retrieve epoch settings");
260 let url = format!("{}/epoch-settings", self.aggregator_endpoint);
261 let response = self
262 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
263 .send()
264 .await;
265
266 match response {
267 Ok(response) => match response.status() {
268 StatusCode::OK => {
269 self.warn_if_api_version_mismatch(&response);
270 match response.json::<EpochSettingsMessage>().await {
271 Ok(message) => {
272 let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
273 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
274 Ok(Some(epoch_settings))
275 }
276 Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
277 }
278 }
279 _ => Err(AggregatorClientError::from_response(response).await),
280 },
281 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
282 }
283 }
284
285 async fn register_signer(
286 &self,
287 epoch: Epoch,
288 signer: &Signer,
289 ) -> Result<(), AggregatorClientError> {
290 debug!(self.logger, "Register signer");
291 let url = format!("{}/register-signer", self.aggregator_endpoint);
292 let register_signer_message =
293 ToRegisterSignerMessageAdapter::try_adapt((epoch, signer.to_owned()))
294 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
295 let response = self
296 .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
297 .json(®ister_signer_message)
298 .send()
299 .await;
300
301 match response {
302 Ok(response) => match response.status() {
303 StatusCode::CREATED => {
304 self.warn_if_api_version_mismatch(&response);
305
306 Ok(())
307 }
308 _ => Err(AggregatorClientError::from_response(response).await),
309 },
310 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
311 }
312 }
313
314 async fn register_signature(
315 &self,
316 signed_entity_type: &SignedEntityType,
317 signature: &SingleSignature,
318 protocol_message: &ProtocolMessage,
319 ) -> Result<(), AggregatorClientError> {
320 debug!(self.logger, "Register signature");
321 let url = format!("{}/register-signatures", self.aggregator_endpoint);
322 let register_single_signature_message = ToRegisterSignatureMessageAdapter::try_adapt((
323 signed_entity_type.to_owned(),
324 signature.to_owned(),
325 protocol_message,
326 ))
327 .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
328 let response = self
329 .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
330 .json(®ister_single_signature_message)
331 .send()
332 .await;
333
334 match response {
335 Ok(response) => match response.status() {
336 StatusCode::CREATED | StatusCode::ACCEPTED => {
337 self.warn_if_api_version_mismatch(&response);
338
339 Ok(())
340 }
341 StatusCode::GONE => {
342 self.warn_if_api_version_mismatch(&response);
343 let root_cause = AggregatorClientError::get_root_cause(response).await;
344 debug!(self.logger, "Message already certified or expired"; "details" => &root_cause);
345
346 Ok(())
347 }
348 _ => Err(AggregatorClientError::from_response(response).await),
349 },
350 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
351 }
352 }
353
354 async fn retrieve_aggregator_features(
355 &self,
356 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
357 debug!(self.logger, "Retrieve aggregator features message");
358 let url = format!("{}/", self.aggregator_endpoint);
359 let response = self
360 .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
361 .send()
362 .await;
363
364 match response {
365 Ok(response) => match response.status() {
366 StatusCode::OK => {
367 self.warn_if_api_version_mismatch(&response);
368
369 Ok(response
370 .json::<AggregatorFeaturesMessage>()
371 .await
372 .map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?)
373 }
374 _ => Err(AggregatorClientError::from_response(response).await),
375 },
376 Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
377 }
378 }
379}
380
381#[cfg(test)]
382pub(crate) mod dumb {
383 use mithril_common::test::double::Dummy;
384 use tokio::sync::RwLock;
385
386 use super::*;
387
388 pub struct DumbAggregatorClient {
392 epoch_settings: RwLock<Option<SignerEpochSettings>>,
393 last_registered_signer: RwLock<Option<Signer>>,
394 aggregator_features: RwLock<AggregatorFeaturesMessage>,
395 }
396
397 impl DumbAggregatorClient {
398 pub async fn get_last_registered_signer(&self) -> Option<Signer> {
400 self.last_registered_signer.read().await.clone()
401 }
402
403 pub async fn set_aggregator_features(
404 &self,
405 aggregator_features: AggregatorFeaturesMessage,
406 ) {
407 let mut aggregator_features_writer = self.aggregator_features.write().await;
408 *aggregator_features_writer = aggregator_features;
409 }
410 }
411
412 impl Default for DumbAggregatorClient {
413 fn default() -> Self {
414 Self {
415 epoch_settings: RwLock::new(Some(SignerEpochSettings::dummy())),
416 last_registered_signer: RwLock::new(None),
417 aggregator_features: RwLock::new(AggregatorFeaturesMessage::dummy()),
418 }
419 }
420 }
421
422 #[async_trait]
423 impl AggregatorClient for DumbAggregatorClient {
424 async fn retrieve_epoch_settings(
425 &self,
426 ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
427 let epoch_settings = self.epoch_settings.read().await.clone();
428
429 Ok(epoch_settings)
430 }
431
432 async fn register_signer(
434 &self,
435 _epoch: Epoch,
436 signer: &Signer,
437 ) -> Result<(), AggregatorClientError> {
438 let mut last_registered_signer = self.last_registered_signer.write().await;
439 let signer = signer.clone();
440 *last_registered_signer = Some(signer);
441
442 Ok(())
443 }
444
445 async fn register_signature(
447 &self,
448 _signed_entity_type: &SignedEntityType,
449 _signature: &SingleSignature,
450 _protocol_message: &ProtocolMessage,
451 ) -> Result<(), AggregatorClientError> {
452 Ok(())
453 }
454
455 async fn retrieve_aggregator_features(
456 &self,
457 ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
458 let aggregator_features = self.aggregator_features.read().await;
459 Ok(aggregator_features.clone())
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use std::collections::HashMap;
467
468 use http::response::Builder as HttpResponseBuilder;
469 use httpmock::prelude::*;
470 use semver::Version;
471 use serde_json::json;
472
473 use mithril_common::entities::Epoch;
474 use mithril_common::messages::TryFromMessageAdapter;
475 use mithril_common::test::{
476 double::{Dummy, DummyApiVersionDiscriminantSource, fake_data},
477 logging::MemoryDrainForTestInspector,
478 };
479
480 use crate::test_tools::TestLogger;
481
482 use super::*;
483
484 macro_rules! assert_is_error {
485 ($error:expr, $error_type:pat) => {
486 assert!(
487 matches!($error, $error_type),
488 "Expected {} error, got '{:?}'.",
489 stringify!($error_type),
490 $error
491 );
492 };
493 }
494
495 fn setup_client<U: Into<String>>(server_url: U) -> AggregatorHTTPClient {
496 let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
497 let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
498
499 AggregatorHTTPClient::new(
500 server_url.into(),
501 None,
502 Arc::new(api_version_provider),
503 None,
504 TestLogger::stdout(),
505 )
506 }
507
508 fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
509 let server = MockServer::start();
510 let aggregator_endpoint = server.url("");
511 let client = setup_client(&aggregator_endpoint);
512
513 (server, client)
514 }
515
516 fn set_returning_500(server: &MockServer) {
517 server.mock(|_, then| {
518 then.status(500).body("an error occurred");
519 });
520 }
521
522 fn set_unparsable_json(server: &MockServer) {
523 server.mock(|_, then| {
524 then.status(200).body("this is not a json");
525 });
526 }
527
528 fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
529 HttpResponseBuilder::new()
530 .status(status_code)
531 .body(body.into())
532 .unwrap()
533 .into()
534 }
535
536 fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
537 HttpResponseBuilder::new()
538 .status(status_code)
539 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
540 .body(serde_json::to_string(&body).unwrap())
541 .unwrap()
542 .into()
543 }
544
545 macro_rules! assert_error_text_contains {
546 ($error: expr, $expect_contains: expr) => {
547 let error = &$error;
548 assert!(
549 error.contains($expect_contains),
550 "Expected error message to contain '{}'\ngot '{error:?}'",
551 $expect_contains,
552 );
553 };
554 }
555
556 #[tokio::test]
557 async fn test_aggregator_features_ok_200() {
558 let (server, client) = setup_server_and_client();
559 let message_expected = AggregatorFeaturesMessage::dummy();
560 let _server_mock = server.mock(|when, then| {
561 when.path("/");
562 then.status(200).body(json!(message_expected).to_string());
563 });
564
565 let message = client.retrieve_aggregator_features().await.unwrap();
566
567 assert_eq!(message_expected, message);
568 }
569
570 #[tokio::test]
571 async fn test_aggregator_features_ko_500() {
572 let (server, client) = setup_server_and_client();
573 set_returning_500(&server);
574
575 let error = client.retrieve_aggregator_features().await.unwrap_err();
576
577 assert_is_error!(error, AggregatorClientError::RemoteServerTechnical(_));
578 }
579
580 #[tokio::test]
581 async fn test_aggregator_features_ko_json_serialization() {
582 let (server, client) = setup_server_and_client();
583 set_unparsable_json(&server);
584
585 let error = client.retrieve_aggregator_features().await.unwrap_err();
586
587 assert_is_error!(error, AggregatorClientError::JsonParseFailed(_));
588 }
589
590 #[tokio::test]
591 async fn test_aggregator_features_timeout() {
592 let (server, mut client) = setup_server_and_client();
593 client.timeout_duration = Some(Duration::from_millis(10));
594 let _server_mock = server.mock(|when, then| {
595 when.path("/");
596 then.delay(Duration::from_millis(100));
597 });
598
599 let error = client.retrieve_aggregator_features().await.unwrap_err();
600
601 assert_is_error!(error, AggregatorClientError::RemoteServerUnreachable(_));
602 }
603
604 #[tokio::test]
605 async fn test_epoch_settings_ok_200() {
606 let (server, client) = setup_server_and_client();
607 let epoch_settings_expected = EpochSettingsMessage::dummy();
608 let _server_mock = server.mock(|when, then| {
609 when.path("/epoch-settings");
610 then.status(200).body(json!(epoch_settings_expected).to_string());
611 });
612
613 let epoch_settings = client.retrieve_epoch_settings().await;
614 epoch_settings.as_ref().expect("unexpected error");
615 assert_eq!(
616 FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
617 epoch_settings.unwrap().unwrap()
618 );
619 }
620
621 #[tokio::test]
622 async fn test_epoch_settings_ko_500() {
623 let (server, client) = setup_server_and_client();
624 let _server_mock = server.mock(|when, then| {
625 when.path("/epoch-settings");
626 then.status(500).body("an error occurred");
627 });
628
629 match client.retrieve_epoch_settings().await.unwrap_err() {
630 AggregatorClientError::RemoteServerTechnical(_) => (),
631 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
632 };
633 }
634
635 #[tokio::test]
636 async fn test_epoch_settings_timeout() {
637 let (server, mut client) = setup_server_and_client();
638 client.timeout_duration = Some(Duration::from_millis(10));
639 let _server_mock = server.mock(|when, then| {
640 when.path("/epoch-settings");
641 then.delay(Duration::from_millis(100));
642 });
643
644 let error = client
645 .retrieve_epoch_settings()
646 .await
647 .expect_err("retrieve_epoch_settings should fail");
648
649 assert!(
650 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
651 "unexpected error type: {error:?}"
652 );
653 }
654
655 #[tokio::test]
656 async fn test_register_signer_ok_201() {
657 let epoch = Epoch(1);
658 let single_signers = fake_data::signers(1);
659 let single_signer = single_signers.first().unwrap();
660 let (server, client) = setup_server_and_client();
661 let _server_mock = server.mock(|when, then| {
662 when.method(POST).path("/register-signer");
663 then.status(201);
664 });
665
666 let register_signer = client.register_signer(epoch, single_signer).await;
667 register_signer.expect("unexpected error");
668 }
669
670 #[tokio::test]
671 async fn test_register_signer_ko_400() {
672 let epoch = Epoch(1);
673 let single_signers = fake_data::signers(1);
674 let single_signer = single_signers.first().unwrap();
675 let (server, client) = setup_server_and_client();
676 let _server_mock = server.mock(|when, then| {
677 when.method(POST).path("/register-signer");
678 then.status(400).body(
679 serde_json::to_vec(&ClientError::new(
680 "error".to_string(),
681 "an error".to_string(),
682 ))
683 .unwrap(),
684 );
685 });
686
687 match client.register_signer(epoch, single_signer).await.unwrap_err() {
688 AggregatorClientError::RemoteServerLogical(_) => (),
689 err => {
690 panic!(
691 "Expected a AggregatorClientError::RemoteServerLogical error, got '{err:?}'."
692 )
693 }
694 };
695 }
696
697 #[tokio::test]
698 async fn test_register_signer_ko_500() {
699 let epoch = Epoch(1);
700 let single_signers = fake_data::signers(1);
701 let single_signer = single_signers.first().unwrap();
702 let (server, client) = setup_server_and_client();
703 let _server_mock = server.mock(|when, then| {
704 when.method(POST).path("/register-signer");
705 then.status(500).body("an error occurred");
706 });
707
708 match client.register_signer(epoch, single_signer).await.unwrap_err() {
709 AggregatorClientError::RemoteServerTechnical(_) => (),
710 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
711 };
712 }
713
714 #[tokio::test]
715 async fn test_register_signer_timeout() {
716 let epoch = Epoch(1);
717 let single_signers = fake_data::signers(1);
718 let single_signer = single_signers.first().unwrap();
719 let (server, mut client) = setup_server_and_client();
720 client.timeout_duration = Some(Duration::from_millis(10));
721 let _server_mock = server.mock(|when, then| {
722 when.method(POST).path("/register-signer");
723 then.delay(Duration::from_millis(100));
724 });
725
726 let error = client
727 .register_signer(epoch, single_signer)
728 .await
729 .expect_err("register_signer should fail");
730
731 assert!(
732 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
733 "unexpected error type: {error:?}"
734 );
735 }
736
737 #[tokio::test]
738 async fn test_register_signature_ok_201() {
739 let single_signature = fake_data::single_signature((1..5).collect());
740 let (server, client) = setup_server_and_client();
741 let _server_mock = server.mock(|when, then| {
742 when.method(POST).path("/register-signatures");
743 then.status(201);
744 });
745
746 let register_signature = client
747 .register_signature(
748 &SignedEntityType::dummy(),
749 &single_signature,
750 &ProtocolMessage::default(),
751 )
752 .await;
753 register_signature.expect("unexpected error");
754 }
755
756 #[tokio::test]
757 async fn test_register_signature_ok_202() {
758 let single_signature = fake_data::single_signature((1..5).collect());
759 let (server, client) = setup_server_and_client();
760 let _server_mock = server.mock(|when, then| {
761 when.method(POST).path("/register-signatures");
762 then.status(202);
763 });
764
765 let register_signature = client
766 .register_signature(
767 &SignedEntityType::dummy(),
768 &single_signature,
769 &ProtocolMessage::default(),
770 )
771 .await;
772 register_signature.expect("unexpected error");
773 }
774
775 #[tokio::test]
776 async fn test_register_signature_ko_400() {
777 let single_signature = fake_data::single_signature((1..5).collect());
778 let (server, client) = setup_server_and_client();
779 let _server_mock = server.mock(|when, then| {
780 when.method(POST).path("/register-signatures");
781 then.status(400).body(
782 serde_json::to_vec(&ClientError::new(
783 "error".to_string(),
784 "an error".to_string(),
785 ))
786 .unwrap(),
787 );
788 });
789
790 match client
791 .register_signature(
792 &SignedEntityType::dummy(),
793 &single_signature,
794 &ProtocolMessage::default(),
795 )
796 .await
797 .unwrap_err()
798 {
799 AggregatorClientError::RemoteServerLogical(_) => (),
800 e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
801 };
802 }
803
804 #[tokio::test]
805 async fn test_register_signature_ok_410_log_response_body() {
806 let (logger, log_inspector) = TestLogger::memory();
807
808 let single_signature = fake_data::single_signature((1..5).collect());
809 let (server, mut client) = setup_server_and_client();
810 client.logger = logger;
811 let _server_mock = server.mock(|when, then| {
812 when.method(POST).path("/register-signatures");
813 then.status(410).body(
814 serde_json::to_vec(&ClientError::new(
815 "already_aggregated".to_string(),
816 "too late".to_string(),
817 ))
818 .unwrap(),
819 );
820 });
821
822 client
823 .register_signature(
824 &SignedEntityType::dummy(),
825 &single_signature,
826 &ProtocolMessage::default(),
827 )
828 .await
829 .expect("Should not fail when status is 410 (GONE)");
830
831 assert!(log_inspector.contains_log("already_aggregated"));
832 assert!(log_inspector.contains_log("too late"));
833 }
834
835 #[tokio::test]
836 async fn test_register_signature_ko_409() {
837 let single_signature = fake_data::single_signature((1..5).collect());
838 let (server, client) = setup_server_and_client();
839 let _server_mock = server.mock(|when, then| {
840 when.method(POST).path("/register-signatures");
841 then.status(409);
842 });
843
844 match client
845 .register_signature(
846 &SignedEntityType::dummy(),
847 &single_signature,
848 &ProtocolMessage::default(),
849 )
850 .await
851 .unwrap_err()
852 {
853 AggregatorClientError::RemoteServerLogical(_) => (),
854 e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
855 }
856 }
857
858 #[tokio::test]
859 async fn test_register_signature_ko_500() {
860 let single_signature = fake_data::single_signature((1..5).collect());
861 let (server, client) = setup_server_and_client();
862 let _server_mock = server.mock(|when, then| {
863 when.method(POST).path("/register-signatures");
864 then.status(500).body("an error occurred");
865 });
866
867 match client
868 .register_signature(
869 &SignedEntityType::dummy(),
870 &single_signature,
871 &ProtocolMessage::default(),
872 )
873 .await
874 .unwrap_err()
875 {
876 AggregatorClientError::RemoteServerTechnical(_) => (),
877 e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
878 };
879 }
880
881 #[tokio::test]
882 async fn test_register_signature_timeout() {
883 let single_signature = fake_data::single_signature((1..5).collect());
884 let (server, mut client) = setup_server_and_client();
885 client.timeout_duration = Some(Duration::from_millis(10));
886 let _server_mock = server.mock(|when, then| {
887 when.method(POST).path("/register-signatures");
888 then.delay(Duration::from_millis(100));
889 });
890
891 let error = client
892 .register_signature(
893 &SignedEntityType::dummy(),
894 &single_signature,
895 &ProtocolMessage::default(),
896 )
897 .await
898 .expect_err("register_signature should fail");
899
900 assert!(
901 matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
902 "unexpected error type: {error:?}"
903 );
904 }
905
906 #[tokio::test]
907 async fn test_4xx_errors_are_handled_as_remote_server_logical() {
908 let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
909 let handled_error = AggregatorClientError::from_response(response).await;
910
911 assert!(
912 matches!(
913 handled_error,
914 AggregatorClientError::RemoteServerLogical(..)
915 ),
916 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
917 );
918 }
919
920 #[tokio::test]
921 async fn test_5xx_errors_are_handled_as_remote_server_technical() {
922 let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
923 let handled_error = AggregatorClientError::from_response(response).await;
924
925 assert!(
926 matches!(
927 handled_error,
928 AggregatorClientError::RemoteServerTechnical(..)
929 ),
930 "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
931 );
932 }
933
934 #[tokio::test]
935 async fn test_550_error_is_handled_as_registration_round_not_yet_opened() {
936 let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available");
937 let handled_error = AggregatorClientError::from_response(response).await;
938
939 assert!(
940 matches!(
941 handled_error,
942 AggregatorClientError::RegistrationRoundNotYetOpened(..)
943 ),
944 "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'",
945 );
946 }
947
948 #[tokio::test]
949 async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
950 {
951 let response = build_text_response(StatusCode::OK, "ok text");
952 let handled_error = AggregatorClientError::from_response(response).await;
953
954 assert!(
955 matches!(
956 handled_error,
957 AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
958 ),
959 "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
960 );
961 }
962
963 #[tokio::test]
964 async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
965 let error_text = "An error occurred; please try again later.";
966 let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
967
968 assert_error_text_contains!(
969 AggregatorClientError::get_root_cause(response).await,
970 "expectation failed: An error occurred; please try again later."
971 );
972 }
973
974 #[tokio::test]
975 async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
976 {
977 let client_error = ClientError::new("label", "message");
978 let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
979
980 assert_error_text_contains!(
981 AggregatorClientError::get_root_cause(response).await,
982 "bad request: label: message"
983 );
984 }
985
986 #[tokio::test]
987 async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
988 {
989 let server_error = ServerError::new("message");
990 let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
991
992 assert_error_text_contains!(
993 AggregatorClientError::get_root_cause(response).await,
994 "bad request: message"
995 );
996 }
997
998 #[tokio::test]
999 async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
1000 let response = build_json_response(
1001 StatusCode::INTERNAL_SERVER_ERROR,
1002 &json!({ "second": "unknown", "first": "foreign" }),
1003 );
1004
1005 assert_error_text_contains!(
1006 AggregatorClientError::get_root_cause(response).await,
1007 r#"internal server error: {"first":"foreign","second":"unknown"}"#
1008 );
1009 }
1010
1011 #[tokio::test]
1012 async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
1013 let response = HttpResponseBuilder::new()
1014 .status(StatusCode::BAD_REQUEST)
1015 .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
1016 .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
1017 .unwrap()
1018 .into();
1019
1020 let root_cause = AggregatorClientError::get_root_cause(response).await;
1021
1022 assert_error_text_contains!(root_cause, "bad request");
1023 assert!(
1024 !root_cause.contains("bad request: "),
1025 "Expected error message should not contain additional information \ngot '{root_cause:?}'"
1026 );
1027 }
1028
1029 #[tokio::test]
1030 async fn test_sends_accept_encoding_header() {
1031 let (server, client) = setup_server_and_client();
1032 server.mock(|when, then| {
1033 when.matches(|req| {
1034 let headers = req.headers.clone().expect("HTTP headers not found");
1035 let accept_encoding_header = headers
1036 .iter()
1037 .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
1038 .expect("Accept-Encoding header not found");
1039
1040 let header_value = accept_encoding_header.clone().1;
1041 ["gzip", "br", "deflate", "zstd"]
1042 .iter()
1043 .all(|&value| header_value.contains(value))
1044 });
1045
1046 then.status(201);
1047 });
1048
1049 client
1050 .register_signature(
1051 &SignedEntityType::dummy(),
1052 &fake_data::single_signature((1..5).collect()),
1053 &ProtocolMessage::default(),
1054 )
1055 .await
1056 .expect("Should succeed with Accept-Encoding header");
1057 }
1058
1059 mod warn_if_api_version_mismatch {
1060 use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
1061
1062 use super::*;
1063
1064 fn version_provider_with_open_api_version<V: Into<String>>(
1065 version: V,
1066 ) -> APIVersionProvider {
1067 let mut version_provider = version_provider_without_open_api_version();
1068 let mut open_api_versions = HashMap::new();
1069 open_api_versions.insert(
1070 "openapi.yaml".to_string(),
1071 Version::parse(&version.into()).unwrap(),
1072 );
1073 version_provider.update_open_api_versions(open_api_versions);
1074
1075 version_provider
1076 }
1077
1078 fn version_provider_without_open_api_version() -> APIVersionProvider {
1079 let mut version_provider =
1080 APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::new("dummy")));
1081 version_provider.update_open_api_versions(HashMap::new());
1082
1083 version_provider
1084 }
1085
1086 fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
1087 key: K,
1088 value: V,
1089 ) -> Response {
1090 HttpResponseBuilder::new()
1091 .header(key.into(), value.into())
1092 .body("whatever")
1093 .unwrap()
1094 .into()
1095 }
1096
1097 fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
1098 log_inspector: &MemoryDrainForTestInspector,
1099 aggregator_version: A,
1100 signer_version: S,
1101 ) {
1102 assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1103 assert!(
1104 log_inspector
1105 .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
1106 );
1107 assert!(
1108 log_inspector.contains_log(&format!("signer_version={}", signer_version.into()))
1109 );
1110 }
1111
1112 #[test]
1113 fn test_logs_warning_when_aggregator_api_version_is_newer() {
1114 let aggregator_version = "2.0.0";
1115 let signer_version = "1.0.0";
1116 let (logger, log_inspector) = TestLogger::memory();
1117 let version_provider = version_provider_with_open_api_version(signer_version);
1118 let mut client = setup_client("whatever");
1119 client.api_version_provider = Arc::new(version_provider);
1120 client.logger = logger;
1121 let response =
1122 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1123
1124 assert!(
1125 Version::parse(aggregator_version).unwrap()
1126 > Version::parse(signer_version).unwrap()
1127 );
1128
1129 client.warn_if_api_version_mismatch(&response);
1130
1131 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1132 }
1133
1134 #[test]
1135 fn test_no_warning_logged_when_versions_match() {
1136 let version = "1.0.0";
1137 let (logger, log_inspector) = TestLogger::memory();
1138 let version_provider = version_provider_with_open_api_version(version);
1139 let mut client = setup_client("whatever");
1140 client.api_version_provider = Arc::new(version_provider);
1141 client.logger = logger;
1142 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
1143
1144 client.warn_if_api_version_mismatch(&response);
1145
1146 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1147 }
1148
1149 #[test]
1150 fn test_no_warning_logged_when_aggregator_api_version_is_older() {
1151 let aggregator_version = "1.0.0";
1152 let signer_version = "2.0.0";
1153 let (logger, log_inspector) = TestLogger::memory();
1154 let version_provider = version_provider_with_open_api_version(signer_version);
1155 let mut client = setup_client("whatever");
1156 client.api_version_provider = Arc::new(version_provider);
1157 client.logger = logger;
1158 let response =
1159 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1160
1161 assert!(
1162 Version::parse(aggregator_version).unwrap()
1163 < Version::parse(signer_version).unwrap()
1164 );
1165
1166 client.warn_if_api_version_mismatch(&response);
1167
1168 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1169 }
1170
1171 #[test]
1172 fn test_does_not_log_or_fail_when_header_is_missing() {
1173 let (logger, log_inspector) = TestLogger::memory();
1174 let mut client = setup_client("whatever");
1175 client.logger = logger;
1176 let response =
1177 build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
1178
1179 client.warn_if_api_version_mismatch(&response);
1180
1181 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1182 }
1183
1184 #[test]
1185 fn test_does_not_log_or_fail_when_header_is_not_a_version() {
1186 let (logger, log_inspector) = TestLogger::memory();
1187 let mut client = setup_client("whatever");
1188 client.logger = logger;
1189 let response =
1190 build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1191
1192 client.warn_if_api_version_mismatch(&response);
1193
1194 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1195 }
1196
1197 #[test]
1198 fn test_logs_error_when_signer_version_cannot_be_computed() {
1199 let (logger, log_inspector) = TestLogger::memory();
1200 let version_provider = version_provider_without_open_api_version();
1201 let mut client = setup_client("whatever");
1202 client.api_version_provider = Arc::new(version_provider);
1203 client.logger = logger;
1204 let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1205
1206 client.warn_if_api_version_mismatch(&response);
1207
1208 assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1209 }
1210
1211 #[tokio::test]
1212 async fn test_aggregator_features_ok_200_log_warning_if_api_version_mismatch() {
1213 let aggregator_version = "2.0.0";
1214 let signer_version = "1.0.0";
1215 let (server, mut client) = setup_server_and_client();
1216 let (logger, log_inspector) = TestLogger::memory();
1217 let version_provider = version_provider_with_open_api_version(signer_version);
1218 client.api_version_provider = Arc::new(version_provider);
1219 client.logger = logger;
1220
1221 let message_expected = AggregatorFeaturesMessage::dummy();
1222 let _server_mock = server.mock(|when, then| {
1223 when.path("/");
1224 then.status(200)
1225 .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1226 .body(json!(message_expected).to_string());
1227 });
1228
1229 assert!(
1230 Version::parse(aggregator_version).unwrap()
1231 > Version::parse(signer_version).unwrap()
1232 );
1233
1234 client.retrieve_aggregator_features().await.unwrap();
1235
1236 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1237 }
1238
1239 #[tokio::test]
1240 async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
1241 let aggregator_version = "2.0.0";
1242 let signer_version = "1.0.0";
1243 let (server, mut client) = setup_server_and_client();
1244 let (logger, log_inspector) = TestLogger::memory();
1245 let version_provider = version_provider_with_open_api_version(signer_version);
1246 client.api_version_provider = Arc::new(version_provider);
1247 client.logger = logger;
1248
1249 let epoch_settings_expected = EpochSettingsMessage::dummy();
1250 let _server_mock = server.mock(|when, then| {
1251 when.path("/epoch-settings");
1252 then.status(200)
1253 .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1254 .body(json!(epoch_settings_expected).to_string());
1255 });
1256
1257 assert!(
1258 Version::parse(aggregator_version).unwrap()
1259 > Version::parse(signer_version).unwrap()
1260 );
1261
1262 client.retrieve_epoch_settings().await.unwrap();
1263
1264 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1265 }
1266
1267 #[tokio::test]
1268 async fn test_register_signer_ok_201_log_warning_if_api_version_mismatch() {
1269 let aggregator_version = "2.0.0";
1270 let signer_version = "1.0.0";
1271 let epoch = Epoch(1);
1272 let single_signers = fake_data::signers(1);
1273 let single_signer = single_signers.first().unwrap();
1274 let (server, mut client) = setup_server_and_client();
1275 let (logger, log_inspector) = TestLogger::memory();
1276 let version_provider = version_provider_with_open_api_version(signer_version);
1277 client.api_version_provider = Arc::new(version_provider);
1278 client.logger = logger;
1279 let _server_mock = server.mock(|when, then| {
1280 when.method(POST).path("/register-signer");
1281 then.status(201)
1282 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1283 });
1284
1285 assert!(
1286 Version::parse(aggregator_version).unwrap()
1287 > Version::parse(signer_version).unwrap()
1288 );
1289
1290 client.register_signer(epoch, single_signer).await.unwrap();
1291
1292 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1293 }
1294
1295 #[tokio::test]
1296 async fn test_register_signature_ok_201_log_warning_if_api_version_mismatch() {
1297 let aggregator_version = "2.0.0";
1298 let signer_version = "1.0.0";
1299 let single_signature = fake_data::single_signature((1..5).collect());
1300 let (server, mut client) = setup_server_and_client();
1301 let (logger, log_inspector) = TestLogger::memory();
1302 let version_provider = version_provider_with_open_api_version(signer_version);
1303 client.api_version_provider = Arc::new(version_provider);
1304 client.logger = logger;
1305 let _server_mock = server.mock(|when, then| {
1306 when.method(POST).path("/register-signatures");
1307 then.status(201)
1308 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1309 });
1310
1311 assert!(
1312 Version::parse(aggregator_version).unwrap()
1313 > Version::parse(signer_version).unwrap()
1314 );
1315
1316 client
1317 .register_signature(
1318 &SignedEntityType::dummy(),
1319 &single_signature,
1320 &ProtocolMessage::default(),
1321 )
1322 .await
1323 .expect("Should not fail");
1324
1325 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1326 }
1327
1328 #[tokio::test]
1329 async fn test_register_signature_ok_202_log_warning_if_api_version_mismatch() {
1330 let aggregator_version = "2.0.0";
1331 let signer_version = "1.0.0";
1332 let single_signature = fake_data::single_signature((1..5).collect());
1333 let (server, mut client) = setup_server_and_client();
1334 let (logger, log_inspector) = TestLogger::memory();
1335 let version_provider = version_provider_with_open_api_version(signer_version);
1336 client.api_version_provider = Arc::new(version_provider);
1337 client.logger = logger;
1338 let _server_mock = server.mock(|when, then| {
1339 when.method(POST).path("/register-signatures");
1340 then.status(202)
1341 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1342 });
1343
1344 assert!(
1345 Version::parse(aggregator_version).unwrap()
1346 > Version::parse(signer_version).unwrap()
1347 );
1348
1349 client
1350 .register_signature(
1351 &SignedEntityType::dummy(),
1352 &single_signature,
1353 &ProtocolMessage::default(),
1354 )
1355 .await
1356 .unwrap();
1357
1358 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1359 }
1360
1361 #[tokio::test]
1362 async fn test_register_signature_ok_410_log_warning_if_api_version_mismatch() {
1363 let aggregator_version = "2.0.0";
1364 let signer_version = "1.0.0";
1365 let single_signature = fake_data::single_signature((1..5).collect());
1366 let (server, mut client) = setup_server_and_client();
1367 let (logger, log_inspector) = TestLogger::memory();
1368 let version_provider = version_provider_with_open_api_version(signer_version);
1369 client.api_version_provider = Arc::new(version_provider);
1370 client.logger = logger;
1371 let _server_mock = server.mock(|when, then| {
1372 when.method(POST).path("/register-signatures");
1373 then.status(410)
1374 .body(
1375 serde_json::to_vec(&ClientError::new(
1376 "already_aggregated".to_string(),
1377 "too late".to_string(),
1378 ))
1379 .unwrap(),
1380 )
1381 .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1382 });
1383
1384 assert!(
1385 Version::parse(aggregator_version).unwrap()
1386 > Version::parse(signer_version).unwrap()
1387 );
1388
1389 client
1390 .register_signature(
1391 &SignedEntityType::dummy(),
1392 &single_signature,
1393 &ProtocolMessage::default(),
1394 )
1395 .await
1396 .unwrap();
1397
1398 assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1399 }
1400 }
1401}