mithril_signer/services/
aggregator_client.rs

1use anyhow::anyhow;
2use async_trait::async_trait;
3use reqwest::header::{self, HeaderValue};
4use reqwest::{self, Client, Proxy, RequestBuilder, Response, StatusCode};
5use semver::Version;
6use slog::{Logger, debug, error, warn};
7use std::{io, sync::Arc, time::Duration};
8use thiserror::Error;
9
10use mithril_common::{
11    MITHRIL_API_VERSION_HEADER, MITHRIL_SIGNER_VERSION_HEADER, StdError,
12    api_version::APIVersionProvider,
13    entities::{
14        ClientError, Epoch, ProtocolMessage, ServerError, SignedEntityType, Signer, SingleSignature,
15    },
16    logging::LoggerExtensions,
17    messages::{
18        AggregatorFeaturesMessage, EpochSettingsMessage, TryFromMessageAdapter, TryToMessageAdapter,
19    },
20};
21
22use crate::entities::SignerEpochSettings;
23use crate::message_adapters::{
24    FromEpochSettingsAdapter, ToRegisterSignatureMessageAdapter, ToRegisterSignerMessageAdapter,
25};
26
27const JSON_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/json");
28
29const API_VERSION_MISMATCH_WARNING_MESSAGE: &str =
30    "OpenAPI version may be incompatible, please update your Mithril node to the latest version.";
31
32/// Error structure for the Aggregator Client.
33#[derive(Error, Debug)]
34pub enum AggregatorClientError {
35    /// The aggregator host has returned a technical error.
36    #[error("remote server technical error")]
37    RemoteServerTechnical(#[source] StdError),
38
39    /// The aggregator host responded it cannot fulfill our request.
40    #[error("remote server logical error")]
41    RemoteServerLogical(#[source] StdError),
42
43    /// Could not reach aggregator.
44    #[error("remote server unreachable")]
45    RemoteServerUnreachable(#[source] StdError),
46
47    /// Unhandled status code
48    #[error("unhandled status code: {0}, response text: {1}")]
49    UnhandledStatusCode(StatusCode, String),
50
51    /// Could not parse response.
52    #[error("json parsing failed")]
53    JsonParseFailed(#[source] StdError),
54
55    /// Mostly network errors.
56    #[error("Input/Output error")]
57    IOError(#[from] io::Error),
58
59    /// HTTP client creation error
60    #[error("HTTP client creation failed")]
61    HTTPClientCreation(#[source] StdError),
62
63    /// Proxy creation error
64    #[error("proxy creation failed")]
65    ProxyCreation(#[source] StdError),
66
67    /// Adapter error
68    #[error("adapter failed")]
69    Adapter(#[source] StdError),
70
71    /// No signer registration round opened yet
72    #[error("a signer registration round is not opened yet, please try again later")]
73    RegistrationRoundNotYetOpened(#[source] StdError),
74}
75
76impl AggregatorClientError {
77    /// Create an `AggregatorClientError` from a response.
78    ///
79    /// This method is meant to be used after handling domain-specific cases leaving only
80    /// 4xx or 5xx status codes.
81    /// Otherwise, it will return an `UnhandledStatusCode` error.
82    pub async fn from_response(response: Response) -> Self {
83        let error_code = response.status();
84
85        if error_code.is_client_error() {
86            let root_cause = Self::get_root_cause(response).await;
87            Self::RemoteServerLogical(anyhow!(root_cause))
88        } else if error_code.is_server_error() {
89            let root_cause = Self::get_root_cause(response).await;
90            match error_code.as_u16() {
91                550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)),
92                _ => Self::RemoteServerTechnical(anyhow!(root_cause)),
93            }
94        } else {
95            let response_text = response.text().await.unwrap_or_default();
96            Self::UnhandledStatusCode(error_code, response_text)
97        }
98    }
99
100    async fn get_root_cause(response: Response) -> String {
101        let error_code = response.status();
102        let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
103        let is_json = response
104            .headers()
105            .get(header::CONTENT_TYPE)
106            .is_some_and(|ct| JSON_CONTENT_TYPE == ct);
107
108        if is_json {
109            let json_value: serde_json::Value = response.json().await.unwrap_or_default();
110
111            if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
112                format!(
113                    "{}: {}: {}",
114                    canonical_reason, client_error.label, client_error.message
115                )
116            } else if let Ok(server_error) =
117                serde_json::from_value::<ServerError>(json_value.clone())
118            {
119                format!("{}: {}", canonical_reason, server_error.message)
120            } else if json_value.is_null() {
121                canonical_reason.to_string()
122            } else {
123                format!("{canonical_reason}: {json_value}")
124            }
125        } else {
126            let response_text = response.text().await.unwrap_or_default();
127            format!("{canonical_reason}: {response_text}")
128        }
129    }
130}
131
132/// Trait for mocking and testing a `AggregatorClient`
133#[cfg_attr(test, mockall::automock)]
134#[async_trait]
135pub trait AggregatorClient: Sync + Send {
136    /// Retrieves epoch settings from the aggregator
137    async fn retrieve_epoch_settings(
138        &self,
139    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError>;
140
141    /// Registers signer with the aggregator.
142    async fn register_signer(
143        &self,
144        epoch: Epoch,
145        signer: &Signer,
146    ) -> Result<(), AggregatorClientError>;
147
148    /// Registers single signature with the aggregator.
149    async fn register_signature(
150        &self,
151        signed_entity_type: &SignedEntityType,
152        signature: &SingleSignature,
153        protocol_message: &ProtocolMessage,
154    ) -> Result<(), AggregatorClientError>;
155
156    /// Retrieves aggregator features message from the aggregator
157    async fn retrieve_aggregator_features(
158        &self,
159    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError>;
160}
161
162/// AggregatorHTTPClient is a http client for an aggregator
163pub struct AggregatorHTTPClient {
164    aggregator_endpoint: String,
165    relay_endpoint: Option<String>,
166    api_version_provider: Arc<APIVersionProvider>,
167    timeout_duration: Option<Duration>,
168    logger: Logger,
169}
170
171impl AggregatorHTTPClient {
172    /// AggregatorHTTPClient factory
173    pub fn new(
174        aggregator_endpoint: String,
175        relay_endpoint: Option<String>,
176        api_version_provider: Arc<APIVersionProvider>,
177        timeout_duration: Option<Duration>,
178        logger: Logger,
179    ) -> Self {
180        let logger = logger.new_with_component_name::<Self>();
181        debug!(logger, "New AggregatorHTTPClient created");
182        Self {
183            aggregator_endpoint,
184            relay_endpoint,
185            api_version_provider,
186            timeout_duration,
187            logger,
188        }
189    }
190
191    fn prepare_http_client(&self) -> Result<Client, AggregatorClientError> {
192        let client = match &self.relay_endpoint {
193            Some(relay_endpoint) => Client::builder()
194                .proxy(
195                    Proxy::all(relay_endpoint)
196                        .map_err(|e| AggregatorClientError::ProxyCreation(anyhow!(e)))?,
197                )
198                .build()
199                .map_err(|e| AggregatorClientError::HTTPClientCreation(anyhow!(e)))?,
200            None => Client::new(),
201        };
202
203        Ok(client)
204    }
205
206    /// Forge a client request adding protocol version in the headers.
207    pub fn prepare_request_builder(&self, request_builder: RequestBuilder) -> RequestBuilder {
208        let request_builder = request_builder
209            .header(
210                MITHRIL_API_VERSION_HEADER,
211                self.api_version_provider
212                    .compute_current_version()
213                    .unwrap()
214                    .to_string(),
215            )
216            .header(MITHRIL_SIGNER_VERSION_HEADER, env!("CARGO_PKG_VERSION"));
217
218        if let Some(duration) = self.timeout_duration {
219            request_builder.timeout(duration)
220        } else {
221            request_builder
222        }
223    }
224
225    /// Check API version mismatch and log a warning if the aggregator's version is more recent.
226    fn warn_if_api_version_mismatch(&self, response: &Response) {
227        let aggregator_version = response
228            .headers()
229            .get(MITHRIL_API_VERSION_HEADER)
230            .and_then(|v| v.to_str().ok())
231            .and_then(|s| Version::parse(s).ok());
232
233        let signer_version = self.api_version_provider.compute_current_version();
234
235        match (aggregator_version, signer_version) {
236            (Some(aggregator), Ok(signer)) if signer < aggregator => {
237                warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE;
238                    "aggregator_version" => %aggregator,
239                    "signer_version" => %signer,
240                );
241            }
242            (Some(_), Err(error)) => {
243                error!(
244                    self.logger,
245                    "Failed to compute the current signer API version";
246                    "error" => error.to_string()
247                );
248            }
249            _ => {}
250        }
251    }
252}
253
254#[async_trait]
255impl AggregatorClient for AggregatorHTTPClient {
256    async fn retrieve_epoch_settings(
257        &self,
258    ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
259        debug!(self.logger, "Retrieve epoch settings");
260        let url = format!("{}/epoch-settings", self.aggregator_endpoint);
261        let response = self
262            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
263            .send()
264            .await;
265
266        match response {
267            Ok(response) => match response.status() {
268                StatusCode::OK => {
269                    self.warn_if_api_version_mismatch(&response);
270                    match response.json::<EpochSettingsMessage>().await {
271                        Ok(message) => {
272                            let epoch_settings = FromEpochSettingsAdapter::try_adapt(message)
273                                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
274                            Ok(Some(epoch_settings))
275                        }
276                        Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
277                    }
278                }
279                _ => Err(AggregatorClientError::from_response(response).await),
280            },
281            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
282        }
283    }
284
285    async fn register_signer(
286        &self,
287        epoch: Epoch,
288        signer: &Signer,
289    ) -> Result<(), AggregatorClientError> {
290        debug!(self.logger, "Register signer");
291        let url = format!("{}/register-signer", self.aggregator_endpoint);
292        let register_signer_message =
293            ToRegisterSignerMessageAdapter::try_adapt((epoch, signer.to_owned()))
294                .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
295        let response = self
296            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
297            .json(&register_signer_message)
298            .send()
299            .await;
300
301        match response {
302            Ok(response) => match response.status() {
303                StatusCode::CREATED => {
304                    self.warn_if_api_version_mismatch(&response);
305
306                    Ok(())
307                }
308                _ => Err(AggregatorClientError::from_response(response).await),
309            },
310            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
311        }
312    }
313
314    async fn register_signature(
315        &self,
316        signed_entity_type: &SignedEntityType,
317        signature: &SingleSignature,
318        protocol_message: &ProtocolMessage,
319    ) -> Result<(), AggregatorClientError> {
320        debug!(self.logger, "Register signature");
321        let url = format!("{}/register-signatures", self.aggregator_endpoint);
322        let register_single_signature_message = ToRegisterSignatureMessageAdapter::try_adapt((
323            signed_entity_type.to_owned(),
324            signature.to_owned(),
325            protocol_message,
326        ))
327        .map_err(|e| AggregatorClientError::Adapter(anyhow!(e)))?;
328        let response = self
329            .prepare_request_builder(self.prepare_http_client()?.post(url.clone()))
330            .json(&register_single_signature_message)
331            .send()
332            .await;
333
334        match response {
335            Ok(response) => match response.status() {
336                StatusCode::CREATED | StatusCode::ACCEPTED => {
337                    self.warn_if_api_version_mismatch(&response);
338
339                    Ok(())
340                }
341                StatusCode::GONE => {
342                    self.warn_if_api_version_mismatch(&response);
343                    let root_cause = AggregatorClientError::get_root_cause(response).await;
344                    debug!(self.logger, "Message already certified or expired"; "details" => &root_cause);
345
346                    Ok(())
347                }
348                _ => Err(AggregatorClientError::from_response(response).await),
349            },
350            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
351        }
352    }
353
354    async fn retrieve_aggregator_features(
355        &self,
356    ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
357        debug!(self.logger, "Retrieve aggregator features message");
358        let url = format!("{}/", self.aggregator_endpoint);
359        let response = self
360            .prepare_request_builder(self.prepare_http_client()?.get(url.clone()))
361            .send()
362            .await;
363
364        match response {
365            Ok(response) => match response.status() {
366                StatusCode::OK => {
367                    self.warn_if_api_version_mismatch(&response);
368
369                    Ok(response
370                        .json::<AggregatorFeaturesMessage>()
371                        .await
372                        .map_err(|e| AggregatorClientError::JsonParseFailed(anyhow!(e)))?)
373                }
374                _ => Err(AggregatorClientError::from_response(response).await),
375            },
376            Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
377        }
378    }
379}
380
381#[cfg(test)]
382pub(crate) mod dumb {
383    use mithril_common::test::double::Dummy;
384    use tokio::sync::RwLock;
385
386    use super::*;
387
388    /// This aggregator client is intended to be used by test services.
389    /// It actually does not communicate with an aggregator host but mimics this behavior.
390    /// It is driven by a Tester that controls the data it can return, and it can return its internal state for testing.
391    pub struct DumbAggregatorClient {
392        epoch_settings: RwLock<Option<SignerEpochSettings>>,
393        last_registered_signer: RwLock<Option<Signer>>,
394        aggregator_features: RwLock<AggregatorFeaturesMessage>,
395    }
396
397    impl DumbAggregatorClient {
398        /// Return the last signer that called with the `register` method.
399        pub async fn get_last_registered_signer(&self) -> Option<Signer> {
400            self.last_registered_signer.read().await.clone()
401        }
402
403        pub async fn set_aggregator_features(
404            &self,
405            aggregator_features: AggregatorFeaturesMessage,
406        ) {
407            let mut aggregator_features_writer = self.aggregator_features.write().await;
408            *aggregator_features_writer = aggregator_features;
409        }
410    }
411
412    impl Default for DumbAggregatorClient {
413        fn default() -> Self {
414            Self {
415                epoch_settings: RwLock::new(Some(SignerEpochSettings::dummy())),
416                last_registered_signer: RwLock::new(None),
417                aggregator_features: RwLock::new(AggregatorFeaturesMessage::dummy()),
418            }
419        }
420    }
421
422    #[async_trait]
423    impl AggregatorClient for DumbAggregatorClient {
424        async fn retrieve_epoch_settings(
425            &self,
426        ) -> Result<Option<SignerEpochSettings>, AggregatorClientError> {
427            let epoch_settings = self.epoch_settings.read().await.clone();
428
429            Ok(epoch_settings)
430        }
431
432        /// Registers signer with the aggregator
433        async fn register_signer(
434            &self,
435            _epoch: Epoch,
436            signer: &Signer,
437        ) -> Result<(), AggregatorClientError> {
438            let mut last_registered_signer = self.last_registered_signer.write().await;
439            let signer = signer.clone();
440            *last_registered_signer = Some(signer);
441
442            Ok(())
443        }
444
445        /// Registers single signature with the aggregator
446        async fn register_signature(
447            &self,
448            _signed_entity_type: &SignedEntityType,
449            _signature: &SingleSignature,
450            _protocol_message: &ProtocolMessage,
451        ) -> Result<(), AggregatorClientError> {
452            Ok(())
453        }
454
455        async fn retrieve_aggregator_features(
456            &self,
457        ) -> Result<AggregatorFeaturesMessage, AggregatorClientError> {
458            let aggregator_features = self.aggregator_features.read().await;
459            Ok(aggregator_features.clone())
460        }
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use std::collections::HashMap;
467
468    use http::response::Builder as HttpResponseBuilder;
469    use httpmock::prelude::*;
470    use semver::Version;
471    use serde_json::json;
472
473    use mithril_common::entities::Epoch;
474    use mithril_common::messages::TryFromMessageAdapter;
475    use mithril_common::test::{
476        double::{Dummy, DummyApiVersionDiscriminantSource, fake_data},
477        logging::MemoryDrainForTestInspector,
478    };
479
480    use crate::test_tools::TestLogger;
481
482    use super::*;
483
484    macro_rules! assert_is_error {
485        ($error:expr, $error_type:pat) => {
486            assert!(
487                matches!($error, $error_type),
488                "Expected {} error, got '{:?}'.",
489                stringify!($error_type),
490                $error
491            );
492        };
493    }
494
495    fn setup_client<U: Into<String>>(server_url: U) -> AggregatorHTTPClient {
496        let discriminant_source = DummyApiVersionDiscriminantSource::new("dummy");
497        let api_version_provider = APIVersionProvider::new(Arc::new(discriminant_source));
498
499        AggregatorHTTPClient::new(
500            server_url.into(),
501            None,
502            Arc::new(api_version_provider),
503            None,
504            TestLogger::stdout(),
505        )
506    }
507
508    fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
509        let server = MockServer::start();
510        let aggregator_endpoint = server.url("");
511        let client = setup_client(&aggregator_endpoint);
512
513        (server, client)
514    }
515
516    fn set_returning_500(server: &MockServer) {
517        server.mock(|_, then| {
518            then.status(500).body("an error occurred");
519        });
520    }
521
522    fn set_unparsable_json(server: &MockServer) {
523        server.mock(|_, then| {
524            then.status(200).body("this is not a json");
525        });
526    }
527
528    fn build_text_response<T: Into<String>>(status_code: StatusCode, body: T) -> Response {
529        HttpResponseBuilder::new()
530            .status(status_code)
531            .body(body.into())
532            .unwrap()
533            .into()
534    }
535
536    fn build_json_response<T: serde::Serialize>(status_code: StatusCode, body: &T) -> Response {
537        HttpResponseBuilder::new()
538            .status(status_code)
539            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
540            .body(serde_json::to_string(&body).unwrap())
541            .unwrap()
542            .into()
543    }
544
545    macro_rules! assert_error_text_contains {
546        ($error: expr, $expect_contains: expr) => {
547            let error = &$error;
548            assert!(
549                error.contains($expect_contains),
550                "Expected error message to contain '{}'\ngot '{error:?}'",
551                $expect_contains,
552            );
553        };
554    }
555
556    #[tokio::test]
557    async fn test_aggregator_features_ok_200() {
558        let (server, client) = setup_server_and_client();
559        let message_expected = AggregatorFeaturesMessage::dummy();
560        let _server_mock = server.mock(|when, then| {
561            when.path("/");
562            then.status(200).body(json!(message_expected).to_string());
563        });
564
565        let message = client.retrieve_aggregator_features().await.unwrap();
566
567        assert_eq!(message_expected, message);
568    }
569
570    #[tokio::test]
571    async fn test_aggregator_features_ko_500() {
572        let (server, client) = setup_server_and_client();
573        set_returning_500(&server);
574
575        let error = client.retrieve_aggregator_features().await.unwrap_err();
576
577        assert_is_error!(error, AggregatorClientError::RemoteServerTechnical(_));
578    }
579
580    #[tokio::test]
581    async fn test_aggregator_features_ko_json_serialization() {
582        let (server, client) = setup_server_and_client();
583        set_unparsable_json(&server);
584
585        let error = client.retrieve_aggregator_features().await.unwrap_err();
586
587        assert_is_error!(error, AggregatorClientError::JsonParseFailed(_));
588    }
589
590    #[tokio::test]
591    async fn test_aggregator_features_timeout() {
592        let (server, mut client) = setup_server_and_client();
593        client.timeout_duration = Some(Duration::from_millis(10));
594        let _server_mock = server.mock(|when, then| {
595            when.path("/");
596            then.delay(Duration::from_millis(100));
597        });
598
599        let error = client.retrieve_aggregator_features().await.unwrap_err();
600
601        assert_is_error!(error, AggregatorClientError::RemoteServerUnreachable(_));
602    }
603
604    #[tokio::test]
605    async fn test_epoch_settings_ok_200() {
606        let (server, client) = setup_server_and_client();
607        let epoch_settings_expected = EpochSettingsMessage::dummy();
608        let _server_mock = server.mock(|when, then| {
609            when.path("/epoch-settings");
610            then.status(200).body(json!(epoch_settings_expected).to_string());
611        });
612
613        let epoch_settings = client.retrieve_epoch_settings().await;
614        epoch_settings.as_ref().expect("unexpected error");
615        assert_eq!(
616            FromEpochSettingsAdapter::try_adapt(epoch_settings_expected).unwrap(),
617            epoch_settings.unwrap().unwrap()
618        );
619    }
620
621    #[tokio::test]
622    async fn test_epoch_settings_ko_500() {
623        let (server, client) = setup_server_and_client();
624        let _server_mock = server.mock(|when, then| {
625            when.path("/epoch-settings");
626            then.status(500).body("an error occurred");
627        });
628
629        match client.retrieve_epoch_settings().await.unwrap_err() {
630            AggregatorClientError::RemoteServerTechnical(_) => (),
631            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
632        };
633    }
634
635    #[tokio::test]
636    async fn test_epoch_settings_timeout() {
637        let (server, mut client) = setup_server_and_client();
638        client.timeout_duration = Some(Duration::from_millis(10));
639        let _server_mock = server.mock(|when, then| {
640            when.path("/epoch-settings");
641            then.delay(Duration::from_millis(100));
642        });
643
644        let error = client
645            .retrieve_epoch_settings()
646            .await
647            .expect_err("retrieve_epoch_settings should fail");
648
649        assert!(
650            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
651            "unexpected error type: {error:?}"
652        );
653    }
654
655    #[tokio::test]
656    async fn test_register_signer_ok_201() {
657        let epoch = Epoch(1);
658        let single_signers = fake_data::signers(1);
659        let single_signer = single_signers.first().unwrap();
660        let (server, client) = setup_server_and_client();
661        let _server_mock = server.mock(|when, then| {
662            when.method(POST).path("/register-signer");
663            then.status(201);
664        });
665
666        let register_signer = client.register_signer(epoch, single_signer).await;
667        register_signer.expect("unexpected error");
668    }
669
670    #[tokio::test]
671    async fn test_register_signer_ko_400() {
672        let epoch = Epoch(1);
673        let single_signers = fake_data::signers(1);
674        let single_signer = single_signers.first().unwrap();
675        let (server, client) = setup_server_and_client();
676        let _server_mock = server.mock(|when, then| {
677            when.method(POST).path("/register-signer");
678            then.status(400).body(
679                serde_json::to_vec(&ClientError::new(
680                    "error".to_string(),
681                    "an error".to_string(),
682                ))
683                .unwrap(),
684            );
685        });
686
687        match client.register_signer(epoch, single_signer).await.unwrap_err() {
688            AggregatorClientError::RemoteServerLogical(_) => (),
689            err => {
690                panic!(
691                    "Expected a AggregatorClientError::RemoteServerLogical error, got '{err:?}'."
692                )
693            }
694        };
695    }
696
697    #[tokio::test]
698    async fn test_register_signer_ko_500() {
699        let epoch = Epoch(1);
700        let single_signers = fake_data::signers(1);
701        let single_signer = single_signers.first().unwrap();
702        let (server, client) = setup_server_and_client();
703        let _server_mock = server.mock(|when, then| {
704            when.method(POST).path("/register-signer");
705            then.status(500).body("an error occurred");
706        });
707
708        match client.register_signer(epoch, single_signer).await.unwrap_err() {
709            AggregatorClientError::RemoteServerTechnical(_) => (),
710            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
711        };
712    }
713
714    #[tokio::test]
715    async fn test_register_signer_timeout() {
716        let epoch = Epoch(1);
717        let single_signers = fake_data::signers(1);
718        let single_signer = single_signers.first().unwrap();
719        let (server, mut client) = setup_server_and_client();
720        client.timeout_duration = Some(Duration::from_millis(10));
721        let _server_mock = server.mock(|when, then| {
722            when.method(POST).path("/register-signer");
723            then.delay(Duration::from_millis(100));
724        });
725
726        let error = client
727            .register_signer(epoch, single_signer)
728            .await
729            .expect_err("register_signer should fail");
730
731        assert!(
732            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
733            "unexpected error type: {error:?}"
734        );
735    }
736
737    #[tokio::test]
738    async fn test_register_signature_ok_201() {
739        let single_signature = fake_data::single_signature((1..5).collect());
740        let (server, client) = setup_server_and_client();
741        let _server_mock = server.mock(|when, then| {
742            when.method(POST).path("/register-signatures");
743            then.status(201);
744        });
745
746        let register_signature = client
747            .register_signature(
748                &SignedEntityType::dummy(),
749                &single_signature,
750                &ProtocolMessage::default(),
751            )
752            .await;
753        register_signature.expect("unexpected error");
754    }
755
756    #[tokio::test]
757    async fn test_register_signature_ok_202() {
758        let single_signature = fake_data::single_signature((1..5).collect());
759        let (server, client) = setup_server_and_client();
760        let _server_mock = server.mock(|when, then| {
761            when.method(POST).path("/register-signatures");
762            then.status(202);
763        });
764
765        let register_signature = client
766            .register_signature(
767                &SignedEntityType::dummy(),
768                &single_signature,
769                &ProtocolMessage::default(),
770            )
771            .await;
772        register_signature.expect("unexpected error");
773    }
774
775    #[tokio::test]
776    async fn test_register_signature_ko_400() {
777        let single_signature = fake_data::single_signature((1..5).collect());
778        let (server, client) = setup_server_and_client();
779        let _server_mock = server.mock(|when, then| {
780            when.method(POST).path("/register-signatures");
781            then.status(400).body(
782                serde_json::to_vec(&ClientError::new(
783                    "error".to_string(),
784                    "an error".to_string(),
785                ))
786                .unwrap(),
787            );
788        });
789
790        match client
791            .register_signature(
792                &SignedEntityType::dummy(),
793                &single_signature,
794                &ProtocolMessage::default(),
795            )
796            .await
797            .unwrap_err()
798        {
799            AggregatorClientError::RemoteServerLogical(_) => (),
800            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
801        };
802    }
803
804    #[tokio::test]
805    async fn test_register_signature_ok_410_log_response_body() {
806        let (logger, log_inspector) = TestLogger::memory();
807
808        let single_signature = fake_data::single_signature((1..5).collect());
809        let (server, mut client) = setup_server_and_client();
810        client.logger = logger;
811        let _server_mock = server.mock(|when, then| {
812            when.method(POST).path("/register-signatures");
813            then.status(410).body(
814                serde_json::to_vec(&ClientError::new(
815                    "already_aggregated".to_string(),
816                    "too late".to_string(),
817                ))
818                .unwrap(),
819            );
820        });
821
822        client
823            .register_signature(
824                &SignedEntityType::dummy(),
825                &single_signature,
826                &ProtocolMessage::default(),
827            )
828            .await
829            .expect("Should not fail when status is 410 (GONE)");
830
831        assert!(log_inspector.contains_log("already_aggregated"));
832        assert!(log_inspector.contains_log("too late"));
833    }
834
835    #[tokio::test]
836    async fn test_register_signature_ko_409() {
837        let single_signature = fake_data::single_signature((1..5).collect());
838        let (server, client) = setup_server_and_client();
839        let _server_mock = server.mock(|when, then| {
840            when.method(POST).path("/register-signatures");
841            then.status(409);
842        });
843
844        match client
845            .register_signature(
846                &SignedEntityType::dummy(),
847                &single_signature,
848                &ProtocolMessage::default(),
849            )
850            .await
851            .unwrap_err()
852        {
853            AggregatorClientError::RemoteServerLogical(_) => (),
854            e => panic!("Expected Aggregator::RemoteServerLogical error, got '{e:?}'."),
855        }
856    }
857
858    #[tokio::test]
859    async fn test_register_signature_ko_500() {
860        let single_signature = fake_data::single_signature((1..5).collect());
861        let (server, client) = setup_server_and_client();
862        let _server_mock = server.mock(|when, then| {
863            when.method(POST).path("/register-signatures");
864            then.status(500).body("an error occurred");
865        });
866
867        match client
868            .register_signature(
869                &SignedEntityType::dummy(),
870                &single_signature,
871                &ProtocolMessage::default(),
872            )
873            .await
874            .unwrap_err()
875        {
876            AggregatorClientError::RemoteServerTechnical(_) => (),
877            e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."),
878        };
879    }
880
881    #[tokio::test]
882    async fn test_register_signature_timeout() {
883        let single_signature = fake_data::single_signature((1..5).collect());
884        let (server, mut client) = setup_server_and_client();
885        client.timeout_duration = Some(Duration::from_millis(10));
886        let _server_mock = server.mock(|when, then| {
887            when.method(POST).path("/register-signatures");
888            then.delay(Duration::from_millis(100));
889        });
890
891        let error = client
892            .register_signature(
893                &SignedEntityType::dummy(),
894                &single_signature,
895                &ProtocolMessage::default(),
896            )
897            .await
898            .expect_err("register_signature should fail");
899
900        assert!(
901            matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
902            "unexpected error type: {error:?}"
903        );
904    }
905
906    #[tokio::test]
907    async fn test_4xx_errors_are_handled_as_remote_server_logical() {
908        let response = build_text_response(StatusCode::BAD_REQUEST, "error text");
909        let handled_error = AggregatorClientError::from_response(response).await;
910
911        assert!(
912            matches!(
913                handled_error,
914                AggregatorClientError::RemoteServerLogical(..)
915            ),
916            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
917        );
918    }
919
920    #[tokio::test]
921    async fn test_5xx_errors_are_handled_as_remote_server_technical() {
922        let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text");
923        let handled_error = AggregatorClientError::from_response(response).await;
924
925        assert!(
926            matches!(
927                handled_error,
928                AggregatorClientError::RemoteServerTechnical(..)
929            ),
930            "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'",
931        );
932    }
933
934    #[tokio::test]
935    async fn test_550_error_is_handled_as_registration_round_not_yet_opened() {
936        let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available");
937        let handled_error = AggregatorClientError::from_response(response).await;
938
939        assert!(
940            matches!(
941                handled_error,
942                AggregatorClientError::RegistrationRoundNotYetOpened(..)
943            ),
944            "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'",
945        );
946    }
947
948    #[tokio::test]
949    async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text()
950     {
951        let response = build_text_response(StatusCode::OK, "ok text");
952        let handled_error = AggregatorClientError::from_response(response).await;
953
954        assert!(
955            matches!(
956                handled_error,
957                AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text")
958            ),
959            "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'",
960        );
961    }
962
963    #[tokio::test]
964    async fn test_root_cause_of_non_json_response_contains_response_plain_text() {
965        let error_text = "An error occurred; please try again later.";
966        let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text);
967
968        assert_error_text_contains!(
969            AggregatorClientError::get_root_cause(response).await,
970            "expectation failed: An error occurred; please try again later."
971        );
972    }
973
974    #[tokio::test]
975    async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message()
976     {
977        let client_error = ClientError::new("label", "message");
978        let response = build_json_response(StatusCode::BAD_REQUEST, &client_error);
979
980        assert_error_text_contains!(
981            AggregatorClientError::get_root_cause(response).await,
982            "bad request: label: message"
983        );
984    }
985
986    #[tokio::test]
987    async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message()
988     {
989        let server_error = ServerError::new("message");
990        let response = build_json_response(StatusCode::BAD_REQUEST, &server_error);
991
992        assert_error_text_contains!(
993            AggregatorClientError::get_root_cause(response).await,
994            "bad request: message"
995        );
996    }
997
998    #[tokio::test]
999    async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() {
1000        let response = build_json_response(
1001            StatusCode::INTERNAL_SERVER_ERROR,
1002            &json!({ "second": "unknown", "first": "foreign" }),
1003        );
1004
1005        assert_error_text_contains!(
1006            AggregatorClientError::get_root_cause(response).await,
1007            r#"internal server error: {"first":"foreign","second":"unknown"}"#
1008        );
1009    }
1010
1011    #[tokio::test]
1012    async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() {
1013        let response = HttpResponseBuilder::new()
1014            .status(StatusCode::BAD_REQUEST)
1015            .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE)
1016            .body(r#"{"invalid":"unexpected dot", "key": "value".}"#)
1017            .unwrap()
1018            .into();
1019
1020        let root_cause = AggregatorClientError::get_root_cause(response).await;
1021
1022        assert_error_text_contains!(root_cause, "bad request");
1023        assert!(
1024            !root_cause.contains("bad request: "),
1025            "Expected error message should not contain additional information \ngot '{root_cause:?}'"
1026        );
1027    }
1028
1029    #[tokio::test]
1030    async fn test_sends_accept_encoding_header() {
1031        let (server, client) = setup_server_and_client();
1032        server.mock(|when, then| {
1033            when.matches(|req| {
1034                let headers = req.headers.clone().expect("HTTP headers not found");
1035                let accept_encoding_header = headers
1036                    .iter()
1037                    .find(|(name, _values)| name.to_lowercase() == "accept-encoding")
1038                    .expect("Accept-Encoding header not found");
1039
1040                let header_value = accept_encoding_header.clone().1;
1041                ["gzip", "br", "deflate", "zstd"]
1042                    .iter()
1043                    .all(|&value| header_value.contains(value))
1044            });
1045
1046            then.status(201);
1047        });
1048
1049        client
1050            .register_signature(
1051                &SignedEntityType::dummy(),
1052                &fake_data::single_signature((1..5).collect()),
1053                &ProtocolMessage::default(),
1054            )
1055            .await
1056            .expect("Should succeed with Accept-Encoding header");
1057    }
1058
1059    mod warn_if_api_version_mismatch {
1060        use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension;
1061
1062        use super::*;
1063
1064        fn version_provider_with_open_api_version<V: Into<String>>(
1065            version: V,
1066        ) -> APIVersionProvider {
1067            let mut version_provider = version_provider_without_open_api_version();
1068            let mut open_api_versions = HashMap::new();
1069            open_api_versions.insert(
1070                "openapi.yaml".to_string(),
1071                Version::parse(&version.into()).unwrap(),
1072            );
1073            version_provider.update_open_api_versions(open_api_versions);
1074
1075            version_provider
1076        }
1077
1078        fn version_provider_without_open_api_version() -> APIVersionProvider {
1079            let mut version_provider =
1080                APIVersionProvider::new(Arc::new(DummyApiVersionDiscriminantSource::new("dummy")));
1081            version_provider.update_open_api_versions(HashMap::new());
1082
1083            version_provider
1084        }
1085
1086        fn build_fake_response_with_header<K: Into<String>, V: Into<String>>(
1087            key: K,
1088            value: V,
1089        ) -> Response {
1090            HttpResponseBuilder::new()
1091                .header(key.into(), value.into())
1092                .body("whatever")
1093                .unwrap()
1094                .into()
1095        }
1096
1097        fn assert_api_version_warning_logged<A: Into<String>, S: Into<String>>(
1098            log_inspector: &MemoryDrainForTestInspector,
1099            aggregator_version: A,
1100            signer_version: S,
1101        ) {
1102            assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1103            assert!(
1104                log_inspector
1105                    .contains_log(&format!("aggregator_version={}", aggregator_version.into()))
1106            );
1107            assert!(
1108                log_inspector.contains_log(&format!("signer_version={}", signer_version.into()))
1109            );
1110        }
1111
1112        #[test]
1113        fn test_logs_warning_when_aggregator_api_version_is_newer() {
1114            let aggregator_version = "2.0.0";
1115            let signer_version = "1.0.0";
1116            let (logger, log_inspector) = TestLogger::memory();
1117            let version_provider = version_provider_with_open_api_version(signer_version);
1118            let mut client = setup_client("whatever");
1119            client.api_version_provider = Arc::new(version_provider);
1120            client.logger = logger;
1121            let response =
1122                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1123
1124            assert!(
1125                Version::parse(aggregator_version).unwrap()
1126                    > Version::parse(signer_version).unwrap()
1127            );
1128
1129            client.warn_if_api_version_mismatch(&response);
1130
1131            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1132        }
1133
1134        #[test]
1135        fn test_no_warning_logged_when_versions_match() {
1136            let version = "1.0.0";
1137            let (logger, log_inspector) = TestLogger::memory();
1138            let version_provider = version_provider_with_open_api_version(version);
1139            let mut client = setup_client("whatever");
1140            client.api_version_provider = Arc::new(version_provider);
1141            client.logger = logger;
1142            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
1143
1144            client.warn_if_api_version_mismatch(&response);
1145
1146            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1147        }
1148
1149        #[test]
1150        fn test_no_warning_logged_when_aggregator_api_version_is_older() {
1151            let aggregator_version = "1.0.0";
1152            let signer_version = "2.0.0";
1153            let (logger, log_inspector) = TestLogger::memory();
1154            let version_provider = version_provider_with_open_api_version(signer_version);
1155            let mut client = setup_client("whatever");
1156            client.api_version_provider = Arc::new(version_provider);
1157            client.logger = logger;
1158            let response =
1159                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1160
1161            assert!(
1162                Version::parse(aggregator_version).unwrap()
1163                    < Version::parse(signer_version).unwrap()
1164            );
1165
1166            client.warn_if_api_version_mismatch(&response);
1167
1168            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1169        }
1170
1171        #[test]
1172        fn test_does_not_log_or_fail_when_header_is_missing() {
1173            let (logger, log_inspector) = TestLogger::memory();
1174            let mut client = setup_client("whatever");
1175            client.logger = logger;
1176            let response =
1177                build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever");
1178
1179            client.warn_if_api_version_mismatch(&response);
1180
1181            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1182        }
1183
1184        #[test]
1185        fn test_does_not_log_or_fail_when_header_is_not_a_version() {
1186            let (logger, log_inspector) = TestLogger::memory();
1187            let mut client = setup_client("whatever");
1188            client.logger = logger;
1189            let response =
1190                build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version");
1191
1192            client.warn_if_api_version_mismatch(&response);
1193
1194            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1195        }
1196
1197        #[test]
1198        fn test_logs_error_when_signer_version_cannot_be_computed() {
1199            let (logger, log_inspector) = TestLogger::memory();
1200            let version_provider = version_provider_without_open_api_version();
1201            let mut client = setup_client("whatever");
1202            client.api_version_provider = Arc::new(version_provider);
1203            client.logger = logger;
1204            let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
1205
1206            client.warn_if_api_version_mismatch(&response);
1207
1208            assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE));
1209        }
1210
1211        #[tokio::test]
1212        async fn test_aggregator_features_ok_200_log_warning_if_api_version_mismatch() {
1213            let aggregator_version = "2.0.0";
1214            let signer_version = "1.0.0";
1215            let (server, mut client) = setup_server_and_client();
1216            let (logger, log_inspector) = TestLogger::memory();
1217            let version_provider = version_provider_with_open_api_version(signer_version);
1218            client.api_version_provider = Arc::new(version_provider);
1219            client.logger = logger;
1220
1221            let message_expected = AggregatorFeaturesMessage::dummy();
1222            let _server_mock = server.mock(|when, then| {
1223                when.path("/");
1224                then.status(200)
1225                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1226                    .body(json!(message_expected).to_string());
1227            });
1228
1229            assert!(
1230                Version::parse(aggregator_version).unwrap()
1231                    > Version::parse(signer_version).unwrap()
1232            );
1233
1234            client.retrieve_aggregator_features().await.unwrap();
1235
1236            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1237        }
1238
1239        #[tokio::test]
1240        async fn test_epoch_settings_ok_200_log_warning_if_api_version_mismatch() {
1241            let aggregator_version = "2.0.0";
1242            let signer_version = "1.0.0";
1243            let (server, mut client) = setup_server_and_client();
1244            let (logger, log_inspector) = TestLogger::memory();
1245            let version_provider = version_provider_with_open_api_version(signer_version);
1246            client.api_version_provider = Arc::new(version_provider);
1247            client.logger = logger;
1248
1249            let epoch_settings_expected = EpochSettingsMessage::dummy();
1250            let _server_mock = server.mock(|when, then| {
1251                when.path("/epoch-settings");
1252                then.status(200)
1253                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version)
1254                    .body(json!(epoch_settings_expected).to_string());
1255            });
1256
1257            assert!(
1258                Version::parse(aggregator_version).unwrap()
1259                    > Version::parse(signer_version).unwrap()
1260            );
1261
1262            client.retrieve_epoch_settings().await.unwrap();
1263
1264            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1265        }
1266
1267        #[tokio::test]
1268        async fn test_register_signer_ok_201_log_warning_if_api_version_mismatch() {
1269            let aggregator_version = "2.0.0";
1270            let signer_version = "1.0.0";
1271            let epoch = Epoch(1);
1272            let single_signers = fake_data::signers(1);
1273            let single_signer = single_signers.first().unwrap();
1274            let (server, mut client) = setup_server_and_client();
1275            let (logger, log_inspector) = TestLogger::memory();
1276            let version_provider = version_provider_with_open_api_version(signer_version);
1277            client.api_version_provider = Arc::new(version_provider);
1278            client.logger = logger;
1279            let _server_mock = server.mock(|when, then| {
1280                when.method(POST).path("/register-signer");
1281                then.status(201)
1282                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1283            });
1284
1285            assert!(
1286                Version::parse(aggregator_version).unwrap()
1287                    > Version::parse(signer_version).unwrap()
1288            );
1289
1290            client.register_signer(epoch, single_signer).await.unwrap();
1291
1292            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1293        }
1294
1295        #[tokio::test]
1296        async fn test_register_signature_ok_201_log_warning_if_api_version_mismatch() {
1297            let aggregator_version = "2.0.0";
1298            let signer_version = "1.0.0";
1299            let single_signature = fake_data::single_signature((1..5).collect());
1300            let (server, mut client) = setup_server_and_client();
1301            let (logger, log_inspector) = TestLogger::memory();
1302            let version_provider = version_provider_with_open_api_version(signer_version);
1303            client.api_version_provider = Arc::new(version_provider);
1304            client.logger = logger;
1305            let _server_mock = server.mock(|when, then| {
1306                when.method(POST).path("/register-signatures");
1307                then.status(201)
1308                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1309            });
1310
1311            assert!(
1312                Version::parse(aggregator_version).unwrap()
1313                    > Version::parse(signer_version).unwrap()
1314            );
1315
1316            client
1317                .register_signature(
1318                    &SignedEntityType::dummy(),
1319                    &single_signature,
1320                    &ProtocolMessage::default(),
1321                )
1322                .await
1323                .expect("Should not fail");
1324
1325            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1326        }
1327
1328        #[tokio::test]
1329        async fn test_register_signature_ok_202_log_warning_if_api_version_mismatch() {
1330            let aggregator_version = "2.0.0";
1331            let signer_version = "1.0.0";
1332            let single_signature = fake_data::single_signature((1..5).collect());
1333            let (server, mut client) = setup_server_and_client();
1334            let (logger, log_inspector) = TestLogger::memory();
1335            let version_provider = version_provider_with_open_api_version(signer_version);
1336            client.api_version_provider = Arc::new(version_provider);
1337            client.logger = logger;
1338            let _server_mock = server.mock(|when, then| {
1339                when.method(POST).path("/register-signatures");
1340                then.status(202)
1341                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1342            });
1343
1344            assert!(
1345                Version::parse(aggregator_version).unwrap()
1346                    > Version::parse(signer_version).unwrap()
1347            );
1348
1349            client
1350                .register_signature(
1351                    &SignedEntityType::dummy(),
1352                    &single_signature,
1353                    &ProtocolMessage::default(),
1354                )
1355                .await
1356                .unwrap();
1357
1358            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1359        }
1360
1361        #[tokio::test]
1362        async fn test_register_signature_ok_410_log_warning_if_api_version_mismatch() {
1363            let aggregator_version = "2.0.0";
1364            let signer_version = "1.0.0";
1365            let single_signature = fake_data::single_signature((1..5).collect());
1366            let (server, mut client) = setup_server_and_client();
1367            let (logger, log_inspector) = TestLogger::memory();
1368            let version_provider = version_provider_with_open_api_version(signer_version);
1369            client.api_version_provider = Arc::new(version_provider);
1370            client.logger = logger;
1371            let _server_mock = server.mock(|when, then| {
1372                when.method(POST).path("/register-signatures");
1373                then.status(410)
1374                    .body(
1375                        serde_json::to_vec(&ClientError::new(
1376                            "already_aggregated".to_string(),
1377                            "too late".to_string(),
1378                        ))
1379                        .unwrap(),
1380                    )
1381                    .header(MITHRIL_API_VERSION_HEADER, aggregator_version);
1382            });
1383
1384            assert!(
1385                Version::parse(aggregator_version).unwrap()
1386                    > Version::parse(signer_version).unwrap()
1387            );
1388
1389            client
1390                .register_signature(
1391                    &SignedEntityType::dummy(),
1392                    &single_signature,
1393                    &ProtocolMessage::default(),
1394                )
1395                .await
1396                .unwrap();
1397
1398            assert_api_version_warning_logged(&log_inspector, aggregator_version, signer_version);
1399        }
1400    }
1401}