mithril_signer/services/
aggregator_client.rs

1use anyhow::anyhow;
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
5use slog::{debug, error, Logger};
6use std::{io, sync::Arc, time::Duration};
7use thiserror::Error;
8
9use mithril_common::{
10    api_version::APIVersionProvider,
11    entities::{
12        ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer,
13        SingleSignatures,
14    },
15    logging::LoggerExtensions,
16    messages::{
17        AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
18    },
19    StdError, StdResult, MITHRIL_API_VERSION_HEADER, MITHRIL_SIGNER_VERSION_HEADER,
20};
21
22use crate::entities::SignerEpochSettings;
23use crate::message_adapters::{
24    FromEpochSettingsAdapter, ToRegisterSignatureMessageAdapter, ToRegisterSignerMessageAdapter,
25};
26use crate::services::SignaturePublisher;
27
28const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
29
30/// Error structure for the Aggregator Client.
31#[derive(Error, Debug)]
32pub enum AggregatorClientError {
33    /// The aggregator host has returned a technical error.
34    #[error("remote server technical error")]
35    RemoteServerTechnical(#[source] StdError),
36
37    /// The aggregator host responded it cannot fulfill our request.
38    #[error("remote server logical error")]
39    RemoteServerLogical(#[source] StdError),
40
41    /// Could not reach aggregator.
42    #[error("remote server unreachable")]
43    RemoteServerUnreachable(#[source] StdError),
44
45    /// Unhandled status code
46    #[error("unhandled status code: {0}, response text: {1}")]
47    UnhandledStatusCode(StatusCode, String),
48
49    /// Could not parse response.
50    #[error("json parsing failed")]
51    JsonParseFailed(#[source] StdError),
52
53    /// Mostly network errors.
54    #[error("Input/Output error")]
55    IOError(#[from] io::Error),
56
57    /// Incompatible API version error
58    #[error("HTTP API version mismatch")]
59    ApiVersionMismatch(#[source] StdError),
60
61    /// HTTP client creation error
62    #[error("HTTP client creation failed")]
63    HTTPClientCreation(#[source] StdError),
64
65    /// Proxy creation error
66    #[error("proxy creation failed")]
67    ProxyCreation(#[source] StdError),
68
69    /// Adapter error
70    #[error("adapter failed")]
71    Adapter(#[source] StdError),
72
73    /// No signer registration round opened yet
74    #[error("a signer registration round is not opened yet, please try again later")]
75    RegistrationRoundNotYetOpened(#[source] StdError),
76}
77
78#[cfg(test)]
79/// convenient methods to error enum
80impl AggregatorClientError {
81    pub(crate) fn is_api_version_mismatch(&self) -> bool {
82        matches!(self, Self::ApiVersionMismatch(_))
83    }
84}
85
86impl AggregatorClientError {
87    /// Create an `AggregatorClientError` from a response.
88    ///
89    /// This method is meant to be used after handling domain-specific cases leaving only
90    /// 4xx or 5xx status codes.
91    /// Otherwise, it will return an `UnhandledStatusCode` error.
92    pub async fn from_response(response: Response) -> Self {
93        let error_code = response.status();
94
95        if error_code.is_client_error() {
96            let root_cause = Self::get_root_cause(response).await;
97            Self::RemoteServerLogical(anyhow!(root_cause))
98        } else if error_code.is_server_error() {
99            let root_cause = Self::get_root_cause(response).await;
100            match error_code.as_u16() {
101                550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)),
102                _ => Self::RemoteServerTechnical(anyhow!(root_cause)),
103            }
104        } else {
105            let response_text = response.text().await.unwrap_or_default();
106            Self::UnhandledStatusCode(error_code, response_text)
107        }
108    }
109
110    async fn get_root_cause(response: Response) -> String {
111        let error_code = response.status();
112        let canonical_reason = error_code
113            .canonical_reason()
114            .unwrap_or_default()
115            .to_lowercase();
116        let is_json = response
117            .headers()
118            .get(header::CONTENT_TYPE)
119            .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
120
121        if is_json {
122            let json_value: serde_json::Value = response.json().await.unwrap_or_default();
123
124            if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
125                format!(
126                    "{}: {}: {}",
127                    canonical_reason, client_error.label, client_error.message
128                )
129            } else if let Ok(server_error) =
130                serde_json::from_value::<ServerError>(json_value.clone())
131            {
132                format!("{}: {}", canonical_reason, server_error.message)
133            } else if json_value.is_null() {
134                canonical_reason.to_string()
135            } else {
136                format!("{}: {}", canonical_reason, json_value)
137            }
138        } else {
139            let response_text = response.text().await.unwrap_or_default();
140            format!("{}: {}", canonical_reason, response_text)
141        }
142    }
143}
144
145/// Trait for mocking and testing a `AggregatorClient`
146#[cfg_attr(test, mockall::automock)]
147#[async_trait]
148pub trait AggregatorClient: Sync + Send {
149    /// Retrieves epoch settings from the aggregator
150    async fn retrieve_epoch_settings(
151        &self,
152    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError>;
153
154    /// Registers signer with the aggregator.
155    async fn register_signer(
156        &self,
157        epoch: Epoch,
158        signer: &Signer,
159    ) -> Result<(), AggregatorClientError>;
160
161    /// Registers single signatures with the aggregator.
162    async fn register_signatures(
163        &self,
164        signed_entity_type: &SignedEntityType,
165        signatures: &SingleSignatures,
166        protocol_message: &ProtocolMessage,
167    ) -> Result<(), AggregatorClientError>;
168
169    /// Retrieves aggregator features message from the aggregator
170    async fn retrieve_aggregator_features(
171        &self,
172    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError>;
173}
174
175#[async_trait]
176impl<T: AggregatorClient> SignaturePublisher for T {
177    async fn publish(
178        &self,
179        signed_entity_type: &SignedEntityType,
180        signatures: &SingleSignatures,
181        protocol_message: &ProtocolMessage,
182    ) -> StdResult<()> {
183        self.register_signatures(signed_entity_type, signatures, protocol_message)
184            .await?;
185        Ok(())
186    }
187}
188
189/// AggregatorHTTPClient is a http client for an aggregator
190pub struct AggregatorHTTPClient {
191    aggregator_endpoint: String,
192    relay_endpoint: Option<String>,
193    api_version_provider: Arc<APIVersionProvider>,
194    timeout_duration: Option<Duration>,
195    logger: Logger,
196}
197
198impl AggregatorHTTPClient {
199    /// AggregatorHTTPClient factory
200    pub fn new(
201        aggregator_endpoint: String,
202        relay_endpoint: Option<String>,
203        api_version_provider: Arc<APIVersionProvider>,
204        timeout_duration: Option<Duration>,
205        logger: Logger,
206    ) -> Self {
207        let logger = logger.new_with_component_name::<Self>();
208        debug!(logger, "New AggregatorHTTPClient created");
209        Self {
210            aggregator_endpoint,
211            relay_endpoint,
212            api_version_provider,
213            timeout_duration,
214            logger,
215        }
216    }
217
218    fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
219        let client = match &self.relay_endpoint {
220            Some(relay_endpoint) => Client::builder()
221                .proxy(
222                    Proxy::all(relay_endpoint)
223                        .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
224                )
225                .build()
226                .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
227            None => Client::new(),
228        };
229
230        Ok(client)
231    }
232
233    /// Forge a client request adding protocol version in the headers.
234    pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
235        let request_builder = request_builder
236            .header(
237                MITHRIL_API_VERSION_HEADER,
238                self.api_version_provider
239                    .compute_current_version()
240                    .unwrap()
241                    .to_string(),
242            )
243            .header(MITHRIL_SIGNER_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
244
245        if let Some(duration) = self.timeout_duration {
246            request_builder.timeout(duration)
247        } else {
248            request_builder
249        }
250    }
251
252    /// API version error handling
253    fn handle_api_error(&self, response: &Response) -> AggregatorClientError {
254        if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
255            AggregatorClientError::ApiVersionMismatch(anyhow!(
256                "server version: '{}', signer version: '{}'",
257                version.to_str().unwrap(),
258                self.api_version_provider.compute_current_version().unwrap()
259            ))
260        } else {
261            AggregatorClientError::ApiVersionMismatch(anyhow!(
262                "version precondition failed, sent version '{}'.",
263                self.api_version_provider.compute_current_version().unwrap()
264            ))
265        }
266    }
267}
268
269#[async_trait]
270impl AggregatorClient for AggregatorHTTPClient {
271    async fn retrieve_epoch_settings(
272        &self,
273    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
274        debug!(self.logger, "Retrieve epoch settings");
275        let url = format!("{}/epoch-settings", self.aggregator_endpoint);
276        let response = self
277            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
278            .send()
279            .await;
280
281        match response {
282            Ok(response) => match response.status() {
283                StatusCode::OK => match response.json::<EpochSettingsMessage>().await {
284                    Ok(message) => {
285                        let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
286                            .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
287                        Ok(Some(epoch_settings))
288                    }
289                    Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
290                },
291                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
292                _ => Err(AggregatorClientError::from_response(response).await),
293            },
294            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
295        }
296    }
297
298    async fn register_signer(
299        &self,
300        epoch: Epoch,
301        signer: &Signer,
302    ) -> Result<(), AggregatorClientError> {
303        debug!(self.logger, "Register signer");
304        let url = format!("{}/register-signer", self.aggregator_endpoint);
305        let register_signer_message =
306            ToRegisterSignerMessageAdapter::try_adapt((epoch, signer.to_owned()))
307                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
308        let response = self
309            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
310            .json(&register_signer_message)
311            .send()
312            .await;
313
314        match response {
315            Ok(response) => match response.status() {
316                StatusCode::CREATED => Ok(()),
317                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
318                _ => Err(AggregatorClientError::from_response(response).await),
319            },
320            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
321        }
322    }
323
324    async fn register_signatures(
325        &self,
326        signed_entity_type: &SignedEntityType,
327        signatures: &SingleSignatures,
328        protocol_message: &ProtocolMessage,
329    ) -> Result<(), AggregatorClientError> {
330        debug!(self.logger, "Register signatures");
331        let url = format!("{}/register-signatures", self.aggregator_endpoint);
332        let register_single_signature_message = ToRegisterSignatureMessageAdapter::try_adapt((
333            signed_entity_type.to_owned(),
334            signatures.to_owned(),
335            protocol_message,
336        ))
337        .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
338        let response = self
339            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
340            .json(&register_single_signature_message)
341            .send()
342            .await;
343
344        match response {
345            Ok(response) => match response.status() {
346                StatusCode::CREATED | StatusCode::ACCEPTED => Ok(()),
347                StatusCode::GONE => {
348                    let root_cause = AggregatorClientError::get_root_cause(response).await;
349                    debug!(self.logger, "Message already certified or expired"; "details" => &root_cause);
350
351                    Ok(())
352                }
353                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
354                _ => Err(AggregatorClientError::from_response(response).await),
355            },
356            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
357        }
358    }
359
360    async fn retrieve_aggregator_features(
361        &self,
362    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
363        debug!(self.logger, "Retrieve aggregator features message");
364        let url = format!("{}/", self.aggregator_endpoint);
365        let response = self
366            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
367            .send()
368            .await;
369
370        match response {
371            Ok(response) => match response.status() {
372                StatusCode::OK => Ok(response
373                    .json::<AggregatorFeaturesMessage>()
374                    .await
375                    .map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?),
376                StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
377                _ => Err(AggregatorClientError::from_response(response).await),
378            },
379            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
380        }
381    }
382}
383
384#[cfg(test)]
385pub(crate) mod dumb {
386    use tokio::sync::RwLock;
387
388    use super::*;
389
390    /// This aggregator client is intended to be used by test services.
391    /// It actually does not communicate with an aggregator host but mimics this behavior.
392    /// It is driven by a Tester that controls the data it can return, and it can return its internal state for testing.
393    pub struct DumbAggregatorClient {
394        epoch_settings: RwLock<Option<SignerEpochSettings>>,
395        last_registered_signer: RwLock<Option<Signer>>,
396        aggregator_features: RwLock<AggregatorFeaturesMessage>,
397    }
398
399    impl DumbAggregatorClient {
400        /// Return the last signer that called with the `register` method.
401        pub async fn get_last_registered_signer(&self) -> Option<Signer> {
402            self.last_registered_signer.read().await.clone()
403        }
404
405        pub async fn set_aggregator_features(
406            &self,
407            aggregator_features: AggregatorFeaturesMessage,
408        ) {
409            let mut aggregator_features_writer = self.aggregator_features.write().await;
410            *aggregator_features_writer = aggregator_features;
411        }
412    }
413
414    impl Default for DumbAggregatorClient {
415        fn default() -> Self {
416            Self {
417                epoch_settings: RwLock::new(Some(SignerEpochSettings::dummy())),
418                last_registered_signer: RwLock::new(None),
419                aggregator_features: RwLock::new(AggregatorFeaturesMessage::dummy()),
420            }
421        }
422    }
423
424    #[async_trait]
425    impl AggregatorClient for DumbAggregatorClient {
426        async fn retrieve_epoch_settings(
427            &self,
428        ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
429            let epoch_settings = self.epoch_settings.read().await.clone();
430
431            Ok(epoch_settings)
432        }
433
434        /// Registers signer with the aggregator
435        async fn register_signer(
436            &self,
437            _epoch: Epoch,
438            signer: &Signer,
439        ) -> Result<(), AggregatorClientError> {
440            let mut last_registered_signer = self.last_registered_signer.write().await;
441            let signer = signer.clone();
442            *last_registered_signer = Some(signer);
443
444            Ok(())
445        }
446
447        /// Registers single signatures with the aggregator
448        async fn register_signatures(
449            &self,
450            _signed_entity_type: &SignedEntityType,
451            _signatures: &SingleSignatures,
452            _protocol_message: &ProtocolMessage,
453        ) -> Result<(), AggregatorClientError> {
454            Ok(())
455        }
456
457        async fn retrieve_aggregator_features(
458            &self,
459        ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
460            let aggregator_features = self.aggregator_features.read().await;
461            Ok(aggregator_features.clone())
462        }
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use http::response::Builder as HttpResponseBuilder;
469    use httpmock::prelude::*;
470    use serde_json::json;
471
472    use mithril_common::entities::Epoch;
473    use mithril_common::era::{EraChecker, SupportedEra};
474    use mithril_common::messages::TryFromMessageAdapter;
475    use mithril_common::test_utils::{fake_data, TempDir};
476
477    use crate::test_tools::TestLogger;
478
479    use super::*;
480
481    macro_rules! assert_is_error {
482        ($error:expr, $error_type:pat) => {
483            assert!(
484                matches!($error, $error_type),
485                "Expected {} error, got '{:?}'.",
486                stringify!($error_type),
487                $error
488            );
489        };
490    }
491
492    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
493        let server = MockServer::start();
494        let aggregator_endpoint = server.url("");
495        let relay_endpoint = None;
496        let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
497        let api_version_provider = APIVersionProvider::new(Arc::new(era_checker));
498
499        (
500            server,
501            AggregatorHTTPClient::new(
502                aggregator_endpoint,
503                relay_endpoint,
504                Arc::new(api_version_provider),
505                None,
506                TestLogger::stdout(),
507            ),
508        )
509    }
510
511    fn set_returning_412(server: &MockServer) {
512        server.mock(|_, then| {
513            then.status(412)
514                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
515        });
516    }
517
518    fn set_returning_500(server: &MockServer) {
519        server.mock(|_, then| {
520            then.status(500).body("an error occurred");
521        });
522    }
523
524    fn set_unparsable_json(server: &MockServer) {
525        server.mock(|_, then| {
526            then.status(200).body("this is not a json");
527        });
528    }
529
530    fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
531        HttpResponseBuilder::new()
532            .status(status_code)
533            .body(body.into())
534            .unwrap()
535            .into()
536    }
537
538    fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
539        HttpResponseBuilder::new()
540            .status(status_code)
541            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
542            .body(serde_json::to_string(&body).unwrap())
543            .unwrap()
544            .into()
545    }
546
547    macro_rules! assert_error_text_contains {
548        ($error: expr, $expect_contains: expr) => {
549            let error = &$error;
550            assert!(
551                error.contains($expect_contains),
552                "Expected error message to contain '{}'\ngot '{error:?}'",
553                $expect_contains,
554            );
555        };
556    }
557
558    #[tokio::test]
559    async fn test_aggregator_features_ok_200() {
560        let (server, client) = setup_server_and_client();
561        let message_expected = AggregatorFeaturesMessage::dummy();
562        let _server_mock = server.mock(|when, then| {
563            when.path("/");
564            then.status(200).body(json!(message_expected).to_string());
565        });
566
567        let message = client.retrieve_aggregator_features().await.unwrap();
568
569        assert_eq!(message_expected, message);
570    }
571
572    #[tokio::test]
573    async fn test_aggregator_features_ko_412() {
574        let (server, client) = setup_server_and_client();
575        set_returning_412(&server);
576
577        let error = client.retrieve_aggregator_features().await.unwrap_err();
578
579        assert_is_error!(error, AggregatorClientError::ApiVersionMismatch(_));
580    }
581
582    #[tokio::test]
583    async fn test_aggregator_features_ko_500() {
584        let (server, client) = setup_server_and_client();
585        set_returning_500(&server);
586
587        let error = client.retrieve_aggregator_features().await.unwrap_err();
588
589        assert_is_error!(error, AggregatorClientError::RemoteServerTechnical(_));
590    }
591
592    #[tokio::test]
593    async fn test_aggregator_features_ko_json_serialization() {
594        let (server, client) = setup_server_and_client();
595        set_unparsable_json(&server);
596
597        let error = client.retrieve_aggregator_features().await.unwrap_err();
598
599        assert_is_error!(error, AggregatorClientError::JsonParseFailed(_));
600    }
601
602    #[tokio::test]
603    async fn test_aggregator_features_timeout() {
604        let (server, mut client) = setup_server_and_client();
605        client.timeout_duration = Some(Duration::from_millis(10));
606        let _server_mock = server.mock(|when, then| {
607            when.path("/");
608            then.delay(Duration::from_millis(100));
609        });
610
611        let error = client.retrieve_aggregator_features().await.unwrap_err();
612
613        assert_is_error!(error, AggregatorClientError::RemoteServerUnreachable(_));
614    }
615
616    #[tokio::test]
617    async fn test_epoch_settings_ok_200() {
618        let (server, client) = setup_server_and_client();
619        let epoch_settings_expected = EpochSettingsMessage::dummy();
620        let _server_mock = server.mock(|when, then| {
621            when.path("/epoch-settings");
622            then.status(200)
623                .body(json!(epoch_settings_expected).to_string());
624        });
625
626        let epoch_settings = client.retrieve_epoch_settings().await;
627        epoch_settings.as_ref().expect("unexpected error");
628        assert_eq!(
629            FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
630            epoch_settings.unwrap().unwrap()
631        );
632    }
633
634    #[tokio::test]
635    async fn test_epoch_settings_ko_412() {
636        let (server, client) = setup_server_and_client();
637        let _server_mock = server.mock(|when, then| {
638            when.path("/epoch-settings");
639            then.status(412)
640                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
641        });
642
643        let epoch_settings = client.retrieve_epoch_settings().await.unwrap_err();
644
645        assert!(epoch_settings.is_api_version_mismatch());
646    }
647
648    #[tokio::test]
649    async fn test_epoch_settings_ko_500() {
650        let (server, client) = setup_server_and_client();
651        let _server_mock = server.mock(|when, then| {
652            when.path("/epoch-settings");
653            then.status(500).body("an error occurred");
654        });
655
656        match client.retrieve_epoch_settings().await.unwrap_err() {
657            AggregatorClientError::RemoteServerTechnical(_) => (),
658            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
659        };
660    }
661
662    #[tokio::test]
663    async fn test_epoch_settings_timeout() {
664        let (server, mut client) = setup_server_and_client();
665        client.timeout_duration = Some(Duration::from_millis(10));
666        let _server_mock = server.mock(|when, then| {
667            when.path("/epoch-settings");
668            then.delay(Duration::from_millis(100));
669        });
670
671        let error = client
672            .retrieve_epoch_settings()
673            .await
674            .expect_err("retrieve_epoch_settings should fail");
675
676        assert!(
677            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
678            "unexpected error type: {error:?}"
679        );
680    }
681
682    #[tokio::test]
683    async fn test_register_signer_ok_201() {
684        let epoch = Epoch(1);
685        let single_signers = fake_data::signers(1);
686        let single_signer = single_signers.first().unwrap();
687        let (server, client) = setup_server_and_client();
688        let _server_mock = server.mock(|when, then| {
689            when.method(POST).path("/register-signer");
690            then.status(201);
691        });
692
693        let register_signer = client.register_signer(epoch, single_signer).await;
694        register_signer.expect("unexpected error");
695    }
696
697    #[tokio::test]
698    async fn test_register_signer_ko_412() {
699        let epoch = Epoch(1);
700        let (server, client) = setup_server_and_client();
701        let _server_mock = server.mock(|when, then| {
702            when.method(POST).path("/register-signer");
703            then.status(412)
704                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
705        });
706        let single_signers = fake_data::signers(1);
707        let single_signer = single_signers.first().unwrap();
708
709        let error = client
710            .register_signer(epoch, single_signer)
711            .await
712            .unwrap_err();
713
714        assert!(error.is_api_version_mismatch());
715    }
716
717    #[tokio::test]
718    async fn test_register_signer_ko_400() {
719        let epoch = Epoch(1);
720        let single_signers = fake_data::signers(1);
721        let single_signer = single_signers.first().unwrap();
722        let (server, client) = setup_server_and_client();
723        let _server_mock = server.mock(|when, then| {
724            when.method(POST).path("/register-signer");
725            then.status(400).body(
726                serde_json::to_vec(&ClientError::new(
727                    "error".to_string(),
728                    "an error".to_string(),
729                ))
730                .unwrap(),
731            );
732        });
733
734        match client
735            .register_signer(epoch, single_signer)
736            .await
737            .unwrap_err()
738        {
739            AggregatorClientError::RemoteServerLogical(_) => (),
740            err => {
741                panic!(
742                    "Expected a AggregatorClientError::RemoteServerLogical error, got '{err:?}'."
743                )
744            }
745        };
746    }
747
748    #[tokio::test]
749    async fn test_register_signer_ko_500() {
750        let epoch = Epoch(1);
751        let single_signers = fake_data::signers(1);
752        let single_signer = single_signers.first().unwrap();
753        let (server, client) = setup_server_and_client();
754        let _server_mock = server.mock(|when, then| {
755            when.method(POST).path("/register-signer");
756            then.status(500).body("an error occurred");
757        });
758
759        match client
760            .register_signer(epoch, single_signer)
761            .await
762            .unwrap_err()
763        {
764            AggregatorClientError::RemoteServerTechnical(_) => (),
765            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
766        };
767    }
768
769    #[tokio::test]
770    async fn test_register_signer_timeout() {
771        let epoch = Epoch(1);
772        let single_signers = fake_data::signers(1);
773        let single_signer = single_signers.first().unwrap();
774        let (server, mut client) = setup_server_and_client();
775        client.timeout_duration = Some(Duration::from_millis(10));
776        let _server_mock = server.mock(|when, then| {
777            when.method(POST).path("/register-signer");
778            then.delay(Duration::from_millis(100));
779        });
780
781        let error = client
782            .register_signer(epoch, single_signer)
783            .await
784            .expect_err("register_signer should fail");
785
786        assert!(
787            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
788            "unexpected error type: {error:?}"
789        );
790    }
791
792    #[tokio::test]
793    async fn test_register_signatures_ok_201() {
794        let single_signatures = fake_data::single_signatures((1..5).collect());
795        let (server, client) = setup_server_and_client();
796        let _server_mock = server.mock(|when, then| {
797            when.method(POST).path("/register-signatures");
798            then.status(201);
799        });
800
801        let register_signatures = client
802            .register_signatures(
803                &SignedEntityType::dummy(),
804                &single_signatures,
805                &ProtocolMessage::default(),
806            )
807            .await;
808        register_signatures.expect("unexpected error");
809    }
810
811    #[tokio::test]
812    async fn test_register_signatures_ok_202() {
813        let single_signatures = fake_data::single_signatures((1..5).collect());
814        let (server, client) = setup_server_and_client();
815        let _server_mock = server.mock(|when, then| {
816            when.method(POST).path("/register-signatures");
817            then.status(202);
818        });
819
820        let register_signatures = client
821            .register_signatures(
822                &SignedEntityType::dummy(),
823                &single_signatures,
824                &ProtocolMessage::default(),
825            )
826            .await;
827        register_signatures.expect("unexpected error");
828    }
829
830    #[tokio::test]
831    async fn test_register_signatures_ko_412() {
832        let (server, client) = setup_server_and_client();
833        let _server_mock = server.mock(|when, then| {
834            when.method(POST).path("/register-signatures");
835            then.status(412)
836                .header(MITHRIL_API_VERSION_HEADER, "0.0.999");
837        });
838        let single_signatures = fake_data::single_signatures((1..5).collect());
839
840        let error = client
841            .register_signatures(
842                &SignedEntityType::dummy(),
843                &single_signatures,
844                &ProtocolMessage::default(),
845            )
846            .await
847            .unwrap_err();
848
849        assert!(error.is_api_version_mismatch());
850    }
851
852    #[tokio::test]
853    async fn test_register_signatures_ko_400() {
854        let single_signatures = fake_data::single_signatures((1..5).collect());
855        let (server, client) = setup_server_and_client();
856        let _server_mock = server.mock(|when, then| {
857            when.method(POST).path("/register-signatures");
858            then.status(400).body(
859                serde_json::to_vec(&ClientError::new(
860                    "error".to_string(),
861                    "an error".to_string(),
862                ))
863                .unwrap(),
864            );
865        });
866
867        match client
868            .register_signatures(
869                &SignedEntityType::dummy(),
870                &single_signatures,
871                &ProtocolMessage::default(),
872            )
873            .await
874            .unwrap_err()
875        {
876            AggregatorClientError::RemoteServerLogical(_) => (),
877            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
878        };
879    }
880
881    #[tokio::test]
882    async fn test_register_signatures_ok_410_log_response_body() {
883        let log_path = TempDir::create(
884            "aggregator_client",
885            "test_register_signatures_ok_410_log_response_body",
886        )
887        .join("test.log");
888
889        let single_signatures = fake_data::single_signatures((1..5).collect());
890        {
891            let (server, mut client) = setup_server_and_client();
892            client.logger = TestLogger::file(&log_path);
893            let _server_mock = server.mock(|when, then| {
894                when.method(POST).path("/register-signatures");
895                then.status(410).body(
896                    serde_json::to_vec(&ClientError::new(
897                        "already_aggregated".to_string(),
898                        "too late".to_string(),
899                    ))
900                    .unwrap(),
901                );
902            });
903
904            client
905                .register_signatures(
906                    &SignedEntityType::dummy(),
907                    &single_signatures,
908                    &ProtocolMessage::default(),
909                )
910                .await
911                .expect("Should not fail when status is 410 (GONE)");
912        }
913
914        let logs = std::fs::read_to_string(&log_path).unwrap();
915        assert!(logs.contains("already_aggregated"));
916        assert!(logs.contains("too late"));
917    }
918
919    #[tokio::test]
920    async fn test_register_signatures_ko_409() {
921        let single_signatures = fake_data::single_signatures((1..5).collect());
922        let (server, client) = setup_server_and_client();
923        let _server_mock = server.mock(|when, then| {
924            when.method(POST).path("/register-signatures");
925            then.status(409);
926        });
927
928        match client
929            .register_signatures(
930                &SignedEntityType::dummy(),
931                &single_signatures,
932                &ProtocolMessage::default(),
933            )
934            .await
935            .unwrap_err()
936        {
937            AggregatorClientError::RemoteServerLogical(_) => (),
938            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
939        }
940    }
941
942    #[tokio::test]
943    async fn test_register_signatures_ko_500() {
944        let single_signatures = fake_data::single_signatures((1..5).collect());
945        let (server, client) = setup_server_and_client();
946        let _server_mock = server.mock(|when, then| {
947            when.method(POST).path("/register-signatures");
948            then.status(500).body("an error occurred");
949        });
950
951        match client
952            .register_signatures(
953                &SignedEntityType::dummy(),
954                &single_signatures,
955                &ProtocolMessage::default(),
956            )
957            .await
958            .unwrap_err()
959        {
960            AggregatorClientError::RemoteServerTechnical(_) => (),
961            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
962        };
963    }
964
965    #[tokio::test]
966    async fn test_register_signatures_timeout() {
967        let single_signatures = fake_data::single_signatures((1..5).collect());
968        let (server, mut client) = setup_server_and_client();
969        client.timeout_duration = Some(Duration::from_millis(10));
970        let _server_mock = server.mock(|when, then| {
971            when.method(POST).path("/register-signatures");
972            then.delay(Duration::from_millis(100));
973        });
974
975        let error = client
976            .register_signatures(
977                &SignedEntityType::dummy(),
978                &single_signatures,
979                &ProtocolMessage::default(),
980            )
981            .await
982            .expect_err("register_signatures should fail");
983
984        assert!(
985            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
986            "unexpected error type: {error:?}"
987        );
988    }
989
990    #[tokio::test]
991    async fn test_4xx_errors_are_handled_as_remote_server_logical() {
992        let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
993        let handled_error = AggregatorClientError::from_response(response).await;
994
995        assert!(
996            matches!(
997                handled_error,
998                AggregatorClientError::RemoteServerLogical(..)
999            ),
1000            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
1001        );
1002    }
1003
1004    #[tokio::test]
1005    async fn test_5xx_errors_are_handled_as_remote_server_technical() {
1006        let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
1007        let handled_error = AggregatorClientError::from_response(response).await;
1008
1009        assert!(
1010            matches!(
1011                handled_error,
1012                AggregatorClientError::RemoteServerTechnical(..)
1013            ),
1014            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
1015        );
1016    }
1017
1018    #[tokio::test]
1019    async fn test_550_error_is_handled_as_registration_round_not_yet_opened() {
1020        let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available");
1021        let handled_error = AggregatorClientError::from_response(response).await;
1022
1023        assert!(
1024            matches!(
1025                handled_error,
1026                AggregatorClientError::RegistrationRoundNotYetOpened(..)
1027            ),
1028            "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'",
1029        );
1030    }
1031
1032    #[tokio::test]
1033    async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text(
1034    ) {
1035        let response = build_text_response(StatusCode::OK, "ok text");
1036        let handled_error = AggregatorClientError::from_response(response).await;
1037
1038        assert!(
1039            matches!(
1040                handled_error,
1041                AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
1042            ),
1043            "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
1044        );
1045    }
1046
1047    #[tokio::test]
1048    async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
1049        let error_text = "An error occurred; please try again later.";
1050        let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
1051
1052        assert_error_text_contains!(
1053            AggregatorClientError::get_root_cause(response).await,
1054            "expectation failed: An error occurred; please try again later."
1055        );
1056    }
1057
1058    #[tokio::test]
1059    async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message(
1060    ) {
1061        let client_error = ClientError::new("label", "message");
1062        let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
1063
1064        assert_error_text_contains!(
1065            AggregatorClientError::get_root_cause(response).await,
1066            "bad request: label: message"
1067        );
1068    }
1069
1070    #[tokio::test]
1071    async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message(
1072    ) {
1073        let server_error = ServerError::new("message");
1074        let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
1075
1076        assert_error_text_contains!(
1077            AggregatorClientError::get_root_cause(response).await,
1078            "bad request: message"
1079        );
1080    }
1081
1082    #[tokio::test]
1083    async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
1084        let response = build_json_response(
1085            StatusCode::INTERNAL_SERVER_ERROR,
1086            &json!({ "second": "unknown", "first": "foreign" }),
1087        );
1088
1089        assert_error_text_contains!(
1090            AggregatorClientError::get_root_cause(response).await,
1091            r#"internal server error: {"first":"foreign","second":"unknown"}"#
1092        );
1093    }
1094
1095    #[tokio::test]
1096    async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
1097        let response = HttpResponseBuilder::new()
1098            .status(StatusCode::BAD_REQUEST)
1099            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
1100            .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
1101            .unwrap()
1102            .into();
1103
1104        let root_cause = AggregatorClientError::get_root_cause(response).await;
1105
1106        assert_error_text_contains!(root_cause, "bad request");
1107        assert!(
1108            !root_cause.contains("bad request: "),
1109            "Expected error message should not contain additional information \ngot '{root_cause:?}'"
1110        );
1111    }
1112
1113    #[tokio::test]
1114    async fn test_sends_accept_encoding_header() {
1115        let (server, client) = setup_server_and_client();
1116        server.mock(|when, then| {
1117            when.matches(|req| {
1118                let headers = req.headers.clone().expect("HTTP headers not found");
1119                let accept_encoding_header = headers
1120                    .iter()
1121                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
1122                    .expect("Accept-Encoding header not found");
1123
1124                let header_value = accept_encoding_header.clone().1;
1125                ["gzip", "br", "deflate", "zstd"]
1126                    .iter()
1127                    .all(|&value| header_value.contains(value))
1128            });
1129
1130            then.status(201);
1131        });
1132
1133        client
1134            .register_signatures(
1135                &SignedEntityType::dummy(),
1136                &fake_data::single_signatures((1..5).collect()),
1137                &ProtocolMessage::default(),
1138            )
1139            .await
1140            .expect("Should succeed with Accept-Encoding header");
1141    }
1142}