mithril_aggregator/http_server/routes/
middlewares.rs

1use slog::{Logger, debug};
2use std::convert::Infallible;
3use std::sync::Arc;
4use warp::Filter;
5
6use mithril_common::api_version::APIVersionProvider;
7use mithril_common::{MITHRIL_CLIENT_TYPE_HEADER, MITHRIL_ORIGIN_TAG_HEADER};
8
9use crate::database::repository::SignerGetter;
10use crate::dependency_injection::EpochServiceWrapper;
11use crate::event_store::{EventMessage, TransmitterService};
12use crate::http_server::routes::http_server_child_logger;
13use crate::http_server::routes::router::{RouterConfig, RouterState};
14use crate::services::{CertifierService, MessageService, ProverService, SignedEntityService};
15use crate::{
16    MetricsService, SignerRegisterer, SingleSignatureAuthenticator, VerificationKeyStorer,
17};
18
19/// Extract a value from the configuration
20pub fn extract_config<D: Clone + Send>(
21    state: &RouterState,
22    extract: fn(&RouterConfig) -> D,
23) -> impl Filter<Extract = (D,), Error = Infallible> + Clone + use<D> {
24    let config_value = extract(&state.configuration);
25    warp::any().map(move || config_value.clone())
26}
27
28/// With logger middleware
29pub(crate) fn with_logger(
30    router_state: &RouterState,
31) -> impl Filter<Extract = (Logger,), Error = Infallible> + Clone + use<> {
32    let logger = http_server_child_logger(&router_state.dependencies.root_logger);
33    warp::any().map(move || logger.clone())
34}
35
36/// Log to apply each time a route is called
37///
38/// Example of log produced: `POST /aggregator/register-signatures 202 Accepted`
39pub(crate) fn log_route_call(
40    router_state: &RouterState,
41) -> warp::log::Log<impl Fn(warp::log::Info<'_>) + Clone + use<>> {
42    let logger = http_server_child_logger(&router_state.dependencies.root_logger);
43    warp::log::custom(move |info| {
44        debug!(
45            logger,
46            "{} {} {}",
47            info.method(),
48            info.path(),
49            info.status()
50        )
51    })
52}
53
54/// With signer registerer middleware
55pub fn with_signer_registerer(
56    router_state: &RouterState,
57) -> impl Filter<Extract = (Arc<dyn SignerRegisterer>,), Error = Infallible> + Clone + use<> {
58    let signer_register = router_state.dependencies.signer_registerer.clone();
59    warp::any().map(move || signer_register.clone())
60}
61
62/// With signer getter middleware
63pub fn with_signer_getter(
64    router_state: &RouterState,
65) -> impl Filter<Extract = (Arc<dyn SignerGetter>,), Error = Infallible> + Clone + use<> {
66    let signer_getter = router_state.dependencies.signer_getter.clone();
67    warp::any().map(move || signer_getter.clone())
68}
69
70/// With Event transmitter middleware
71pub fn with_event_transmitter(
72    router_state: &RouterState,
73) -> impl Filter<Extract = (Arc<TransmitterService<EventMessage>>,), Error = Infallible> + Clone + use<>
74{
75    let event_transmitter = router_state.dependencies.event_transmitter.clone();
76    warp::any().map(move || event_transmitter.clone())
77}
78
79/// With certifier service middleware
80pub fn with_certifier_service(
81    router_state: &RouterState,
82) -> impl Filter<Extract = (Arc<dyn CertifierService>,), Error = Infallible> + Clone + use<> {
83    let certifier_service = router_state.dependencies.certifier_service.clone();
84    warp::any().map(move || certifier_service.clone())
85}
86
87/// With epoch service middleware
88pub fn with_epoch_service(
89    router_state: &RouterState,
90) -> impl Filter<Extract = (EpochServiceWrapper,), Error = Infallible> + Clone + use<> {
91    let epoch_service = router_state.dependencies.epoch_service.clone();
92    warp::any().map(move || epoch_service.clone())
93}
94
95/// With signed entity service
96pub fn with_signed_entity_service(
97    router_state: &RouterState,
98) -> impl Filter<Extract = (Arc<dyn SignedEntityService>,), Error = Infallible> + Clone + use<> {
99    let signed_entity_service = router_state.dependencies.signed_entity_service.clone();
100    warp::any().map(move || signed_entity_service.clone())
101}
102
103/// With verification key store
104pub fn with_verification_key_store(
105    router_state: &RouterState,
106) -> impl Filter<Extract = (Arc<dyn VerificationKeyStorer>,), Error = Infallible> + Clone + use<> {
107    let verification_key_store = router_state.dependencies.verification_key_store.clone();
108    warp::any().map(move || verification_key_store.clone())
109}
110
111/// With API version provider
112pub fn with_api_version_provider(
113    router_state: &RouterState,
114) -> impl Filter<Extract = (Arc<APIVersionProvider>,), Error = Infallible> + Clone + use<> {
115    let api_version_provider = router_state.dependencies.api_version_provider.clone();
116    warp::any().map(move || api_version_provider.clone())
117}
118
119/// With Message service
120pub fn with_http_message_service(
121    router_state: &RouterState,
122) -> impl Filter<Extract = (Arc<dyn MessageService>,), Error = Infallible> + Clone + use<> {
123    let message_service = router_state.dependencies.message_service.clone();
124    warp::any().map(move || message_service.clone())
125}
126
127/// With Prover service
128pub fn with_prover_service(
129    router_state: &RouterState,
130) -> impl Filter<Extract = (Arc<dyn ProverService>,), Error = Infallible> + Clone + use<> {
131    let prover_service = router_state.dependencies.prover_service.clone();
132    warp::any().map(move || prover_service.clone())
133}
134
135/// With Single Signature Authenticator
136pub fn with_single_signature_authenticator(
137    router_state: &RouterState,
138) -> impl Filter<Extract = (Arc<SingleSignatureAuthenticator>,), Error = Infallible> + Clone + use<>
139{
140    let single_signer_authenticator = router_state.dependencies.single_signer_authenticator.clone();
141    warp::any().map(move || single_signer_authenticator.clone())
142}
143
144/// With Metrics service
145pub fn with_metrics_service(
146    router_state: &RouterState,
147) -> impl Filter<Extract = (Arc<MetricsService>,), Error = Infallible> + Clone + use<> {
148    let metrics_service = router_state.dependencies.metrics_service.clone();
149    warp::any().map(move || metrics_service.clone())
150}
151
152/// With origin tag of the request
153pub fn with_origin_tag(
154    router_state: &RouterState,
155) -> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone + use<> {
156    let white_list = router_state.configuration.origin_tag_white_list.clone();
157
158    warp::header::optional::<String>(MITHRIL_ORIGIN_TAG_HEADER).map(move |name: Option<String>| {
159        name.filter(|tag| white_list.contains(tag)).or(Some("NA".to_string()))
160    })
161}
162
163pub fn with_client_type()
164-> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone {
165    let authorized_client_types = ["CLI", "WASM", "LIBRARY", "NA"];
166
167    warp::header::optional::<String>(MITHRIL_CLIENT_TYPE_HEADER).map(
168        move |client_type: Option<String>| {
169            client_type
170                .filter(|ct| authorized_client_types.contains(&ct.as_str()))
171                .or(Some("NA".to_string()))
172        },
173    )
174}
175
176pub fn with_client_metadata(
177    router_state: &RouterState,
178) -> impl Filter<Extract = (ClientMetadata,), Error = warp::reject::Rejection> + Clone + use<> {
179    with_origin_tag(router_state).and(with_client_type()).map(
180        |origin_tag: Option<String>, client_type: Option<String>| ClientMetadata {
181            origin_tag,
182            client_type,
183        },
184    )
185}
186
187pub struct ClientMetadata {
188    pub origin_tag: Option<String>,
189    pub client_type: Option<String>,
190}
191
192pub mod validators {
193    use crate::http_server::validators::ProverTransactionsHashValidator;
194
195    use super::*;
196
197    /// With Prover Transactions Hash Validator
198    pub fn with_prover_transactions_hash_validator(
199        router_state: &RouterState,
200    ) -> impl Filter<Extract = (ProverTransactionsHashValidator,), Error = Infallible> + Clone + use<>
201    {
202        let max_hashes = router_state
203            .configuration
204            .cardano_transactions_prover_max_hashes_allowed_by_request;
205
206        warp::any().map(move || ProverTransactionsHashValidator::new(max_hashes))
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use serde_json::Value;
214    use std::convert::Infallible;
215    use warp::{
216        Filter,
217        http::{Method, Response, StatusCode},
218        hyper::body::Bytes,
219        test::request,
220    };
221
222    use crate::http_server::routes::reply;
223    use crate::initialize_dependencies;
224
225    async fn route_handler(value: Option<String>) -> Result<impl warp::Reply, Infallible> {
226        Ok(reply::json(&value, StatusCode::OK))
227    }
228
229    fn get_body(response: Response<Bytes>) -> Option<String> {
230        let result: &Value = &serde_json::from_slice(response.body()).unwrap();
231        result.as_str().map(|s| s.to_string())
232    }
233
234    mod origin_tag {
235        use super::*;
236        use std::{collections::HashSet, path::PathBuf};
237
238        use mithril_common::temp_dir;
239
240        fn route_with_origin_tag(
241            router_state: &RouterState,
242        ) -> impl Filter<Extract = (impl warp::Reply + use<>,), Error = warp::Rejection> + Clone + use<>
243        {
244            warp::path!("route")
245                .and(warp::get())
246                .and(with_origin_tag(router_state))
247                .and_then(route_handler)
248        }
249
250        async fn router_state_with_origin_whitelist(
251            temp_dir: PathBuf,
252            tags: &[&str],
253        ) -> RouterState {
254            let origin_tag_white_list: HashSet<String> =
255                tags.iter().map(ToString::to_string).collect();
256
257            RouterState::new(
258                Arc::new(initialize_dependencies(temp_dir).await),
259                RouterConfig {
260                    origin_tag_white_list,
261                    ..RouterConfig::dummy()
262                },
263            )
264        }
265
266        #[tokio::test]
267        async fn test_origin_tag_with_value_in_white_list_return_the_tag() {
268            let router_state =
269                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
270
271            let response = request()
272                .header(MITHRIL_ORIGIN_TAG_HEADER, "CLIENT_TAG")
273                .method(Method::GET.as_str())
274                .path("/route")
275                .reply(&route_with_origin_tag(&router_state))
276                .await;
277
278            assert_eq!(Some("CLIENT_TAG".to_string()), get_body(response));
279        }
280
281        #[tokio::test]
282        async fn test_origin_tag_with_value_not_in_white_list_return_na() {
283            let router_state =
284                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
285
286            let response = request()
287                .header(MITHRIL_ORIGIN_TAG_HEADER, "UNKNOWN_TAG")
288                .method(Method::GET.as_str())
289                .path("/route")
290                .reply(&route_with_origin_tag(&router_state))
291                .await;
292
293            assert_eq!(Some("NA".to_string()), get_body(response));
294        }
295
296        #[tokio::test]
297        async fn test_without_origin_tag() {
298            let router_state =
299                router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
300
301            let response = request()
302                .method(Method::GET.as_str())
303                .path("/route")
304                .reply(&route_with_origin_tag(&router_state))
305                .await;
306
307            assert_eq!(Some("NA".to_string()), get_body(response));
308        }
309    }
310
311    mod client_type {
312        use super::*;
313
314        async fn request_with_client_type(client_type_header_value: &str) -> Response<Bytes> {
315            request()
316                .method(Method::GET.as_str())
317                .header(MITHRIL_CLIENT_TYPE_HEADER, client_type_header_value)
318                .path("/route")
319                .reply(&route_with_client_type())
320                .await
321        }
322
323        fn route_with_client_type()
324        -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
325            warp::path!("route")
326                .and(warp::get())
327                .and(with_client_type())
328                .and_then(route_handler)
329        }
330
331        #[tokio::test]
332        async fn test_with_client_type_use_na_as_default_value_if_header_not_set() {
333            let response = request()
334                .method(Method::GET.as_str())
335                .path("/route")
336                .reply(&route_with_client_type())
337                .await;
338
339            assert_eq!(Some("NA".to_string()), get_body(response));
340        }
341
342        #[tokio::test]
343        async fn test_with_client_type_only_authorize_specific_values() {
344            let response: Response<Bytes> = request_with_client_type("CLI").await;
345            assert_eq!(Some("CLI".to_string()), get_body(response));
346
347            let response: Response<Bytes> = request_with_client_type("WASM").await;
348            assert_eq!(Some("WASM".to_string()), get_body(response));
349
350            let response: Response<Bytes> = request_with_client_type("LIBRARY").await;
351            assert_eq!(Some("LIBRARY".to_string()), get_body(response));
352
353            let response: Response<Bytes> = request_with_client_type("NA").await;
354            assert_eq!(Some("NA".to_string()), get_body(response));
355
356            let response: Response<Bytes> = request_with_client_type("UNKNOWN").await;
357            assert_eq!(Some("NA".to_string()), get_body(response));
358        }
359    }
360}