mithril_aggregator_client/query/post/
post_increment_snapshot_statistic.rs

1use async_trait::async_trait;
2use reqwest::StatusCode;
3
4use mithril_common::messages::SnapshotDownloadMessage;
5
6use crate::AggregatorHttpClientResult;
7use crate::query::{AggregatorQuery, QueryContext, QueryMethod};
8
9/// Query to notify the aggregator that a snapshot has been downloaded.
10pub struct PostIncrementSnapshotDownloadStatisticQuery {
11    message: SnapshotDownloadMessage,
12}
13
14impl PostIncrementSnapshotDownloadStatisticQuery {
15    /// Instantiate a new query that will notify the aggregator that a snapshot has been downloaded.
16    pub fn new(message: SnapshotDownloadMessage) -> Self {
17        Self { message }
18    }
19}
20
21#[cfg_attr(target_family = "wasm", async_trait(?Send))]
22#[cfg_attr(not(target_family = "wasm"), async_trait)]
23impl AggregatorQuery for PostIncrementSnapshotDownloadStatisticQuery {
24    type Response = ();
25    type Body = SnapshotDownloadMessage;
26
27    fn method() -> QueryMethod {
28        QueryMethod::Post
29    }
30
31    fn route(&self) -> String {
32        "statistics/snapshot".to_string()
33    }
34
35    fn body(&self) -> Option<Self::Body> {
36        Some(self.message.clone())
37    }
38
39    async fn handle_response(
40        &self,
41        context: QueryContext,
42    ) -> AggregatorHttpClientResult<Self::Response> {
43        match context.response.status() {
44            StatusCode::CREATED => Ok(()),
45            _ => Err(context.unhandled_status_code().await),
46        }
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use httpmock::Method::POST;
53
54    use mithril_common::entities::ClientError;
55    use mithril_common::test::double::Dummy;
56
57    use crate::AggregatorHttpClientError;
58    use crate::test::{assert_error_matches, setup_server_and_client};
59
60    use super::*;
61
62    #[tokio::test]
63    async fn test_increment_snapshot_download_statistics_ok_201() {
64        let expected_message = SnapshotDownloadMessage::dummy();
65        let (server, client) = setup_server_and_client();
66        let _server_mock = server.mock(|when, then| {
67            when.method(POST)
68                .path("/statistics/snapshot")
69                .body(serde_json::to_string(&expected_message).unwrap());
70            then.status(201);
71        });
72
73        let statistic_query = client
74            .send(PostIncrementSnapshotDownloadStatisticQuery::new(
75                expected_message,
76            ))
77            .await;
78        statistic_query.expect("unexpected error");
79    }
80
81    #[tokio::test]
82    async fn test_increment_snapshot_download_statistics_ko_400() {
83        let (server, client) = setup_server_and_client();
84        let _server_mock = server.mock(|when, then| {
85            when.method(POST).any_request();
86            then.status(400)
87                .body(serde_json::to_vec(&ClientError::dummy()).unwrap());
88        });
89
90        let error = client
91            .send(PostIncrementSnapshotDownloadStatisticQuery::new(
92                SnapshotDownloadMessage::dummy(),
93            ))
94            .await
95            .unwrap_err();
96
97        assert_error_matches!(error, AggregatorHttpClientError::RemoteServerLogical(_));
98    }
99
100    #[tokio::test]
101    async fn test_increment_snapshot_download_statistics_ko_500() {
102        let (server, client) = setup_server_and_client();
103        let _server_mock = server.mock(|when, then| {
104            when.method(POST).any_request();
105            then.status(500).body("an error occurred");
106        });
107
108        let error = client
109            .send(PostIncrementSnapshotDownloadStatisticQuery::new(
110                SnapshotDownloadMessage::dummy(),
111            ))
112            .await
113            .unwrap_err();
114
115        assert_error_matches!(error, AggregatorHttpClientError::RemoteServerTechnical(_));
116    }
117}