mithril_signer/services/
aggregator_client.rs

1use anyhow::anyhow;
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
5use semver::Version;
6use slog::{Logger, debug, error, warn};
7use std::{io, sync::Arc, time::Duration};
8use thiserror::Error;
9
10use mithril_common::{
11    MITHRIL_API_VERSION_HEADER, MITHRIL_SIGNER_VERSION_HEADER, StdError,
12    api_version::APIVersionProvider,
13    entities::{
14        ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer, SingleSignature,
15    },
16    logging::LoggerExtensions,
17    messages::{
18        AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
19    },
20};
21
22use crate::entities::SignerEpochSettings;
23use crate::message_adapters::{
24    FromEpochSettingsAdapter, ToRegisterSignatureMessageAdapter, ToRegisterSignerMessageAdapter,
25};
26
27const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
28
29const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
30    "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
31
32/// Error structure for the Aggregator Client.
33#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35    /// The aggregator host has returned a technical error.
36    #[error("remote server technical error")]
37    RemoteServerTechnical(#[source] StdError),
38
39    /// The aggregator host responded it cannot fulfill our request.
40    #[error("remote server logical error")]
41    RemoteServerLogical(#[source] StdError),
42
43    /// Could not reach aggregator.
44    #[error("remote server unreachable")]
45    RemoteServerUnreachable(#[source] StdError),
46
47    /// Unhandled status code
48    #[error("unhandled status code: {0}, response text: {1}")]
49    UnhandledStatusCode(StatusCode, String),
50
51    /// Could not parse response.
52    #[error("json parsing failed")]
53    JsonParseFailed(#[source] StdError),
54
55    /// Mostly network errors.
56    #[error("Input/Output error")]
57    IOError(#[from] io::Error),
58
59    /// HTTP client creation error
60    #[error("HTTP client creation failed")]
61    HTTPClientCreation(#[source] StdError),
62
63    /// Proxy creation error
64    #[error("proxy creation failed")]
65    ProxyCreation(#[source] StdError),
66
67    /// Adapter error
68    #[error("adapter failed")]
69    Adapter(#[source] StdError),
70
71    /// No signer registration round opened yet
72    #[error("a signer registration round is not opened yet, please try again later")]
73    RegistrationRoundNotYetOpened(#[source] StdError),
74}
75
76impl AggregatorClientError {
77    /// Create an `AggregatorClientError` from a response.
78    ///
79    /// This method is meant to be used after handling domain-specific cases leaving only
80    /// 4xx or 5xx status codes.
81    /// Otherwise, it will return an `UnhandledStatusCode` error.
82    pub async fn from_response(response: Response) -> Self {
83        let error_code = response.status();
84
85        if error_code.is_client_error() {
86            let root_cause = Self::get_root_cause(response).await;
87            Self::RemoteServerLogical(anyhow!(root_cause))
88        } else if error_code.is_server_error() {
89            let root_cause = Self::get_root_cause(response).await;
90            match error_code.as_u16() {
91                550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)),
92                _ => Self::RemoteServerTechnical(anyhow!(root_cause)),
93            }
94        } else {
95            let response_text = response.text().await.unwrap_or_default();
96            Self::UnhandledStatusCode(error_code, response_text)
97        }
98    }
99
100    async fn get_root_cause(response: Response) -> String {
101        let error_code = response.status();
102        let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
103        let is_json = response
104            .headers()
105            .get(header::CONTENT_TYPE)
106            .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
107
108        if is_json {
109            let json_value: serde_json::Value = response.json().await.unwrap_or_default();
110
111            if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
112                format!(
113                    "{}: {}: {}",
114                    canonical_reason, client_error.label, client_error.message
115                )
116            } else if let Ok(server_error) =
117                serde_json::from_value::<ServerError>(json_value.clone())
118            {
119                format!("{}: {}", canonical_reason, server_error.message)
120            } else if json_value.is_null() {
121                canonical_reason.to_string()
122            } else {
123                format!("{canonical_reason}: {json_value}")
124            }
125        } else {
126            let response_text = response.text().await.unwrap_or_default();
127            format!("{canonical_reason}: {response_text}")
128        }
129    }
130}
131
132/// Trait for mocking and testing a `AggregatorClient`
133#[cfg_attr(test, mockall::automock)]
134#[async_trait]
135pub trait AggregatorClient: Sync + Send {
136    /// Retrieves epoch settings from the aggregator
137    async fn retrieve_epoch_settings(
138        &self,
139    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError>;
140
141    /// Registers signer with the aggregator.
142    async fn register_signer(
143        &self,
144        epoch: Epoch,
145        signer: &Signer,
146    ) -> Result<(), AggregatorClientError>;
147
148    /// Registers single signature with the aggregator.
149    async fn register_signature(
150        &self,
151        signed_entity_type: &SignedEntityType,
152        signature: &SingleSignature,
153        protocol_message: &ProtocolMessage,
154    ) -> Result<(), AggregatorClientError>;
155
156    /// Retrieves aggregator features message from the aggregator
157    async fn retrieve_aggregator_features(
158        &self,
159    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError>;
160}
161
162/// AggregatorHTTPClient is a http client for an aggregator
163pub struct AggregatorHTTPClient {
164    aggregator_endpoint: String,
165    relay_endpoint: Option<String>,
166    api_version_provider: Arc<APIVersionProvider>,
167    timeout_duration: Option<Duration>,
168    logger: Logger,
169}
170
171impl AggregatorHTTPClient {
172    /// AggregatorHTTPClient factory
173    pub fn new(
174        aggregator_endpoint: String,
175        relay_endpoint: Option<String>,
176        api_version_provider: Arc<APIVersionProvider>,
177        timeout_duration: Option<Duration>,
178        logger: Logger,
179    ) -> Self {
180        let logger = logger.new_with_component_name::<Self>();
181        debug!(logger, "New AggregatorHTTPClient created");
182        Self {
183            aggregator_endpoint,
184            relay_endpoint,
185            api_version_provider,
186            timeout_duration,
187            logger,
188        }
189    }
190
191    fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
192        let client = match &self.relay_endpoint {
193            Some(relay_endpoint) => Client::builder()
194                .proxy(
195                    Proxy::all(relay_endpoint)
196                        .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
197                )
198                .build()
199                .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
200            None => Client::new(),
201        };
202
203        Ok(client)
204    }
205
206    /// Forge a client request adding protocol version in the headers.
207    pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
208        let request_builder = request_builder
209            .header(
210                MITHRIL_API_VERSION_HEADER,
211                self.api_version_provider
212                    .compute_current_version()
213                    .unwrap()
214                    .to_string(),
215            )
216            .header(MITHRIL_SIGNER_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
217
218        if let Some(duration) = self.timeout_duration {
219            request_builder.timeout(duration)
220        } else {
221            request_builder
222        }
223    }
224
225    /// Check API version mismatch and log a warning if the aggregator's version is more recent.
226    fn warn_if_api_version_mismatch(&self, response: &Response) {
227        let aggregator_version = response
228            .headers()
229            .get(MITHRIL_API_VERSION_HEADER)
230            .and_then(|v| v.to_str().ok())
231            .and_then(|s| Version::parse(s).ok());
232
233        let signer_version = self.api_version_provider.compute_current_version();
234
235        match (aggregator_version, signer_version) {
236            (Some(aggregator), Ok(signer)) if signer < aggregator => {
237                warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
238                    "aggregator_version" => %aggregator,
239                    "signer_version" => %signer,
240                );
241            }
242            (Some(_), Err(error)) => {
243                error!(
244                    self.logger,
245                    "Failed to compute the current signer API version";
246                    "error" => error.to_string()
247                );
248            }
249            _ => {}
250        }
251    }
252}
253
254#[async_trait]
255impl AggregatorClient for AggregatorHTTPClient {
256    async fn retrieve_epoch_settings(
257        &self,
258    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
259        debug!(self.logger, "Retrieve epoch settings");
260        let url = format!("{}/epoch-settings", self.aggregator_endpoint);
261        let response = self
262            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
263            .send()
264            .await;
265
266        match response {
267            Ok(response) => match response.status() {
268                StatusCode::OK => {
269                    self.warn_if_api_version_mismatch(&response);
270                    match response.json::<EpochSettingsMessage>().await {
271                        Ok(message) => {
272                            let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
273                                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
274                            Ok(Some(epoch_settings))
275                        }
276                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
277                    }
278                }
279                _ => Err(AggregatorClientError::from_response(response).await),
280            },
281            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
282        }
283    }
284
285    async fn register_signer(
286        &self,
287        epoch: Epoch,
288        signer: &Signer,
289    ) -> Result<(), AggregatorClientError> {
290        debug!(self.logger, "Register signer");
291        let url = format!("{}/register-signer", self.aggregator_endpoint);
292        let register_signer_message =
293            ToRegisterSignerMessageAdapter::try_adapt((epoch, signer.to_owned()))
294                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
295        let response = self
296            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
297            .json(&register_signer_message)
298            .send()
299            .await;
300
301        match response {
302            Ok(response) => match response.status() {
303                StatusCode::CREATED => {
304                    self.warn_if_api_version_mismatch(&response);
305
306                    Ok(())
307                }
308                _ => Err(AggregatorClientError::from_response(response).await),
309            },
310            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
311        }
312    }
313
314    async fn register_signature(
315        &self,
316        signed_entity_type: &SignedEntityType,
317        signature: &SingleSignature,
318        protocol_message: &ProtocolMessage,
319    ) -> Result<(), AggregatorClientError> {
320        debug!(self.logger, "Register signature");
321        let url = format!("{}/register-signatures", self.aggregator_endpoint);
322        let register_single_signature_message = ToRegisterSignatureMessageAdapter::try_adapt((
323            signed_entity_type.to_owned(),
324            signature.to_owned(),
325            protocol_message,
326        ))
327        .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
328        let response = self
329            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
330            .json(&register_single_signature_message)
331            .send()
332            .await;
333
334        match response {
335            Ok(response) => match response.status() {
336                StatusCode::CREATED | StatusCode::ACCEPTED => {
337                    self.warn_if_api_version_mismatch(&response);
338
339                    Ok(())
340                }
341                StatusCode::GONE => {
342                    self.warn_if_api_version_mismatch(&response);
343                    let root_cause = AggregatorClientError::get_root_cause(response).await;
344                    debug!(self.logger, "Message already certified or expired"; "details" => &root_cause);
345
346                    Ok(())
347                }
348                _ => Err(AggregatorClientError::from_response(response).await),
349            },
350            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
351        }
352    }
353
354    async fn retrieve_aggregator_features(
355        &self,
356    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
357        debug!(self.logger, "Retrieve aggregator features message");
358        let url = format!("{}/", self.aggregator_endpoint);
359        let response = self
360            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
361            .send()
362            .await;
363
364        match response {
365            Ok(response) => match response.status() {
366                StatusCode::OK => {
367                    self.warn_if_api_version_mismatch(&response);
368
369                    Ok(response
370                        .json::<AggregatorFeaturesMessage>()
371                        .await
372                        .map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?)
373                }
374                _ => Err(AggregatorClientError::from_response(response).await),
375            },
376            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
377        }
378    }
379}
380
381#[cfg(test)]
382pub(crate) mod dumb {
383    use mithril_common::test::double::Dummy;
384    use tokio::sync::RwLock;
385
386    use super::*;
387
388    /// This aggregator client is intended to be used by test services.
389    /// It actually does not communicate with an aggregator host but mimics this behavior.
390    /// It is driven by a Tester that controls the data it can return, and it can return its internal state for testing.
391    pub struct DumbAggregatorClient {
392        epoch_settings: RwLock<Option<SignerEpochSettings>>,
393        last_registered_signer: RwLock<Option<Signer>>,
394        aggregator_features: RwLock<AggregatorFeaturesMessage>,
395    }
396
397    impl DumbAggregatorClient {
398        /// Return the last signer that called with the `register` method.
399        pub async fn get_last_registered_signer(&self) -> Option<Signer> {
400            self.last_registered_signer.read().await.clone()
401        }
402    }
403
404    impl Default for DumbAggregatorClient {
405        fn default() -> Self {
406            Self {
407                epoch_settings: RwLock::new(Some(SignerEpochSettings::dummy())),
408                last_registered_signer: RwLock::new(None),
409                aggregator_features: RwLock::new(AggregatorFeaturesMessage::dummy()),
410            }
411        }
412    }
413
414    #[async_trait]
415    impl AggregatorClient for DumbAggregatorClient {
416        async fn retrieve_epoch_settings(
417            &self,
418        ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
419            let epoch_settings = self.epoch_settings.read().await.clone();
420
421            Ok(epoch_settings)
422        }
423
424        /// Registers signer with the aggregator
425        async fn register_signer(
426            &self,
427            _epoch: Epoch,
428            signer: &Signer,
429        ) -> Result<(), AggregatorClientError> {
430            let mut last_registered_signer = self.last_registered_signer.write().await;
431            let signer = signer.clone();
432            *last_registered_signer = Some(signer);
433
434            Ok(())
435        }
436
437        /// Registers single signature with the aggregator
438        async fn register_signature(
439            &self,
440            _signed_entity_type: &SignedEntityType,
441            _signature: &SingleSignature,
442            _protocol_message: &ProtocolMessage,
443        ) -> Result<(), AggregatorClientError> {
444            Ok(())
445        }
446
447        async fn retrieve_aggregator_features(
448            &self,
449        ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
450            let aggregator_features = self.aggregator_features.read().await;
451            Ok(aggregator_features.clone())
452        }
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use std::collections::HashMap;
459
460    use http::response::Builder as HttpResponseBuilder;
461    use httpmock::prelude::*;
462    use semver::Version;
463    use serde_json::json;
464
465    use mithril_common::entities::Epoch;
466    use mithril_common::messages::TryFromMessageAdapter;
467    use mithril_common::test::{
468        double::{Dummy, DummyApiVersionDiscriminantSource, fake_data},
469        logging::MemoryDrainForTestInspector,
470    };
471
472    use crate::test_tools::TestLogger;
473
474    use super::*;
475
476    macro_rules! assert_is_error {
477        ($error:expr, $error_type:pat) => {
478            assert!(
479                matches!($error, $error_type),
480                "Expected {} error, got '{:?}'.",
481                stringify!($error_type),
482                $error
483            );
484        };
485    }
486
487    fn setup_client<U: Into<String>>(server_url: U) -> AggregatorHTTPClient {
488        let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
489        let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
490
491        AggregatorHTTPClient::new(
492            server_url.into(),
493            None,
494            Arc::new(api_version_provider),
495            None,
496            TestLogger::stdout(),
497        )
498    }
499
500    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
501        let server = MockServer::start();
502        let aggregator_endpoint = server.url("");
503        let client = setup_client(&aggregator_endpoint);
504
505        (server, client)
506    }
507
508    fn set_returning_500(server: &MockServer) {
509        server.mock(|_, then| {
510            then.status(500).body("an error occurred");
511        });
512    }
513
514    fn set_unparsable_json(server: &MockServer) {
515        server.mock(|_, then| {
516            then.status(200).body("this is not a json");
517        });
518    }
519
520    fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
521        HttpResponseBuilder::new()
522            .status(status_code)
523            .body(body.into())
524            .unwrap()
525            .into()
526    }
527
528    fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
529        HttpResponseBuilder::new()
530            .status(status_code)
531            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
532            .body(serde_json::to_string(&body).unwrap())
533            .unwrap()
534            .into()
535    }
536
537    macro_rules! assert_error_text_contains {
538        ($error: expr, $expect_contains: expr) => {
539            let error = &$error;
540            assert!(
541                error.contains($expect_contains),
542                "Expected error message to contain '{}'\ngot '{error:?}'",
543                $expect_contains,
544            );
545        };
546    }
547
548    #[tokio::test]
549    async fn test_aggregator_features_ok_200() {
550        let (server, client) = setup_server_and_client();
551        let message_expected = AggregatorFeaturesMessage::dummy();
552        let _server_mock = server.mock(|when, then| {
553            when.path("/");
554            then.status(200).body(json!(message_expected).to_string());
555        });
556
557        let message = client.retrieve_aggregator_features().await.unwrap();
558
559        assert_eq!(message_expected, message);
560    }
561
562    #[tokio::test]
563    async fn test_aggregator_features_ko_500() {
564        let (server, client) = setup_server_and_client();
565        set_returning_500(&server);
566
567        let error = client.retrieve_aggregator_features().await.unwrap_err();
568
569        assert_is_error!(error, AggregatorClientError::RemoteServerTechnical(_));
570    }
571
572    #[tokio::test]
573    async fn test_aggregator_features_ko_json_serialization() {
574        let (server, client) = setup_server_and_client();
575        set_unparsable_json(&server);
576
577        let error = client.retrieve_aggregator_features().await.unwrap_err();
578
579        assert_is_error!(error, AggregatorClientError::JsonParseFailed(_));
580    }
581
582    #[tokio::test]
583    async fn test_aggregator_features_timeout() {
584        let (server, mut client) = setup_server_and_client();
585        client.timeout_duration = Some(Duration::from_millis(10));
586        let _server_mock = server.mock(|when, then| {
587            when.path("/");
588            then.delay(Duration::from_millis(100));
589        });
590
591        let error = client.retrieve_aggregator_features().await.unwrap_err();
592
593        assert_is_error!(error, AggregatorClientError::RemoteServerUnreachable(_));
594    }
595
596    #[tokio::test]
597    async fn test_epoch_settings_ok_200() {
598        let (server, client) = setup_server_and_client();
599        let epoch_settings_expected = EpochSettingsMessage::dummy();
600        let _server_mock = server.mock(|when, then| {
601            when.path("/epoch-settings");
602            then.status(200).body(json!(epoch_settings_expected).to_string());
603        });
604
605        let epoch_settings = client.retrieve_epoch_settings().await;
606        epoch_settings.as_ref().expect("unexpected error");
607        assert_eq!(
608            FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
609            epoch_settings.unwrap().unwrap()
610        );
611    }
612
613    #[tokio::test]
614    async fn test_epoch_settings_ko_500() {
615        let (server, client) = setup_server_and_client();
616        let _server_mock = server.mock(|when, then| {
617            when.path("/epoch-settings");
618            then.status(500).body("an error occurred");
619        });
620
621        match client.retrieve_epoch_settings().await.unwrap_err() {
622            AggregatorClientError::RemoteServerTechnical(_) => (),
623            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
624        };
625    }
626
627    #[tokio::test]
628    async fn test_epoch_settings_timeout() {
629        let (server, mut client) = setup_server_and_client();
630        client.timeout_duration = Some(Duration::from_millis(10));
631        let _server_mock = server.mock(|when, then| {
632            when.path("/epoch-settings");
633            then.delay(Duration::from_millis(100));
634        });
635
636        let error = client
637            .retrieve_epoch_settings()
638            .await
639            .expect_err("retrieve_epoch_settings should fail");
640
641        assert!(
642            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
643            "unexpected error type: {error:?}"
644        );
645    }
646
647    #[tokio::test]
648    async fn test_register_signer_ok_201() {
649        let epoch = Epoch(1);
650        let single_signers = fake_data::signers(1);
651        let single_signer = single_signers.first().unwrap();
652        let (server, client) = setup_server_and_client();
653        let _server_mock = server.mock(|when, then| {
654            when.method(POST).path("/register-signer");
655            then.status(201);
656        });
657
658        let register_signer = client.register_signer(epoch, single_signer).await;
659        register_signer.expect("unexpected error");
660    }
661
662    #[tokio::test]
663    async fn test_register_signer_ko_400() {
664        let epoch = Epoch(1);
665        let single_signers = fake_data::signers(1);
666        let single_signer = single_signers.first().unwrap();
667        let (server, client) = setup_server_and_client();
668        let _server_mock = server.mock(|when, then| {
669            when.method(POST).path("/register-signer");
670            then.status(400).body(
671                serde_json::to_vec(&ClientError::new(
672                    "error".to_string(),
673                    "an error".to_string(),
674                ))
675                .unwrap(),
676            );
677        });
678
679        match client.register_signer(epoch, single_signer).await.unwrap_err() {
680            AggregatorClientError::RemoteServerLogical(_) => (),
681            err => {
682                panic!(
683                    "Expected a AggregatorClientError::RemoteServerLogical error, got '{err:?}'."
684                )
685            }
686        };
687    }
688
689    #[tokio::test]
690    async fn test_register_signer_ko_500() {
691        let epoch = Epoch(1);
692        let single_signers = fake_data::signers(1);
693        let single_signer = single_signers.first().unwrap();
694        let (server, client) = setup_server_and_client();
695        let _server_mock = server.mock(|when, then| {
696            when.method(POST).path("/register-signer");
697            then.status(500).body("an error occurred");
698        });
699
700        match client.register_signer(epoch, single_signer).await.unwrap_err() {
701            AggregatorClientError::RemoteServerTechnical(_) => (),
702            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
703        };
704    }
705
706    #[tokio::test]
707    async fn test_register_signer_timeout() {
708        let epoch = Epoch(1);
709        let single_signers = fake_data::signers(1);
710        let single_signer = single_signers.first().unwrap();
711        let (server, mut client) = setup_server_and_client();
712        client.timeout_duration = Some(Duration::from_millis(10));
713        let _server_mock = server.mock(|when, then| {
714            when.method(POST).path("/register-signer");
715            then.delay(Duration::from_millis(100));
716        });
717
718        let error = client
719            .register_signer(epoch, single_signer)
720            .await
721            .expect_err("register_signer should fail");
722
723        assert!(
724            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
725            "unexpected error type: {error:?}"
726        );
727    }
728
729    #[tokio::test]
730    async fn test_register_signature_ok_201() {
731        let single_signature = fake_data::single_signature((1..5).collect());
732        let (server, client) = setup_server_and_client();
733        let _server_mock = server.mock(|when, then| {
734            when.method(POST).path("/register-signatures");
735            then.status(201);
736        });
737
738        let register_signature = client
739            .register_signature(
740                &SignedEntityType::dummy(),
741                &single_signature,
742                &ProtocolMessage::default(),
743            )
744            .await;
745        register_signature.expect("unexpected error");
746    }
747
748    #[tokio::test]
749    async fn test_register_signature_ok_202() {
750        let single_signature = fake_data::single_signature((1..5).collect());
751        let (server, client) = setup_server_and_client();
752        let _server_mock = server.mock(|when, then| {
753            when.method(POST).path("/register-signatures");
754            then.status(202);
755        });
756
757        let register_signature = client
758            .register_signature(
759                &SignedEntityType::dummy(),
760                &single_signature,
761                &ProtocolMessage::default(),
762            )
763            .await;
764        register_signature.expect("unexpected error");
765    }
766
767    #[tokio::test]
768    async fn test_register_signature_ko_400() {
769        let single_signature = fake_data::single_signature((1..5).collect());
770        let (server, client) = setup_server_and_client();
771        let _server_mock = server.mock(|when, then| {
772            when.method(POST).path("/register-signatures");
773            then.status(400).body(
774                serde_json::to_vec(&ClientError::new(
775                    "error".to_string(),
776                    "an error".to_string(),
777                ))
778                .unwrap(),
779            );
780        });
781
782        match client
783            .register_signature(
784                &SignedEntityType::dummy(),
785                &single_signature,
786                &ProtocolMessage::default(),
787            )
788            .await
789            .unwrap_err()
790        {
791            AggregatorClientError::RemoteServerLogical(_) => (),
792            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
793        };
794    }
795
796    #[tokio::test]
797    async fn test_register_signature_ok_410_log_response_body() {
798        let (logger, log_inspector) = TestLogger::memory();
799
800        let single_signature = fake_data::single_signature((1..5).collect());
801        let (server, mut client) = setup_server_and_client();
802        client.logger = logger;
803        let _server_mock = server.mock(|when, then| {
804            when.method(POST).path("/register-signatures");
805            then.status(410).body(
806                serde_json::to_vec(&ClientError::new(
807                    "already_aggregated".to_string(),
808                    "too late".to_string(),
809                ))
810                .unwrap(),
811            );
812        });
813
814        client
815            .register_signature(
816                &SignedEntityType::dummy(),
817                &single_signature,
818                &ProtocolMessage::default(),
819            )
820            .await
821            .expect("Should not fail when status is 410 (GONE)");
822
823        assert!(log_inspector.contains_log("already_aggregated"));
824        assert!(log_inspector.contains_log("too late"));
825    }
826
827    #[tokio::test]
828    async fn test_register_signature_ko_409() {
829        let single_signature = fake_data::single_signature((1..5).collect());
830        let (server, client) = setup_server_and_client();
831        let _server_mock = server.mock(|when, then| {
832            when.method(POST).path("/register-signatures");
833            then.status(409);
834        });
835
836        match client
837            .register_signature(
838                &SignedEntityType::dummy(),
839                &single_signature,
840                &ProtocolMessage::default(),
841            )
842            .await
843            .unwrap_err()
844        {
845            AggregatorClientError::RemoteServerLogical(_) => (),
846            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
847        }
848    }
849
850    #[tokio::test]
851    async fn test_register_signature_ko_500() {
852        let single_signature = fake_data::single_signature((1..5).collect());
853        let (server, client) = setup_server_and_client();
854        let _server_mock = server.mock(|when, then| {
855            when.method(POST).path("/register-signatures");
856            then.status(500).body("an error occurred");
857        });
858
859        match client
860            .register_signature(
861                &SignedEntityType::dummy(),
862                &single_signature,
863                &ProtocolMessage::default(),
864            )
865            .await
866            .unwrap_err()
867        {
868            AggregatorClientError::RemoteServerTechnical(_) => (),
869            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
870        };
871    }
872
873    #[tokio::test]
874    async fn test_register_signature_timeout() {
875        let single_signature = fake_data::single_signature((1..5).collect());
876        let (server, mut client) = setup_server_and_client();
877        client.timeout_duration = Some(Duration::from_millis(10));
878        let _server_mock = server.mock(|when, then| {
879            when.method(POST).path("/register-signatures");
880            then.delay(Duration::from_millis(100));
881        });
882
883        let error = client
884            .register_signature(
885                &SignedEntityType::dummy(),
886                &single_signature,
887                &ProtocolMessage::default(),
888            )
889            .await
890            .expect_err("register_signature should fail");
891
892        assert!(
893            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
894            "unexpected error type: {error:?}"
895        );
896    }
897
898    #[tokio::test]
899    async fn test_4xx_errors_are_handled_as_remote_server_logical() {
900        let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
901        let handled_error = AggregatorClientError::from_response(response).await;
902
903        assert!(
904            matches!(
905                handled_error,
906                AggregatorClientError::RemoteServerLogical(..)
907            ),
908            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
909        );
910    }
911
912    #[tokio::test]
913    async fn test_5xx_errors_are_handled_as_remote_server_technical() {
914        let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
915        let handled_error = AggregatorClientError::from_response(response).await;
916
917        assert!(
918            matches!(
919                handled_error,
920                AggregatorClientError::RemoteServerTechnical(..)
921            ),
922            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
923        );
924    }
925
926    #[tokio::test]
927    async fn test_550_error_is_handled_as_registration_round_not_yet_opened() {
928        let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available");
929        let handled_error = AggregatorClientError::from_response(response).await;
930
931        assert!(
932            matches!(
933                handled_error,
934                AggregatorClientError::RegistrationRoundNotYetOpened(..)
935            ),
936            "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'",
937        );
938    }
939
940    #[tokio::test]
941    async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
942     {
943        let response = build_text_response(StatusCode::OK, "ok text");
944        let handled_error = AggregatorClientError::from_response(response).await;
945
946        assert!(
947            matches!(
948                handled_error,
949                AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
950            ),
951            "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
952        );
953    }
954
955    #[tokio::test]
956    async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
957        let error_text = "An error occurred; please try again later.";
958        let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
959
960        assert_error_text_contains!(
961            AggregatorClientError::get_root_cause(response).await,
962            "expectation failed: An error occurred; please try again later."
963        );
964    }
965
966    #[tokio::test]
967    async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
968     {
969        let client_error = ClientError::new("label", "message");
970        let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
971
972        assert_error_text_contains!(
973            AggregatorClientError::get_root_cause(response).await,
974            "bad request: label: message"
975        );
976    }
977
978    #[tokio::test]
979    async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
980     {
981        let server_error = ServerError::new("message");
982        let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
983
984        assert_error_text_contains!(
985            AggregatorClientError::get_root_cause(response).await,
986            "bad request: message"
987        );
988    }
989
990    #[tokio::test]
991    async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
992        let response = build_json_response(
993            StatusCode::INTERNAL_SERVER_ERROR,
994            &json!({ "second": "unknown", "first": "foreign" }),
995        );
996
997        assert_error_text_contains!(
998            AggregatorClientError::get_root_cause(response).await,
999            r#"internal server error: {"first":"foreign","second":"unknown"}"#
1000        );
1001    }
1002
1003    #[tokio::test]
1004    async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
1005        let response = HttpResponseBuilder::new()
1006            .status(StatusCode::BAD_REQUEST)
1007            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
1008            .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
1009            .unwrap()
1010            .into();
1011
1012        let root_cause = AggregatorClientError::get_root_cause(response).await;
1013
1014        assert_error_text_contains!(root_cause, "bad request");
1015        assert!(
1016            !root_cause.contains("bad request: "),
1017            "Expected error message should not contain additional information \ngot '{root_cause:?}'"
1018        );
1019    }
1020
1021    #[tokio::test]
1022    async fn test_sends_accept_encoding_header() {
1023        let (server, client) = setup_server_and_client();
1024        server.mock(|when, then| {
1025            when.is_true(|req| {
1026                let headers = req.headers();
1027                let accept_encoding_header = headers
1028                    .get("accept-encoding")
1029                    .expect("Accept-Encoding header not found");
1030
1031                ["gzip", "br", "deflate", "zstd"].iter().all(|&encoding| {
1032                    accept_encoding_header.to_str().is_ok_and(|h| h.contains(encoding))
1033                })
1034            });
1035
1036            then.status(201);
1037        });
1038
1039        client
1040            .register_signature(
1041                &SignedEntityType::dummy(),
1042                &fake_data::single_signature((1..5).collect()),
1043                &ProtocolMessage::default(),
1044            )
1045            .await
1046            .expect("Should succeed with Accept-Encoding header");
1047    }
1048
1049    mod warn_if_api_version_mismatch {
1050        use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
1051
1052        use super::*;
1053
1054        fn version_provider_with_open_api_version<V: Into<String>>(
1055            version: V,
1056        ) -> APIVersionProvider {
1057            let mut version_provider = version_provider_without_open_api_version();
1058            let mut open_api_versions = HashMap::new();
1059            open_api_versions.insert(
1060                "openapi.yaml".to_string(),
1061                Version::parse(&version.into()).unwrap(),
1062            );
1063            version_provider.update_open_api_versions(open_api_versions);
1064
1065            version_provider
1066        }
1067
1068        fn version_provider_without_open_api_version() -> APIVersionProvider {
1069            let mut version_provider =
1070                APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::new("dummy")));
1071            version_provider.update_open_api_versions(HashMap::new());
1072
1073            version_provider
1074        }
1075
1076        fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
1077            key: K,
1078            value: V,
1079        ) -> Response {
1080            HttpResponseBuilder::new()
1081                .header(key.into(), value.into())
1082                .body("whatever")
1083                .unwrap()
1084                .into()
1085        }
1086
1087        fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
1088            log_inspector: &MemoryDrainForTestInspector,
1089            aggregator_version: A,
1090            signer_version: S,
1091        ) {
1092            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1093            assert!(
1094                log_inspector
1095                    .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
1096            );
1097            assert!(
1098                log_inspector.contains_log(&format!("signer_version={}", signer_version.into()))
1099            );
1100        }
1101
1102        #[test]
1103        fn test_logs_warning_when_aggregator_api_version_is_newer() {
1104            let aggregator_version = "2.0.0";
1105            let signer_version = "1.0.0";
1106            let (logger, log_inspector) = TestLogger::memory();
1107            let version_provider = version_provider_with_open_api_version(signer_version);
1108            let mut client = setup_client("whatever");
1109            client.api_version_provider = Arc::new(version_provider);
1110            client.logger = logger;
1111            let response =
1112                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1113
1114            assert!(
1115                Version::parse(aggregator_version).unwrap()
1116                    > Version::parse(signer_version).unwrap()
1117            );
1118
1119            client.warn_if_api_version_mismatch(&response);
1120
1121            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1122        }
1123
1124        #[test]
1125        fn test_no_warning_logged_when_versions_match() {
1126            let version = "1.0.0";
1127            let (logger, log_inspector) = TestLogger::memory();
1128            let version_provider = version_provider_with_open_api_version(version);
1129            let mut client = setup_client("whatever");
1130            client.api_version_provider = Arc::new(version_provider);
1131            client.logger = logger;
1132            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
1133
1134            client.warn_if_api_version_mismatch(&response);
1135
1136            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1137        }
1138
1139        #[test]
1140        fn test_no_warning_logged_when_aggregator_api_version_is_older() {
1141            let aggregator_version = "1.0.0";
1142            let signer_version = "2.0.0";
1143            let (logger, log_inspector) = TestLogger::memory();
1144            let version_provider = version_provider_with_open_api_version(signer_version);
1145            let mut client = setup_client("whatever");
1146            client.api_version_provider = Arc::new(version_provider);
1147            client.logger = logger;
1148            let response =
1149                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1150
1151            assert!(
1152                Version::parse(aggregator_version).unwrap()
1153                    < Version::parse(signer_version).unwrap()
1154            );
1155
1156            client.warn_if_api_version_mismatch(&response);
1157
1158            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1159        }
1160
1161        #[test]
1162        fn test_does_not_log_or_fail_when_header_is_missing() {
1163            let (logger, log_inspector) = TestLogger::memory();
1164            let mut client = setup_client("whatever");
1165            client.logger = logger;
1166            let response =
1167                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
1168
1169            client.warn_if_api_version_mismatch(&response);
1170
1171            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1172        }
1173
1174        #[test]
1175        fn test_does_not_log_or_fail_when_header_is_not_a_version() {
1176            let (logger, log_inspector) = TestLogger::memory();
1177            let mut client = setup_client("whatever");
1178            client.logger = logger;
1179            let response =
1180                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1181
1182            client.warn_if_api_version_mismatch(&response);
1183
1184            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1185        }
1186
1187        #[test]
1188        fn test_logs_error_when_signer_version_cannot_be_computed() {
1189            let (logger, log_inspector) = TestLogger::memory();
1190            let version_provider = version_provider_without_open_api_version();
1191            let mut client = setup_client("whatever");
1192            client.api_version_provider = Arc::new(version_provider);
1193            client.logger = logger;
1194            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1195
1196            client.warn_if_api_version_mismatch(&response);
1197
1198            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1199        }
1200
1201        #[tokio::test]
1202        async fn test_aggregator_features_ok_200_log_warning_if_api_version_mismatch() {
1203            let aggregator_version = "2.0.0";
1204            let signer_version = "1.0.0";
1205            let (server, mut client) = setup_server_and_client();
1206            let (logger, log_inspector) = TestLogger::memory();
1207            let version_provider = version_provider_with_open_api_version(signer_version);
1208            client.api_version_provider = Arc::new(version_provider);
1209            client.logger = logger;
1210
1211            let message_expected = AggregatorFeaturesMessage::dummy();
1212            let _server_mock = server.mock(|when, then| {
1213                when.path("/");
1214                then.status(200)
1215                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1216                    .body(json!(message_expected).to_string());
1217            });
1218
1219            assert!(
1220                Version::parse(aggregator_version).unwrap()
1221                    > Version::parse(signer_version).unwrap()
1222            );
1223
1224            client.retrieve_aggregator_features().await.unwrap();
1225
1226            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1227        }
1228
1229        #[tokio::test]
1230        async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
1231            let aggregator_version = "2.0.0";
1232            let signer_version = "1.0.0";
1233            let (server, mut client) = setup_server_and_client();
1234            let (logger, log_inspector) = TestLogger::memory();
1235            let version_provider = version_provider_with_open_api_version(signer_version);
1236            client.api_version_provider = Arc::new(version_provider);
1237            client.logger = logger;
1238
1239            let epoch_settings_expected = EpochSettingsMessage::dummy();
1240            let _server_mock = server.mock(|when, then| {
1241                when.path("/epoch-settings");
1242                then.status(200)
1243                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1244                    .body(json!(epoch_settings_expected).to_string());
1245            });
1246
1247            assert!(
1248                Version::parse(aggregator_version).unwrap()
1249                    > Version::parse(signer_version).unwrap()
1250            );
1251
1252            client.retrieve_epoch_settings().await.unwrap();
1253
1254            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1255        }
1256
1257        #[tokio::test]
1258        async fn test_register_signer_ok_201_log_warning_if_api_version_mismatch() {
1259            let aggregator_version = "2.0.0";
1260            let signer_version = "1.0.0";
1261            let epoch = Epoch(1);
1262            let single_signers = fake_data::signers(1);
1263            let single_signer = single_signers.first().unwrap();
1264            let (server, mut client) = setup_server_and_client();
1265            let (logger, log_inspector) = TestLogger::memory();
1266            let version_provider = version_provider_with_open_api_version(signer_version);
1267            client.api_version_provider = Arc::new(version_provider);
1268            client.logger = logger;
1269            let _server_mock = server.mock(|when, then| {
1270                when.method(POST).path("/register-signer");
1271                then.status(201)
1272                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1273            });
1274
1275            assert!(
1276                Version::parse(aggregator_version).unwrap()
1277                    > Version::parse(signer_version).unwrap()
1278            );
1279
1280            client.register_signer(epoch, single_signer).await.unwrap();
1281
1282            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1283        }
1284
1285        #[tokio::test]
1286        async fn test_register_signature_ok_201_log_warning_if_api_version_mismatch() {
1287            let aggregator_version = "2.0.0";
1288            let signer_version = "1.0.0";
1289            let single_signature = fake_data::single_signature((1..5).collect());
1290            let (server, mut client) = setup_server_and_client();
1291            let (logger, log_inspector) = TestLogger::memory();
1292            let version_provider = version_provider_with_open_api_version(signer_version);
1293            client.api_version_provider = Arc::new(version_provider);
1294            client.logger = logger;
1295            let _server_mock = server.mock(|when, then| {
1296                when.method(POST).path("/register-signatures");
1297                then.status(201)
1298                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1299            });
1300
1301            assert!(
1302                Version::parse(aggregator_version).unwrap()
1303                    > Version::parse(signer_version).unwrap()
1304            );
1305
1306            client
1307                .register_signature(
1308                    &SignedEntityType::dummy(),
1309                    &single_signature,
1310                    &ProtocolMessage::default(),
1311                )
1312                .await
1313                .expect("Should not fail");
1314
1315            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1316        }
1317
1318        #[tokio::test]
1319        async fn test_register_signature_ok_202_log_warning_if_api_version_mismatch() {
1320            let aggregator_version = "2.0.0";
1321            let signer_version = "1.0.0";
1322            let single_signature = fake_data::single_signature((1..5).collect());
1323            let (server, mut client) = setup_server_and_client();
1324            let (logger, log_inspector) = TestLogger::memory();
1325            let version_provider = version_provider_with_open_api_version(signer_version);
1326            client.api_version_provider = Arc::new(version_provider);
1327            client.logger = logger;
1328            let _server_mock = server.mock(|when, then| {
1329                when.method(POST).path("/register-signatures");
1330                then.status(202)
1331                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1332            });
1333
1334            assert!(
1335                Version::parse(aggregator_version).unwrap()
1336                    > Version::parse(signer_version).unwrap()
1337            );
1338
1339            client
1340                .register_signature(
1341                    &SignedEntityType::dummy(),
1342                    &single_signature,
1343                    &ProtocolMessage::default(),
1344                )
1345                .await
1346                .unwrap();
1347
1348            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1349        }
1350
1351        #[tokio::test]
1352        async fn test_register_signature_ok_410_log_warning_if_api_version_mismatch() {
1353            let aggregator_version = "2.0.0";
1354            let signer_version = "1.0.0";
1355            let single_signature = fake_data::single_signature((1..5).collect());
1356            let (server, mut client) = setup_server_and_client();
1357            let (logger, log_inspector) = TestLogger::memory();
1358            let version_provider = version_provider_with_open_api_version(signer_version);
1359            client.api_version_provider = Arc::new(version_provider);
1360            client.logger = logger;
1361            let _server_mock = server.mock(|when, then| {
1362                when.method(POST).path("/register-signatures");
1363                then.status(410)
1364                    .body(
1365                        serde_json::to_vec(&ClientError::new(
1366                            "already_aggregated".to_string(),
1367                            "too late".to_string(),
1368                        ))
1369                        .unwrap(),
1370                    )
1371                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1372            });
1373
1374            assert!(
1375                Version::parse(aggregator_version).unwrap()
1376                    > Version::parse(signer_version).unwrap()
1377            );
1378
1379            client
1380                .register_signature(
1381                    &SignedEntityType::dummy(),
1382                    &single_signature,
1383                    &ProtocolMessage::default(),
1384                )
1385                .await
1386                .unwrap();
1387
1388            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1389        }
1390    }
1391}