mithril_aggregator/http_server/routes/
middlewares.rs

1use slog::{Logger, debug};
2use std::convert::Infallible;
3use std::sync::Arc;
4use warp::Filter;
5
6use mithril_common::api_version::APIVersionProvider;
7use mithril_common::{MITHRIL_CLIENT_TYPE_HEADER, MITHRIL_ORIGIN_TAG_HEADER};
8
9use crate::database::repository::SignerGetter;
10use crate::dependency_injection::EpochServiceWrapper;
11use crate::event_store::{EventMessage, TransmitterService};
12use crate::http_server::routes::http_server_child_logger;
13use crate::http_server::routes::router::{RouterConfig, RouterState};
14use crate::services::{CertifierService, MessageService, ProverService, SignedEntityService};
15use crate::{
16    MetricsService, SignerRegisterer, SingleSignatureAuthenticator, VerificationKeyStorer,
17};
18
19/// Extract a value from the configuration
20pub fn extract_config<D: Clone + Send>(
21    state: &RouterState,
22    extract: fn(&RouterConfig) -> D,
23) -> impl Filter<Extract = (D,), Error = Infallible> + Clone + use<D> {
24    let config_value = extract(&state.configuration);
25    warp::any().map(move || config_value.clone())
26}
27
28/// With logger middleware
29pub(crate) fn with_logger(
30    router_state: &RouterState,
31) -> impl Filter<Extract = (Logger,), Error = Infallible> + Clone + use<> {
32    let logger = http_server_child_logger(&router_state.dependencies.root_logger);
33    warp::any().map(move || logger.clone())
34}
35
36/// Log to apply each time a route is called
37///
38/// Example of log produced: `POST /aggregator/register-signatures 202 Accepted`
39pub(crate) fn log_route_call(
40    router_state: &RouterState,
41) -> warp::log::Log<impl Fn(warp::log::Info<'_>) + Clone + use<>> {
42    let logger = http_server_child_logger(&router_state.dependencies.root_logger);
43    warp::log::custom(move |info| {
44        debug!(
45            logger,
46            "{} {} {}",
47            info.method(),
48            info.path(),
49            info.status()
50        )
51    })
52}
53
54/// With signer registerer middleware
55pub fn with_signer_registerer(
56    router_state: &RouterState,
57) -> impl Filter<Extract = (Arc<dyn SignerRegisterer>,), Error = Infallible> + Clone + use<> {
58    let signer_register = router_state.dependencies.signer_registerer.clone();
59    warp::any().map(move || signer_register.clone())
60}
61
62/// With signer getter middleware
63pub fn with_signer_getter(
64    router_state: &RouterState,
65) -> impl Filter<Extract = (Arc<dyn SignerGetter>,), Error = Infallible> + Clone + use<> {
66    let signer_getter = router_state.dependencies.signer_getter.clone();
67    warp::any().map(move || signer_getter.clone())
68}
69
70/// With Event transmitter middleware
71pub fn with_event_transmitter(
72    router_state: &RouterState,
73) -> impl Filter<Extract = (Arc<TransmitterService<EventMessage>>,), Error = Infallible> + Clone + use<>
74{
75    let event_transmitter = router_state.dependencies.event_transmitter.clone();
76    warp::any().map(move || event_transmitter.clone())
77}
78
79/// With certifier service middleware
80pub fn with_certifier_service(
81    router_state: &RouterState,
82) -> impl Filter<Extract = (Arc<dyn CertifierService>,), Error = Infallible> + Clone + use<> {
83    let certifier_service = router_state.dependencies.certifier_service.clone();
84    warp::any().map(move || certifier_service.clone())
85}
86
87/// With epoch service middleware
88pub fn with_epoch_service(
89    router_state: &RouterState,
90) -> impl Filter<Extract = (EpochServiceWrapper,), Error = Infallible> + Clone + use<> {
91    let epoch_service = router_state.dependencies.epoch_service.clone();
92    warp::any().map(move || epoch_service.clone())
93}
94
95/// With signed entity service
96pub fn with_signed_entity_service(
97    router_state: &RouterState,
98) -> impl Filter<Extract = (Arc<dyn SignedEntityService>,), Error = Infallible> + Clone + use<> {
99    let signed_entity_service = router_state.dependencies.signed_entity_service.clone();
100    warp::any().map(move || signed_entity_service.clone())
101}
102
103/// With verification key store
104pub fn with_verification_key_store(
105    router_state: &RouterState,
106) -> impl Filter<Extract = (Arc<dyn VerificationKeyStorer>,), Error = Infallible> + Clone + use<> {
107    let verification_key_store = router_state.dependencies.verification_key_store.clone();
108    warp::any().map(move || verification_key_store.clone())
109}
110
111/// With API version provider
112pub fn with_api_version_provider(
113    router_state: &RouterState,
114) -> impl Filter<Extract = (Arc<APIVersionProvider>,), Error = Infallible> + Clone + use<> {
115    let api_version_provider = router_state.dependencies.api_version_provider.clone();
116    warp::any().map(move || api_version_provider.clone())
117}
118
119/// With Message service
120pub fn with_http_message_service(
121    router_state: &RouterState,
122) -> impl Filter<Extract = (Arc<dyn MessageService>,), Error = Infallible> + Clone + use<> {
123    let message_service = router_state.dependencies.message_service.clone();
124    warp::any().map(move || message_service.clone())
125}
126
127/// With Prover service
128pub fn with_prover_service(
129    router_state: &RouterState,
130) -> impl Filter<Extract = (Arc<dyn ProverService>,), Error = Infallible> + Clone + use<> {
131    let prover_service = router_state.dependencies.prover_service.clone();
132    warp::any().map(move || prover_service.clone())
133}
134
135/// With Single Signature Authenticator
136pub fn with_single_signature_authenticator(
137    router_state: &RouterState,
138) -> impl Filter<Extract = (Arc<SingleSignatureAuthenticator>,), Error = Infallible> + Clone + use<>
139{
140    let single_signer_authenticator = router_state.dependencies.single_signer_authenticator.clone();
141    warp::any().map(move || single_signer_authenticator.clone())
142}
143
144/// With Metrics service
145pub fn with_metrics_service(
146    router_state: &RouterState,
147) -> impl Filter<Extract = (Arc<MetricsService>,), Error = Infallible> + Clone + use<> {
148    let metrics_service = router_state.dependencies.metrics_service.clone();
149    warp::any().map(move || metrics_service.clone())
150}
151
152/// With origin tag of the request
153pub fn with_origin_tag(
154    router_state: &RouterState,
155) -> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone + use<> {
156    let white_list = router_state.configuration.origin_tag_white_list.clone();
157
158    warp::header::optional::<String>(MITHRIL_ORIGIN_TAG_HEADER).map(move |name: Option<String>| {
159        name.filter(|tag| white_list.contains(tag)).or(Some("NA".to_string()))
160    })
161}
162
163pub fn with_client_type()
164-> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone {
165    let authorized_client_types = ["CLI", "WASM", "LIBRARY", "NA"];
166
167    warp::header::optional::<String>(MITHRIL_CLIENT_TYPE_HEADER).map(
168        move |client_type: Option<String>| {
169            client_type
170                .filter(|ct| authorized_client_types.contains(&ct.as_str()))
171                .or(Some("NA".to_string()))
172        },
173    )
174}
175
176pub fn with_client_metadata(
177    router_state: &RouterState,
178) -> impl Filter<Extract = (ClientMetadata,), Error = warp::reject::Rejection> + Clone + use<> {
179    with_origin_tag(router_state).and(with_client_type()).map(
180        |origin_tag: Option<String>, client_type: Option<String>| ClientMetadata {
181            origin_tag,
182            client_type,
183        },
184    )
185}
186
187pub struct ClientMetadata {
188    pub origin_tag: Option<String>,
189    pub client_type: Option<String>,
190}
191
192pub mod validators {
193    use crate::http_server::validators::ProverTransactionsHashValidator;
194
195    use super::*;
196
197    /// With Prover Transactions Hash Validator
198    pub fn with_prover_transactions_hash_validator(
199        router_state: &RouterState,
200    ) -> impl Filter<Extract = (ProverTransactionsHashValidator,), Error = Infallible> + Clone + use<>
201    {
202        let max_hashes = router_state
203            .configuration
204            .cardano_transactions_prover_max_hashes_allowed_by_request;
205
206        warp::any().map(move || ProverTransactionsHashValidator::new(max_hashes))
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use serde_json::Value;
213    use std::convert::Infallible;
214    use warp::{
215        Filter,
216        http::{Method, Response, StatusCode},
217        hyper::body::Bytes,
218        test::request,
219    };
220
221    use mithril_common::test::double::Dummy;
222
223    use crate::http_server::routes::reply;
224    use crate::initialize_dependencies;
225
226    use super::*;
227
228    async fn route_handler(value: Option<String>) -> Result<impl warp::Reply, Infallible> {
229        Ok(reply::json(&value, StatusCode::OK))
230    }
231
232    fn get_body(response: Response<Bytes>) -> Option<String> {
233        let result: &Value = &serde_json::from_slice(response.body()).unwrap();
234        result.as_str().map(|s| s.to_string())
235    }
236
237    mod origin_tag {
238        use super::*;
239        use std::{collections::HashSet, path::PathBuf};
240
241        use mithril_common::temp_dir;
242
243        fn route_with_origin_tag(
244            router_state: &RouterState,
245        ) -> impl Filter<Extract = (impl warp::Reply + use<>,), Error = warp::Rejection> + Clone + use<>
246        {
247            warp::path!("route")
248                .and(warp::get())
249                .and(with_origin_tag(router_state))
250                .and_then(route_handler)
251        }
252
253        async fn router_state_with_origin_whitelist(
254            temp_dir: PathBuf,
255            tags: &[&str],
256        ) -> RouterState {
257            let origin_tag_white_list: HashSet<String> =
258                tags.iter().map(ToString::to_string).collect();
259
260            RouterState::new(
261                Arc::new(initialize_dependencies(temp_dir).await),
262                RouterConfig {
263                    origin_tag_white_list,
264                    ..RouterConfig::dummy()
265                },
266            )
267        }
268
269        #[tokio::test]
270        async fn test_origin_tag_with_value_in_white_list_return_the_tag() {
271            let router_state =
272                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
273
274            let response = request()
275                .header(MITHRIL_ORIGIN_TAG_HEADER, "CLIENT_TAG")
276                .method(Method::GET.as_str())
277                .path("/route")
278                .reply(&route_with_origin_tag(&router_state))
279                .await;
280
281            assert_eq!(Some("CLIENT_TAG".to_string()), get_body(response));
282        }
283
284        #[tokio::test]
285        async fn test_origin_tag_with_value_not_in_white_list_return_na() {
286            let router_state =
287                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
288
289            let response = request()
290                .header(MITHRIL_ORIGIN_TAG_HEADER, "UNKNOWN_TAG")
291                .method(Method::GET.as_str())
292                .path("/route")
293                .reply(&route_with_origin_tag(&router_state))
294                .await;
295
296            assert_eq!(Some("NA".to_string()), get_body(response));
297        }
298
299        #[tokio::test]
300        async fn test_without_origin_tag() {
301            let router_state =
302                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
303
304            let response = request()
305                .method(Method::GET.as_str())
306                .path("/route")
307                .reply(&route_with_origin_tag(&router_state))
308                .await;
309
310            assert_eq!(Some("NA".to_string()), get_body(response));
311        }
312    }
313
314    mod client_type {
315        use super::*;
316
317        async fn request_with_client_type(client_type_header_value: &str) -> Response<Bytes> {
318            request()
319                .method(Method::GET.as_str())
320                .header(MITHRIL_CLIENT_TYPE_HEADER, client_type_header_value)
321                .path("/route")
322                .reply(&route_with_client_type())
323                .await
324        }
325
326        fn route_with_client_type()
327        -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
328            warp::path!("route")
329                .and(warp::get())
330                .and(with_client_type())
331                .and_then(route_handler)
332        }
333
334        #[tokio::test]
335        async fn test_with_client_type_use_na_as_default_value_if_header_not_set() {
336            let response = request()
337                .method(Method::GET.as_str())
338                .path("/route")
339                .reply(&route_with_client_type())
340                .await;
341
342            assert_eq!(Some("NA".to_string()), get_body(response));
343        }
344
345        #[tokio::test]
346        async fn test_with_client_type_only_authorize_specific_values() {
347            let response: Response<Bytes> = request_with_client_type("CLI").await;
348            assert_eq!(Some("CLI".to_string()), get_body(response));
349
350            let response: Response<Bytes> = request_with_client_type("WASM").await;
351            assert_eq!(Some("WASM".to_string()), get_body(response));
352
353            let response: Response<Bytes> = request_with_client_type("LIBRARY").await;
354            assert_eq!(Some("LIBRARY".to_string()), get_body(response));
355
356            let response: Response<Bytes> = request_with_client_type("NA").await;
357            assert_eq!(Some("NA".to_string()), get_body(response));
358
359            let response: Response<Bytes> = request_with_client_type("UNKNOWN").await;
360            assert_eq!(Some("NA".to_string()), get_body(response));
361        }
362    }
363}