mithril_signer/services/
aggregator_client.rs

1use anyhow::anyhow;
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
5use slog::{debug, error, Logger};
6use std::{io, sync::Arc, time::Duration};
7use thiserror::Error;
8
9use mithril_common::{
10    api_version::APIVersionProvider,
11    entities::{
12        ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer, SingleSignature,
13    },
14    logging::LoggerExtensions,
15    messages::{
16        AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
17    },
18    StdError, MITHRIL_API_VERSION_HEADER, MITHRIL_SIGNER_VERSION_HEADER,
19};
20
21use crate::entities::SignerEpochSettings;
22use crate::message_adapters::{
23    FromEpochSettingsAdapter, ToRegisterSignatureMessageAdapter, ToRegisterSignerMessageAdapter,
24};
25
26const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
27
28/// Error structure for the Aggregator Client.
29#[derive(Error, Debug)]
30pub enum AggregatorClientError {
31    /// The aggregator host has returned a technical error.
32    #[error("remote server technical error")]
33    RemoteServerTechnical(#[source] StdError),
34
35    /// The aggregator host responded it cannot fulfill our request.
36    #[error("remote server logical error")]
37    RemoteServerLogical(#[source] StdError),
38
39    /// Could not reach aggregator.
40    #[error("remote server unreachable")]
41    RemoteServerUnreachable(#[source] StdError),
42
43    /// Unhandled status code
44    #[error("unhandled status code: {0}, response text: {1}")]
45    UnhandledStatusCode(StatusCode, String),
46
47    /// Could not parse response.
48    #[error("json parsing failed")]
49    JsonParseFailed(#[source] StdError),
50
51    /// Mostly network errors.
52    #[error("Input/Output error")]
53    IOError(#[from] io::Error),
54
55    /// Incompatible API version error
56    #[error("HTTP API version mismatch")]
57    ApiVersionMismatch(#[source] StdError),
58
59    /// HTTP client creation error
60    #[error("HTTP client creation failed")]
61    HTTPClientCreation(#[source] StdError),
62
63    /// Proxy creation error
64    #[error("proxy creation failed")]
65    ProxyCreation(#[source] StdError),
66
67    /// Adapter error
68    #[error("adapter failed")]
69    Adapter(#[source] StdError),
70
71    /// No signer registration round opened yet
72    #[error("a signer registration round is not opened yet, please try again later")]
73    RegistrationRoundNotYetOpened(#[source] StdError),
74}
75
76#[cfg(test)]
77/// convenient methods to error enum
78impl AggregatorClientError {
79    pub(crate) fn is_api_version_mismatch(&self) -> bool {
80        matches!(self, Self::ApiVersionMismatch(_))
81    }
82}
83
84impl AggregatorClientError {
85    /// Create an `AggregatorClientError` from a response.
86    ///
87    /// This method is meant to be used after handling domain-specific cases leaving only
88    /// 4xx or 5xx status codes.
89    /// Otherwise, it will return an `UnhandledStatusCode` error.
90    pub async fn from_response(response: Response) -> Self {
91        let error_code = response.status();
92
93        if error_code.is_client_error() {
94            let root_cause = Self::get_root_cause(response).await;
95            Self::RemoteServerLogical(anyhow!(root_cause))
96        } else if error_code.is_server_error() {
97            let root_cause = Self::get_root_cause(response).await;
98            match error_code.as_u16() {
99                550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)),
100                _ => Self::RemoteServerTechnical(anyhow!(root_cause)),
101            }
102        } else {
103            let response_text = response.text().await.unwrap_or_default();
104            Self::UnhandledStatusCode(error_code, response_text)
105        }
106    }
107
108    async fn get_root_cause(response: Response) -> String {
109        let error_code = response.status();
110        let canonical_reason = error_code
111            .canonical_reason()
112            .unwrap_or_default()
113            .to_lowercase();
114        let is_json = response
115            .headers()
116            .get(header::CONTENT_TYPE)
117            .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
118
119        if is_json {
120            let json_value: serde_json::Value = response.json().await.unwrap_or_default();
121
122            if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
123                format!(
124                    "{}: {}: {}",
125                    canonical_reason, client_error.label, client_error.message
126                )
127            } else if let Ok(server_error) =
128                serde_json::from_value::<ServerError>(json_value.clone())
129            {
130                format!("{}: {}", canonical_reason, server_error.message)
131            } else if json_value.is_null() {
132                canonical_reason.to_string()
133            } else {
134                format!("{}: {}", canonical_reason, json_value)
135            }
136        } else {
137            let response_text = response.text().await.unwrap_or_default();
138            format!("{}: {}", canonical_reason, response_text)
139        }
140    }
141}
142
143/// Trait for mocking and testing a `AggregatorClient`
144#[cfg_attr(test, mockall::automock)]
145#[async_trait]
146pub trait AggregatorClient: Sync + Send {
147    /// Retrieves epoch settings from the aggregator
148    async fn retrieve_epoch_settings(
149        &self,
150    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError>;
151
152    /// Registers signer with the aggregator.
153    async fn register_signer(
154        &self,
155        epoch: Epoch,
156        signer: &Signer,
157    ) -> Result<(), AggregatorClientError>;
158
159    /// Registers single signature with the aggregator.
160    async fn register_signature(
161        &self,
162        signed_entity_type: &SignedEntityType,
163        signature: &SingleSignature,
164        protocol_message: &ProtocolMessage,
165    ) -> Result<(), AggregatorClientError>;
166
167    /// Retrieves aggregator features message from the aggregator
168    async fn retrieve_aggregator_features(
169        &self,
170    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError>;
171}
172
173/// AggregatorHTTPClient is a http client for an aggregator
174pub struct AggregatorHTTPClient {
175    aggregator_endpoint: String,
176    relay_endpoint: Option<String>,
177    api_version_provider: Arc<APIVersionProvider>,
178    timeout_duration: Option<Duration>,
179    logger: Logger,
180}
181
182impl AggregatorHTTPClient {
183    /// AggregatorHTTPClient factory
184    pub fn new(
185        aggregator_endpoint: String,
186        relay_endpoint: Option<String>,
187        api_version_provider: Arc<APIVersionProvider>,
188        timeout_duration: Option<Duration>,
189        logger: Logger,
190    ) -> Self {
191        let logger = logger.new_with_component_name::<Self>();
192        debug!(logger, "New AggregatorHTTPClient created");
193        Self {
194            aggregator_endpoint,
195            relay_endpoint,
196            api_version_provider,
197            timeout_duration,
198            logger,
199        }
200    }
201
202    fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
203        let client = match &self.relay_endpoint {
204            Some(relay_endpoint) => Client::builder()
205                .proxy(
206                    Proxy::all(relay_endpoint)
207                        .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
208                )
209                .build()
210                .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
211            None => Client::new(),
212        };
213
214        Ok(client)
215    }
216
217    /// Forge a client request adding protocol version in the headers.
218    pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
219        let request_builder = request_builder
220            .header(
221                MITHRIL_API_VERSION_HEADER,
222                self.api_version_provider
223                    .compute_current_version()
224                    .unwrap()
225                    .to_string(),
226            )
227            .header(MITHRIL_SIGNER_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
228
229        if let Some(duration) = self.timeout_duration {
230            request_builder.timeout(duration)
231        } else {
232            request_builder
233        }
234    }
235
236    /// API version error handling
237    fn handle_api_error(&self, response: &Response) -> AggregatorClientError {
238        if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
239            AggregatorClientError::ApiVersionMismatch(anyhow!(
240                "server version: '{}', signer version: '{}'",
241                version.to_str().unwrap(),
242                self.api_version_provider.compute_current_version().unwrap()
243            ))
244        } else {
245            AggregatorClientError::ApiVersionMismatch(anyhow!(
246                "version precondition failed, sent version '{}'.",
247                self.api_version_provider.compute_current_version().unwrap()
248            ))
249        }
250    }
251}
252
253#[async_trait]
254impl AggregatorClient for AggregatorHTTPClient {
255    async fn retrieve_epoch_settings(
256        &self,
257    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
258        debug!(self.logger, "Retrieve epoch settings");
259        let url = format!("{}/epoch-settings", self.aggregator_endpoint);
260        let response = self
261            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
262            .send()
263            .await;
264
265        match response {
266            Ok(response) => match response.status() {
267                StatusCode::OK => match response.json::<EpochSettingsMessage>().await {
268                    Ok(message) => {
269                        let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
270                            .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
271                        Ok(Some(epoch_settings))
272                    }
273                    Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
274                },
275                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
276                _ => Err(AggregatorClientError::from_response(response).await),
277            },
278            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
279        }
280    }
281
282    async fn register_signer(
283        &self,
284        epoch: Epoch,
285        signer: &Signer,
286    ) -> Result<(), AggregatorClientError> {
287        debug!(self.logger, "Register signer");
288        let url = format!("{}/register-signer", self.aggregator_endpoint);
289        let register_signer_message =
290            ToRegisterSignerMessageAdapter::try_adapt((epoch, signer.to_owned()))
291                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
292        let response = self
293            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
294            .json(&register_signer_message)
295            .send()
296            .await;
297
298        match response {
299            Ok(response) => match response.status() {
300                StatusCode::CREATED => Ok(()),
301                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
302                _ => Err(AggregatorClientError::from_response(response).await),
303            },
304            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
305        }
306    }
307
308    async fn register_signature(
309        &self,
310        signed_entity_type: &SignedEntityType,
311        signature: &SingleSignature,
312        protocol_message: &ProtocolMessage,
313    ) -> Result<(), AggregatorClientError> {
314        debug!(self.logger, "Register signature");
315        let url = format!("{}/register-signatures", self.aggregator_endpoint);
316        let register_single_signature_message = ToRegisterSignatureMessageAdapter::try_adapt((
317            signed_entity_type.to_owned(),
318            signature.to_owned(),
319            protocol_message,
320        ))
321        .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
322        let response = self
323            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
324            .json(&register_single_signature_message)
325            .send()
326            .await;
327
328        match response {
329            Ok(response) => match response.status() {
330                StatusCode::CREATED | StatusCode::ACCEPTED => Ok(()),
331                StatusCode::GONE => {
332                    let root_cause = AggregatorClientError::get_root_cause(response).await;
333                    debug!(self.logger, "Message already certified or expired"; "details" => &root_cause);
334
335                    Ok(())
336                }
337                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
338                _ => Err(AggregatorClientError::from_response(response).await),
339            },
340            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
341        }
342    }
343
344    async fn retrieve_aggregator_features(
345        &self,
346    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
347        debug!(self.logger, "Retrieve aggregator features message");
348        let url = format!("{}/", self.aggregator_endpoint);
349        let response = self
350            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
351            .send()
352            .await;
353
354        match response {
355            Ok(response) => match response.status() {
356                StatusCode::OK => Ok(response
357                    .json::<AggregatorFeaturesMessage>()
358                    .await
359                    .map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?),
360                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
361                _ => Err(AggregatorClientError::from_response(response).await),
362            },
363            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
364        }
365    }
366}
367
368#[cfg(test)]
369pub(crate) mod dumb {
370    use tokio::sync::RwLock;
371
372    use super::*;
373
374    /// This aggregator client is intended to be used by test services.
375    /// It actually does not communicate with an aggregator host but mimics this behavior.
376    /// It is driven by a Tester that controls the data it can return, and it can return its internal state for testing.
377    pub struct DumbAggregatorClient {
378        epoch_settings: RwLock<Option<SignerEpochSettings>>,
379        last_registered_signer: RwLock<Option<Signer>>,
380        aggregator_features: RwLock<AggregatorFeaturesMessage>,
381    }
382
383    impl DumbAggregatorClient {
384        /// Return the last signer that called with the `register` method.
385        pub async fn get_last_registered_signer(&self) -> Option<Signer> {
386            self.last_registered_signer.read().await.clone()
387        }
388
389        pub async fn set_aggregator_features(
390            &self,
391            aggregator_features: AggregatorFeaturesMessage,
392        ) {
393            let mut aggregator_features_writer = self.aggregator_features.write().await;
394            *aggregator_features_writer = aggregator_features;
395        }
396    }
397
398    impl Default for DumbAggregatorClient {
399        fn default() -> Self {
400            Self {
401                epoch_settings: RwLock::new(Some(SignerEpochSettings::dummy())),
402                last_registered_signer: RwLock::new(None),
403                aggregator_features: RwLock::new(AggregatorFeaturesMessage::dummy()),
404            }
405        }
406    }
407
408    #[async_trait]
409    impl AggregatorClient for DumbAggregatorClient {
410        async fn retrieve_epoch_settings(
411            &self,
412        ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
413            let epoch_settings = self.epoch_settings.read().await.clone();
414
415            Ok(epoch_settings)
416        }
417
418        /// Registers signer with the aggregator
419        async fn register_signer(
420            &self,
421            _epoch: Epoch,
422            signer: &Signer,
423        ) -> Result<(), AggregatorClientError> {
424            let mut last_registered_signer = self.last_registered_signer.write().await;
425            let signer = signer.clone();
426            *last_registered_signer = Some(signer);
427
428            Ok(())
429        }
430
431        /// Registers single signature with the aggregator
432        async fn register_signature(
433            &self,
434            _signed_entity_type: &SignedEntityType,
435            _signature: &SingleSignature,
436            _protocol_message: &ProtocolMessage,
437        ) -> Result<(), AggregatorClientError> {
438            Ok(())
439        }
440
441        async fn retrieve_aggregator_features(
442            &self,
443        ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
444            let aggregator_features = self.aggregator_features.read().await;
445            Ok(aggregator_features.clone())
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use http::response::Builder as HttpResponseBuilder;
453    use httpmock::prelude::*;
454    use serde_json::json;
455
456    use mithril_common::entities::Epoch;
457    use mithril_common::era::{EraChecker, SupportedEra};
458    use mithril_common::messages::TryFromMessageAdapter;
459    use mithril_common::test_utils::fake_data;
460
461    use crate::test_tools::TestLogger;
462
463    use super::*;
464
465    macro_rules! assert_is_error {
466        ($error:expr, $error_type:pat) => {
467            assert!(
468                matches!($error, $error_type),
469                "Expected {} error, got '{:?}'.",
470                stringify!($error_type),
471                $error
472            );
473        };
474    }
475
476    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
477        let server = MockServer::start();
478        let aggregator_endpoint = server.url("");
479        let relay_endpoint = None;
480        let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
481        let api_version_provider = APIVersionProvider::new(Arc::new(era_checker));
482
483        (
484            server,
485            AggregatorHTTPClient::new(
486                aggregator_endpoint,
487                relay_endpoint,
488                Arc::new(api_version_provider),
489                None,
490                TestLogger::stdout(),
491            ),
492        )
493    }
494
495    fn set_returning_412(server: &MockServer) {
496        server.mock(|_, then| {
497            then.status(412)
498                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
499        });
500    }
501
502    fn set_returning_500(server: &MockServer) {
503        server.mock(|_, then| {
504            then.status(500).body("an error occurred");
505        });
506    }
507
508    fn set_unparsable_json(server: &MockServer) {
509        server.mock(|_, then| {
510            then.status(200).body("this is not a json");
511        });
512    }
513
514    fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
515        HttpResponseBuilder::new()
516            .status(status_code)
517            .body(body.into())
518            .unwrap()
519            .into()
520    }
521
522    fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
523        HttpResponseBuilder::new()
524            .status(status_code)
525            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
526            .body(serde_json::to_string(&body).unwrap())
527            .unwrap()
528            .into()
529    }
530
531    macro_rules! assert_error_text_contains {
532        ($error: expr, $expect_contains: expr) => {
533            let error = &$error;
534            assert!(
535                error.contains($expect_contains),
536                "Expected error message to contain '{}'\ngot '{error:?}'",
537                $expect_contains,
538            );
539        };
540    }
541
542    #[tokio::test]
543    async fn test_aggregator_features_ok_200() {
544        let (server, client) = setup_server_and_client();
545        let message_expected = AggregatorFeaturesMessage::dummy();
546        let _server_mock = server.mock(|when, then| {
547            when.path("/");
548            then.status(200).body(json!(message_expected).to_string());
549        });
550
551        let message = client.retrieve_aggregator_features().await.unwrap();
552
553        assert_eq!(message_expected, message);
554    }
555
556    #[tokio::test]
557    async fn test_aggregator_features_ko_412() {
558        let (server, client) = setup_server_and_client();
559        set_returning_412(&server);
560
561        let error = client.retrieve_aggregator_features().await.unwrap_err();
562
563        assert_is_error!(error, AggregatorClientError::ApiVersionMismatch(_));
564    }
565
566    #[tokio::test]
567    async fn test_aggregator_features_ko_500() {
568        let (server, client) = setup_server_and_client();
569        set_returning_500(&server);
570
571        let error = client.retrieve_aggregator_features().await.unwrap_err();
572
573        assert_is_error!(error, AggregatorClientError::RemoteServerTechnical(_));
574    }
575
576    #[tokio::test]
577    async fn test_aggregator_features_ko_json_serialization() {
578        let (server, client) = setup_server_and_client();
579        set_unparsable_json(&server);
580
581        let error = client.retrieve_aggregator_features().await.unwrap_err();
582
583        assert_is_error!(error, AggregatorClientError::JsonParseFailed(_));
584    }
585
586    #[tokio::test]
587    async fn test_aggregator_features_timeout() {
588        let (server, mut client) = setup_server_and_client();
589        client.timeout_duration = Some(Duration::from_millis(10));
590        let _server_mock = server.mock(|when, then| {
591            when.path("/");
592            then.delay(Duration::from_millis(100));
593        });
594
595        let error = client.retrieve_aggregator_features().await.unwrap_err();
596
597        assert_is_error!(error, AggregatorClientError::RemoteServerUnreachable(_));
598    }
599
600    #[tokio::test]
601    async fn test_epoch_settings_ok_200() {
602        let (server, client) = setup_server_and_client();
603        let epoch_settings_expected = EpochSettingsMessage::dummy();
604        let _server_mock = server.mock(|when, then| {
605            when.path("/epoch-settings");
606            then.status(200)
607                .body(json!(epoch_settings_expected).to_string());
608        });
609
610        let epoch_settings = client.retrieve_epoch_settings().await;
611        epoch_settings.as_ref().expect("unexpected error");
612        assert_eq!(
613            FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
614            epoch_settings.unwrap().unwrap()
615        );
616    }
617
618    #[tokio::test]
619    async fn test_epoch_settings_ko_412() {
620        let (server, client) = setup_server_and_client();
621        let _server_mock = server.mock(|when, then| {
622            when.path("/epoch-settings");
623            then.status(412)
624                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
625        });
626
627        let epoch_settings = client.retrieve_epoch_settings().await.unwrap_err();
628
629        assert!(epoch_settings.is_api_version_mismatch());
630    }
631
632    #[tokio::test]
633    async fn test_epoch_settings_ko_500() {
634        let (server, client) = setup_server_and_client();
635        let _server_mock = server.mock(|when, then| {
636            when.path("/epoch-settings");
637            then.status(500).body("an error occurred");
638        });
639
640        match client.retrieve_epoch_settings().await.unwrap_err() {
641            AggregatorClientError::RemoteServerTechnical(_) => (),
642            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
643        };
644    }
645
646    #[tokio::test]
647    async fn test_epoch_settings_timeout() {
648        let (server, mut client) = setup_server_and_client();
649        client.timeout_duration = Some(Duration::from_millis(10));
650        let _server_mock = server.mock(|when, then| {
651            when.path("/epoch-settings");
652            then.delay(Duration::from_millis(100));
653        });
654
655        let error = client
656            .retrieve_epoch_settings()
657            .await
658            .expect_err("retrieve_epoch_settings should fail");
659
660        assert!(
661            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
662            "unexpected error type: {error:?}"
663        );
664    }
665
666    #[tokio::test]
667    async fn test_register_signer_ok_201() {
668        let epoch = Epoch(1);
669        let single_signers = fake_data::signers(1);
670        let single_signer = single_signers.first().unwrap();
671        let (server, client) = setup_server_and_client();
672        let _server_mock = server.mock(|when, then| {
673            when.method(POST).path("/register-signer");
674            then.status(201);
675        });
676
677        let register_signer = client.register_signer(epoch, single_signer).await;
678        register_signer.expect("unexpected error");
679    }
680
681    #[tokio::test]
682    async fn test_register_signer_ko_412() {
683        let epoch = Epoch(1);
684        let (server, client) = setup_server_and_client();
685        let _server_mock = server.mock(|when, then| {
686            when.method(POST).path("/register-signer");
687            then.status(412)
688                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
689        });
690        let single_signers = fake_data::signers(1);
691        let single_signer = single_signers.first().unwrap();
692
693        let error = client
694            .register_signer(epoch, single_signer)
695            .await
696            .unwrap_err();
697
698        assert!(error.is_api_version_mismatch());
699    }
700
701    #[tokio::test]
702    async fn test_register_signer_ko_400() {
703        let epoch = Epoch(1);
704        let single_signers = fake_data::signers(1);
705        let single_signer = single_signers.first().unwrap();
706        let (server, client) = setup_server_and_client();
707        let _server_mock = server.mock(|when, then| {
708            when.method(POST).path("/register-signer");
709            then.status(400).body(
710                serde_json::to_vec(&ClientError::new(
711                    "error".to_string(),
712                    "an error".to_string(),
713                ))
714                .unwrap(),
715            );
716        });
717
718        match client
719            .register_signer(epoch, single_signer)
720            .await
721            .unwrap_err()
722        {
723            AggregatorClientError::RemoteServerLogical(_) => (),
724            err => {
725                panic!(
726                    "Expected a AggregatorClientError::RemoteServerLogical error, got '{err:?}'."
727                )
728            }
729        };
730    }
731
732    #[tokio::test]
733    async fn test_register_signer_ko_500() {
734        let epoch = Epoch(1);
735        let single_signers = fake_data::signers(1);
736        let single_signer = single_signers.first().unwrap();
737        let (server, client) = setup_server_and_client();
738        let _server_mock = server.mock(|when, then| {
739            when.method(POST).path("/register-signer");
740            then.status(500).body("an error occurred");
741        });
742
743        match client
744            .register_signer(epoch, single_signer)
745            .await
746            .unwrap_err()
747        {
748            AggregatorClientError::RemoteServerTechnical(_) => (),
749            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
750        };
751    }
752
753    #[tokio::test]
754    async fn test_register_signer_timeout() {
755        let epoch = Epoch(1);
756        let single_signers = fake_data::signers(1);
757        let single_signer = single_signers.first().unwrap();
758        let (server, mut client) = setup_server_and_client();
759        client.timeout_duration = Some(Duration::from_millis(10));
760        let _server_mock = server.mock(|when, then| {
761            when.method(POST).path("/register-signer");
762            then.delay(Duration::from_millis(100));
763        });
764
765        let error = client
766            .register_signer(epoch, single_signer)
767            .await
768            .expect_err("register_signer should fail");
769
770        assert!(
771            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
772            "unexpected error type: {error:?}"
773        );
774    }
775
776    #[tokio::test]
777    async fn test_register_signature_ok_201() {
778        let single_signature = fake_data::single_signature((1..5).collect());
779        let (server, client) = setup_server_and_client();
780        let _server_mock = server.mock(|when, then| {
781            when.method(POST).path("/register-signatures");
782            then.status(201);
783        });
784
785        let register_signature = client
786            .register_signature(
787                &SignedEntityType::dummy(),
788                &single_signature,
789                &ProtocolMessage::default(),
790            )
791            .await;
792        register_signature.expect("unexpected error");
793    }
794
795    #[tokio::test]
796    async fn test_register_signature_ok_202() {
797        let single_signature = fake_data::single_signature((1..5).collect());
798        let (server, client) = setup_server_and_client();
799        let _server_mock = server.mock(|when, then| {
800            when.method(POST).path("/register-signatures");
801            then.status(202);
802        });
803
804        let register_signature = client
805            .register_signature(
806                &SignedEntityType::dummy(),
807                &single_signature,
808                &ProtocolMessage::default(),
809            )
810            .await;
811        register_signature.expect("unexpected error");
812    }
813
814    #[tokio::test]
815    async fn test_register_signature_ko_412() {
816        let (server, client) = setup_server_and_client();
817        let _server_mock = server.mock(|when, then| {
818            when.method(POST).path("/register-signatures");
819            then.status(412)
820                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
821        });
822        let single_signature = fake_data::single_signature((1..5).collect());
823
824        let error = client
825            .register_signature(
826                &SignedEntityType::dummy(),
827                &single_signature,
828                &ProtocolMessage::default(),
829            )
830            .await
831            .unwrap_err();
832
833        assert!(error.is_api_version_mismatch());
834    }
835
836    #[tokio::test]
837    async fn test_register_signature_ko_400() {
838        let single_signature = fake_data::single_signature((1..5).collect());
839        let (server, client) = setup_server_and_client();
840        let _server_mock = server.mock(|when, then| {
841            when.method(POST).path("/register-signatures");
842            then.status(400).body(
843                serde_json::to_vec(&ClientError::new(
844                    "error".to_string(),
845                    "an error".to_string(),
846                ))
847                .unwrap(),
848            );
849        });
850
851        match client
852            .register_signature(
853                &SignedEntityType::dummy(),
854                &single_signature,
855                &ProtocolMessage::default(),
856            )
857            .await
858            .unwrap_err()
859        {
860            AggregatorClientError::RemoteServerLogical(_) => (),
861            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
862        };
863    }
864
865    #[tokio::test]
866    async fn test_register_signature_ok_410_log_response_body() {
867        let (logger, log_inspector) = TestLogger::memory();
868
869        let single_signature = fake_data::single_signature((1..5).collect());
870        let (server, mut client) = setup_server_and_client();
871        client.logger = logger;
872        let _server_mock = server.mock(|when, then| {
873            when.method(POST).path("/register-signatures");
874            then.status(410).body(
875                serde_json::to_vec(&ClientError::new(
876                    "already_aggregated".to_string(),
877                    "too late".to_string(),
878                ))
879                .unwrap(),
880            );
881        });
882
883        client
884            .register_signature(
885                &SignedEntityType::dummy(),
886                &single_signature,
887                &ProtocolMessage::default(),
888            )
889            .await
890            .expect("Should not fail when status is 410 (GONE)");
891
892        assert!(log_inspector.contains_log("already_aggregated"));
893        assert!(log_inspector.contains_log("too late"));
894    }
895
896    #[tokio::test]
897    async fn test_register_signature_ko_409() {
898        let single_signature = fake_data::single_signature((1..5).collect());
899        let (server, client) = setup_server_and_client();
900        let _server_mock = server.mock(|when, then| {
901            when.method(POST).path("/register-signatures");
902            then.status(409);
903        });
904
905        match client
906            .register_signature(
907                &SignedEntityType::dummy(),
908                &single_signature,
909                &ProtocolMessage::default(),
910            )
911            .await
912            .unwrap_err()
913        {
914            AggregatorClientError::RemoteServerLogical(_) => (),
915            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
916        }
917    }
918
919    #[tokio::test]
920    async fn test_register_signature_ko_500() {
921        let single_signature = fake_data::single_signature((1..5).collect());
922        let (server, client) = setup_server_and_client();
923        let _server_mock = server.mock(|when, then| {
924            when.method(POST).path("/register-signatures");
925            then.status(500).body("an error occurred");
926        });
927
928        match client
929            .register_signature(
930                &SignedEntityType::dummy(),
931                &single_signature,
932                &ProtocolMessage::default(),
933            )
934            .await
935            .unwrap_err()
936        {
937            AggregatorClientError::RemoteServerTechnical(_) => (),
938            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
939        };
940    }
941
942    #[tokio::test]
943    async fn test_register_signature_timeout() {
944        let single_signature = fake_data::single_signature((1..5).collect());
945        let (server, mut client) = setup_server_and_client();
946        client.timeout_duration = Some(Duration::from_millis(10));
947        let _server_mock = server.mock(|when, then| {
948            when.method(POST).path("/register-signatures");
949            then.delay(Duration::from_millis(100));
950        });
951
952        let error = client
953            .register_signature(
954                &SignedEntityType::dummy(),
955                &single_signature,
956                &ProtocolMessage::default(),
957            )
958            .await
959            .expect_err("register_signature should fail");
960
961        assert!(
962            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
963            "unexpected error type: {error:?}"
964        );
965    }
966
967    #[tokio::test]
968    async fn test_4xx_errors_are_handled_as_remote_server_logical() {
969        let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
970        let handled_error = AggregatorClientError::from_response(response).await;
971
972        assert!(
973            matches!(
974                handled_error,
975                AggregatorClientError::RemoteServerLogical(..)
976            ),
977            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
978        );
979    }
980
981    #[tokio::test]
982    async fn test_5xx_errors_are_handled_as_remote_server_technical() {
983        let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
984        let handled_error = AggregatorClientError::from_response(response).await;
985
986        assert!(
987            matches!(
988                handled_error,
989                AggregatorClientError::RemoteServerTechnical(..)
990            ),
991            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
992        );
993    }
994
995    #[tokio::test]
996    async fn test_550_error_is_handled_as_registration_round_not_yet_opened() {
997        let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available");
998        let handled_error = AggregatorClientError::from_response(response).await;
999
1000        assert!(
1001            matches!(
1002                handled_error,
1003                AggregatorClientError::RegistrationRoundNotYetOpened(..)
1004            ),
1005            "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'",
1006        );
1007    }
1008
1009    #[tokio::test]
1010    async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text(
1011    ) {
1012        let response = build_text_response(StatusCode::OK, "ok text");
1013        let handled_error = AggregatorClientError::from_response(response).await;
1014
1015        assert!(
1016            matches!(
1017                handled_error,
1018                AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
1019            ),
1020            "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
1021        );
1022    }
1023
1024    #[tokio::test]
1025    async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
1026        let error_text = "An error occurred; please try again later.";
1027        let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
1028
1029        assert_error_text_contains!(
1030            AggregatorClientError::get_root_cause(response).await,
1031            "expectation failed: An error occurred; please try again later."
1032        );
1033    }
1034
1035    #[tokio::test]
1036    async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message(
1037    ) {
1038        let client_error = ClientError::new("label", "message");
1039        let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
1040
1041        assert_error_text_contains!(
1042            AggregatorClientError::get_root_cause(response).await,
1043            "bad request: label: message"
1044        );
1045    }
1046
1047    #[tokio::test]
1048    async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message(
1049    ) {
1050        let server_error = ServerError::new("message");
1051        let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
1052
1053        assert_error_text_contains!(
1054            AggregatorClientError::get_root_cause(response).await,
1055            "bad request: message"
1056        );
1057    }
1058
1059    #[tokio::test]
1060    async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
1061        let response = build_json_response(
1062            StatusCode::INTERNAL_SERVER_ERROR,
1063            &json!({ "second": "unknown", "first": "foreign" }),
1064        );
1065
1066        assert_error_text_contains!(
1067            AggregatorClientError::get_root_cause(response).await,
1068            r#"internal server error: {"first":"foreign","second":"unknown"}"#
1069        );
1070    }
1071
1072    #[tokio::test]
1073    async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
1074        let response = HttpResponseBuilder::new()
1075            .status(StatusCode::BAD_REQUEST)
1076            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
1077            .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
1078            .unwrap()
1079            .into();
1080
1081        let root_cause = AggregatorClientError::get_root_cause(response).await;
1082
1083        assert_error_text_contains!(root_cause, "bad request");
1084        assert!(
1085            !root_cause.contains("bad request: "),
1086            "Expected error message should not contain additional information \ngot '{root_cause:?}'"
1087        );
1088    }
1089
1090    #[tokio::test]
1091    async fn test_sends_accept_encoding_header() {
1092        let (server, client) = setup_server_and_client();
1093        server.mock(|when, then| {
1094            when.matches(|req| {
1095                let headers = req.headers.clone().expect("HTTP headers not found");
1096                let accept_encoding_header = headers
1097                    .iter()
1098                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
1099                    .expect("Accept-Encoding header not found");
1100
1101                let header_value = accept_encoding_header.clone().1;
1102                ["gzip", "br", "deflate", "zstd"]
1103                    .iter()
1104                    .all(|&value| header_value.contains(value))
1105            });
1106
1107            then.status(201);
1108        });
1109
1110        client
1111            .register_signature(
1112                &SignedEntityType::dummy(),
1113                &fake_data::single_signature((1..5).collect()),
1114                &ProtocolMessage::default(),
1115            )
1116            .await
1117            .expect("Should succeed with Accept-Encoding header");
1118    }
1119}