mithril_aggregator/http_server/routes/
middlewares.rs

1use serde::de::DeserializeOwned;
2use slog::{Logger, debug};
3use std::convert::Infallible;
4use std::sync::Arc;
5use warp::{Filter, Rejection};
6
7use mithril_common::api_version::APIVersionProvider;
8use mithril_common::{MITHRIL_CLIENT_TYPE_HEADER, MITHRIL_ORIGIN_TAG_HEADER};
9
10use crate::database::repository::SignerGetter;
11use crate::dependency_injection::EpochServiceWrapper;
12use crate::event_store::{EventMessage, TransmitterService};
13use crate::http_server::routes::http_server_child_logger;
14use crate::http_server::routes::router::{RouterConfig, RouterState};
15use crate::services::{CertifierService, LegacyProverService, MessageService, SignedEntityService};
16use crate::{
17    MetricsService, SignerRegisterer, SingleSignatureAuthenticator, VerificationKeyStorer,
18};
19
20/// Extract a value from the body with a maximum length limit
21pub(crate) fn json_with_max_length<T: DeserializeOwned + Send>(
22    max_length: u64,
23) -> impl Filter<Extract = (T,), Error = Rejection> + Copy {
24    warp::body::content_length_limit(max_length).and(warp::body::json())
25}
26
27/// Extract a value from the configuration
28pub(crate) fn extract_config<D: Clone + Send>(
29    state: &RouterState,
30    extract: fn(&RouterConfig) -> D,
31) -> impl Filter<Extract = (D,), Error = Infallible> + Clone + use<D> {
32    let config_value = extract(&state.configuration);
33    warp::any().map(move || config_value.clone())
34}
35
36/// With logger middleware
37pub(crate) fn with_logger(
38    router_state: &RouterState,
39) -> impl Filter<Extract = (Logger,), Error = Infallible> + Clone + use<> {
40    let logger = http_server_child_logger(&router_state.dependencies.root_logger);
41    warp::any().map(move || logger.clone())
42}
43
44/// Log to apply each time a route is called
45///
46/// Example of log produced: `POST /aggregator/register-signatures 202 Accepted`
47pub(crate) fn log_route_call(
48    router_state: &RouterState,
49) -> warp::log::Log<impl Fn(warp::log::Info<'_>) + Clone + use<>> {
50    let logger = http_server_child_logger(&router_state.dependencies.root_logger);
51    warp::log::custom(move |info| {
52        debug!(
53            logger,
54            "{} {} {}",
55            info.method(),
56            info.path(),
57            info.status()
58        )
59    })
60}
61
62/// With signer registerer middleware
63pub fn with_signer_registerer(
64    router_state: &RouterState,
65) -> impl Filter<Extract = (Arc<dyn SignerRegisterer>,), Error = Infallible> + Clone + use<> {
66    let signer_register = router_state.dependencies.signer_registerer.clone();
67    warp::any().map(move || signer_register.clone())
68}
69
70/// With signer getter middleware
71pub fn with_signer_getter(
72    router_state: &RouterState,
73) -> impl Filter<Extract = (Arc<dyn SignerGetter>,), Error = Infallible> + Clone + use<> {
74    let signer_getter = router_state.dependencies.signer_getter.clone();
75    warp::any().map(move || signer_getter.clone())
76}
77
78/// With Event transmitter middleware
79pub fn with_event_transmitter(
80    router_state: &RouterState,
81) -> impl Filter<Extract = (Arc<TransmitterService<EventMessage>>,), Error = Infallible> + Clone + use<>
82{
83    let event_transmitter = router_state.dependencies.event_transmitter.clone();
84    warp::any().map(move || event_transmitter.clone())
85}
86
87/// With certifier service middleware
88pub fn with_certifier_service(
89    router_state: &RouterState,
90) -> impl Filter<Extract = (Arc<dyn CertifierService>,), Error = Infallible> + Clone + use<> {
91    let certifier_service = router_state.dependencies.certifier_service.clone();
92    warp::any().map(move || certifier_service.clone())
93}
94
95/// With epoch service middleware
96pub fn with_epoch_service(
97    router_state: &RouterState,
98) -> impl Filter<Extract = (EpochServiceWrapper,), Error = Infallible> + Clone + use<> {
99    let epoch_service = router_state.dependencies.epoch_service.clone();
100    warp::any().map(move || epoch_service.clone())
101}
102
103/// With signed entity service
104pub fn with_signed_entity_service(
105    router_state: &RouterState,
106) -> impl Filter<Extract = (Arc<dyn SignedEntityService>,), Error = Infallible> + Clone + use<> {
107    let signed_entity_service = router_state.dependencies.signed_entity_service.clone();
108    warp::any().map(move || signed_entity_service.clone())
109}
110
111/// With verification key store
112pub fn with_verification_key_store(
113    router_state: &RouterState,
114) -> impl Filter<Extract = (Arc<dyn VerificationKeyStorer>,), Error = Infallible> + Clone + use<> {
115    let verification_key_store = router_state.dependencies.verification_key_store.clone();
116    warp::any().map(move || verification_key_store.clone())
117}
118
119/// With API version provider
120pub fn with_api_version_provider(
121    router_state: &RouterState,
122) -> impl Filter<Extract = (Arc<APIVersionProvider>,), Error = Infallible> + Clone + use<> {
123    let api_version_provider = router_state.dependencies.api_version_provider.clone();
124    warp::any().map(move || api_version_provider.clone())
125}
126
127/// With Message service
128pub fn with_http_message_service(
129    router_state: &RouterState,
130) -> impl Filter<Extract = (Arc<dyn MessageService>,), Error = Infallible> + Clone + use<> {
131    let message_service = router_state.dependencies.message_service.clone();
132    warp::any().map(move || message_service.clone())
133}
134
135/// With Prover service
136pub fn with_prover_service(
137    router_state: &RouterState,
138) -> impl Filter<Extract = (Arc<dyn LegacyProverService>,), Error = Infallible> + Clone + use<> {
139    let prover_service = router_state.dependencies.legacy_prover_service.clone();
140    warp::any().map(move || prover_service.clone())
141}
142
143/// With Single Signature Authenticator
144pub fn with_single_signature_authenticator(
145    router_state: &RouterState,
146) -> impl Filter<Extract = (Arc<SingleSignatureAuthenticator>,), Error = Infallible> + Clone + use<>
147{
148    let single_signer_authenticator = router_state.dependencies.single_signer_authenticator.clone();
149    warp::any().map(move || single_signer_authenticator.clone())
150}
151
152/// With Metrics service
153pub fn with_metrics_service(
154    router_state: &RouterState,
155) -> impl Filter<Extract = (Arc<MetricsService>,), Error = Infallible> + Clone + use<> {
156    let metrics_service = router_state.dependencies.metrics_service.clone();
157    warp::any().map(move || metrics_service.clone())
158}
159
160/// With origin tag of the request
161pub fn with_origin_tag(
162    router_state: &RouterState,
163) -> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone + use<> {
164    let white_list = router_state.configuration.origin_tag_white_list.clone();
165
166    warp::header::optional::<String>(MITHRIL_ORIGIN_TAG_HEADER).map(move |name: Option<String>| {
167        name.filter(|tag| white_list.contains(tag)).or(Some("NA".to_string()))
168    })
169}
170
171pub fn with_client_type()
172-> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone {
173    let authorized_client_types = ["CLI", "WASM", "LIBRARY", "NA"];
174
175    warp::header::optional::<String>(MITHRIL_CLIENT_TYPE_HEADER).map(
176        move |client_type: Option<String>| {
177            client_type
178                .filter(|ct| authorized_client_types.contains(&ct.as_str()))
179                .or(Some("NA".to_string()))
180        },
181    )
182}
183
184pub fn with_client_metadata(
185    router_state: &RouterState,
186) -> impl Filter<Extract = (ClientMetadata,), Error = warp::reject::Rejection> + Clone + use<> {
187    with_origin_tag(router_state).and(with_client_type()).map(
188        |origin_tag: Option<String>, client_type: Option<String>| ClientMetadata {
189            origin_tag,
190            client_type,
191        },
192    )
193}
194
195pub struct ClientMetadata {
196    pub origin_tag: Option<String>,
197    pub client_type: Option<String>,
198}
199
200pub mod validators {
201    use crate::http_server::validators::{HashKind, ProverHashValidator};
202
203    use super::*;
204
205    /// With Prover Transactions Hash Validator
206    pub fn with_prover_transactions_hash_validator(
207        router_state: &RouterState,
208    ) -> impl Filter<Extract = (ProverHashValidator,), Error = Infallible> + Clone + use<> {
209        let max_hashes = router_state
210            .configuration
211            .cardano_prover_max_hashes_allowed_by_request;
212
213        warp::any().map(move || ProverHashValidator::new(HashKind::Transaction, max_hashes))
214    }
215
216    /// With Prover Block Hash Validator
217    pub fn with_prover_block_hash_validator(
218        router_state: &RouterState,
219    ) -> impl Filter<Extract = (ProverHashValidator,), Error = Infallible> + Clone + use<> {
220        let max_hashes = router_state
221            .configuration
222            .cardano_prover_max_hashes_allowed_by_request;
223
224        warp::any().map(move || ProverHashValidator::new(HashKind::Block, max_hashes))
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use serde_json::Value;
231    use std::convert::Infallible;
232    use warp::{
233        Filter,
234        http::{Method, Response, StatusCode},
235        hyper::body::Bytes,
236        test::request,
237    };
238
239    use mithril_common::test::double::Dummy;
240
241    use crate::http_server::routes::reply;
242    use crate::initialize_dependencies;
243
244    use super::*;
245
246    async fn route_handler(value: Option<String>) -> Result<impl warp::Reply, Infallible> {
247        Ok(reply::json(&value, StatusCode::OK))
248    }
249
250    fn get_body(response: Response<Bytes>) -> Option<String> {
251        let result: &Value = &serde_json::from_slice(response.body()).unwrap();
252        result.as_str().map(|s| s.to_string())
253    }
254
255    mod origin_tag {
256        use super::*;
257        use std::{collections::HashSet, path::PathBuf};
258
259        use mithril_common::temp_dir;
260
261        fn route_with_origin_tag(
262            router_state: &RouterState,
263        ) -> impl Filter<Extract = (impl warp::Reply + use<>,), Error = warp::Rejection> + Clone + use<>
264        {
265            warp::path!("route")
266                .and(warp::get())
267                .and(with_origin_tag(router_state))
268                .and_then(route_handler)
269        }
270
271        async fn router_state_with_origin_whitelist(
272            temp_dir: PathBuf,
273            tags: &[&str],
274        ) -> RouterState {
275            let origin_tag_white_list: HashSet<String> =
276                tags.iter().map(ToString::to_string).collect();
277
278            RouterState::new(
279                Arc::new(initialize_dependencies(temp_dir).await),
280                RouterConfig {
281                    origin_tag_white_list,
282                    ..RouterConfig::dummy()
283                },
284            )
285        }
286
287        #[tokio::test]
288        async fn test_origin_tag_with_value_in_white_list_return_the_tag() {
289            let router_state =
290                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
291
292            let response = request()
293                .header(MITHRIL_ORIGIN_TAG_HEADER, "CLIENT_TAG")
294                .method(Method::GET.as_str())
295                .path("/route")
296                .reply(&route_with_origin_tag(&router_state))
297                .await;
298
299            assert_eq!(Some("CLIENT_TAG".to_string()), get_body(response));
300        }
301
302        #[tokio::test]
303        async fn test_origin_tag_with_value_not_in_white_list_return_na() {
304            let router_state =
305                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
306
307            let response = request()
308                .header(MITHRIL_ORIGIN_TAG_HEADER, "UNKNOWN_TAG")
309                .method(Method::GET.as_str())
310                .path("/route")
311                .reply(&route_with_origin_tag(&router_state))
312                .await;
313
314            assert_eq!(Some("NA".to_string()), get_body(response));
315        }
316
317        #[tokio::test]
318        async fn test_without_origin_tag() {
319            let router_state =
320                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
321
322            let response = request()
323                .method(Method::GET.as_str())
324                .path("/route")
325                .reply(&route_with_origin_tag(&router_state))
326                .await;
327
328            assert_eq!(Some("NA".to_string()), get_body(response));
329        }
330    }
331
332    mod client_type {
333        use super::*;
334
335        async fn request_with_client_type(client_type_header_value: &str) -> Response<Bytes> {
336            request()
337                .method(Method::GET.as_str())
338                .header(MITHRIL_CLIENT_TYPE_HEADER, client_type_header_value)
339                .path("/route")
340                .reply(&route_with_client_type())
341                .await
342        }
343
344        fn route_with_client_type()
345        -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
346            warp::path!("route")
347                .and(warp::get())
348                .and(with_client_type())
349                .and_then(route_handler)
350        }
351
352        #[tokio::test]
353        async fn test_with_client_type_use_na_as_default_value_if_header_not_set() {
354            let response = request()
355                .method(Method::GET.as_str())
356                .path("/route")
357                .reply(&route_with_client_type())
358                .await;
359
360            assert_eq!(Some("NA".to_string()), get_body(response));
361        }
362
363        #[tokio::test]
364        async fn test_with_client_type_only_authorize_specific_values() {
365            let response: Response<Bytes> = request_with_client_type("CLI").await;
366            assert_eq!(Some("CLI".to_string()), get_body(response));
367
368            let response: Response<Bytes> = request_with_client_type("WASM").await;
369            assert_eq!(Some("WASM".to_string()), get_body(response));
370
371            let response: Response<Bytes> = request_with_client_type("LIBRARY").await;
372            assert_eq!(Some("LIBRARY".to_string()), get_body(response));
373
374            let response: Response<Bytes> = request_with_client_type("NA").await;
375            assert_eq!(Some("NA".to_string()), get_body(response));
376
377            let response: Response<Bytes> = request_with_client_type("UNKNOWN").await;
378            assert_eq!(Some("NA".to_string()), get_body(response));
379        }
380    }
381
382    mod json_with_max_length {
383        use warp::test::RequestBuilder;
384
385        use super::*;
386
387        #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
388        struct TestPayload {
389            d: String,
390        }
391
392        impl TestPayload {
393            fn new(payload: &str) -> Self {
394                TestPayload {
395                    d: payload.to_string(),
396                }
397            }
398
399            fn json_bytes_len(&self) -> usize {
400                serde_json::to_string(self).unwrap().len()
401            }
402        }
403
404        fn route_with_json_limit(
405            max_length: u64,
406        ) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
407            warp::path!("json-route")
408                .and(warp::post())
409                .and(json_with_max_length(max_length))
410                .then(|json: TestPayload| async move { reply::json(&json, StatusCode::OK) })
411        }
412
413        fn test_request(json: &TestPayload) -> RequestBuilder {
414            // Note: `.json` set both `content-type` and `content-length` headers
415            request().method(Method::POST.as_str()).json(json).path("/json-route")
416        }
417
418        #[tokio::test]
419        async fn test_accepts_request_with_content_length_just_below_threshold() {
420            let payload = TestPayload::new("test");
421            let threshold = payload.json_bytes_len();
422            let response = test_request(&payload)
423                .reply(&route_with_json_limit(threshold as u64))
424                .await;
425
426            assert_eq!(response.status(), StatusCode::OK);
427        }
428
429        #[tokio::test]
430        async fn test_rejects_request_with_content_length_too_large() {
431            let payload = TestPayload::new("test");
432            let threshold = payload.json_bytes_len() - 1;
433            let response = test_request(&payload)
434                .reply(&route_with_json_limit(threshold as u64))
435                .await;
436
437            assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
438        }
439
440        #[tokio::test]
441        async fn test_rejects_request_missing_content_length_header() {
442            let payload = TestPayload::new("test");
443            let response = test_request(&payload)
444                // Unset content-length header
445                .header("content-length", "")
446                .reply(&route_with_json_limit(100))
447                .await;
448
449            assert_eq!(response.status(), StatusCode::LENGTH_REQUIRED);
450        }
451    }
452}