mithril_aggregator/http_server/routes/
middlewares.rs1use serde::de::DeserializeOwned;
2use slog::{Logger, debug};
3use std::convert::Infallible;
4use std::sync::Arc;
5use warp::{Filter, Rejection};
6
7use mithril_common::api_version::APIVersionProvider;
8use mithril_common::{MITHRIL_CLIENT_TYPE_HEADER, MITHRIL_ORIGIN_TAG_HEADER};
9
10use crate::database::repository::SignerGetter;
11use crate::dependency_injection::EpochServiceWrapper;
12use crate::event_store::{EventMessage, TransmitterService};
13use crate::http_server::routes::http_server_child_logger;
14use crate::http_server::routes::router::{RouterConfig, RouterState};
15use crate::services::{CertifierService, LegacyProverService, MessageService, SignedEntityService};
16use crate::{
17 MetricsService, SignerRegisterer, SingleSignatureAuthenticator, VerificationKeyStorer,
18};
19
20pub(crate) fn json_with_max_length<T: DeserializeOwned + Send>(
22 max_length: u64,
23) -> impl Filter<Extract = (T,), Error = Rejection> + Copy {
24 warp::body::content_length_limit(max_length).and(warp::body::json())
25}
26
27pub(crate) fn extract_config<D: Clone + Send>(
29 state: &RouterState,
30 extract: fn(&RouterConfig) -> D,
31) -> impl Filter<Extract = (D,), Error = Infallible> + Clone + use<D> {
32 let config_value = extract(&state.configuration);
33 warp::any().map(move || config_value.clone())
34}
35
36pub(crate) fn with_logger(
38 router_state: &RouterState,
39) -> impl Filter<Extract = (Logger,), Error = Infallible> + Clone + use<> {
40 let logger = http_server_child_logger(&router_state.dependencies.root_logger);
41 warp::any().map(move || logger.clone())
42}
43
44pub(crate) fn log_route_call(
48 router_state: &RouterState,
49) -> warp::log::Log<impl Fn(warp::log::Info<'_>) + Clone + use<>> {
50 let logger = http_server_child_logger(&router_state.dependencies.root_logger);
51 warp::log::custom(move |info| {
52 debug!(
53 logger,
54 "{} {} {}",
55 info.method(),
56 info.path(),
57 info.status()
58 )
59 })
60}
61
62pub fn with_signer_registerer(
64 router_state: &RouterState,
65) -> impl Filter<Extract = (Arc<dyn SignerRegisterer>,), Error = Infallible> + Clone + use<> {
66 let signer_register = router_state.dependencies.signer_registerer.clone();
67 warp::any().map(move || signer_register.clone())
68}
69
70pub fn with_signer_getter(
72 router_state: &RouterState,
73) -> impl Filter<Extract = (Arc<dyn SignerGetter>,), Error = Infallible> + Clone + use<> {
74 let signer_getter = router_state.dependencies.signer_getter.clone();
75 warp::any().map(move || signer_getter.clone())
76}
77
78pub fn with_event_transmitter(
80 router_state: &RouterState,
81) -> impl Filter<Extract = (Arc<TransmitterService<EventMessage>>,), Error = Infallible> + Clone + use<>
82{
83 let event_transmitter = router_state.dependencies.event_transmitter.clone();
84 warp::any().map(move || event_transmitter.clone())
85}
86
87pub fn with_certifier_service(
89 router_state: &RouterState,
90) -> impl Filter<Extract = (Arc<dyn CertifierService>,), Error = Infallible> + Clone + use<> {
91 let certifier_service = router_state.dependencies.certifier_service.clone();
92 warp::any().map(move || certifier_service.clone())
93}
94
95pub fn with_epoch_service(
97 router_state: &RouterState,
98) -> impl Filter<Extract = (EpochServiceWrapper,), Error = Infallible> + Clone + use<> {
99 let epoch_service = router_state.dependencies.epoch_service.clone();
100 warp::any().map(move || epoch_service.clone())
101}
102
103pub fn with_signed_entity_service(
105 router_state: &RouterState,
106) -> impl Filter<Extract = (Arc<dyn SignedEntityService>,), Error = Infallible> + Clone + use<> {
107 let signed_entity_service = router_state.dependencies.signed_entity_service.clone();
108 warp::any().map(move || signed_entity_service.clone())
109}
110
111pub fn with_verification_key_store(
113 router_state: &RouterState,
114) -> impl Filter<Extract = (Arc<dyn VerificationKeyStorer>,), Error = Infallible> + Clone + use<> {
115 let verification_key_store = router_state.dependencies.verification_key_store.clone();
116 warp::any().map(move || verification_key_store.clone())
117}
118
119pub fn with_api_version_provider(
121 router_state: &RouterState,
122) -> impl Filter<Extract = (Arc<APIVersionProvider>,), Error = Infallible> + Clone + use<> {
123 let api_version_provider = router_state.dependencies.api_version_provider.clone();
124 warp::any().map(move || api_version_provider.clone())
125}
126
127pub fn with_http_message_service(
129 router_state: &RouterState,
130) -> impl Filter<Extract = (Arc<dyn MessageService>,), Error = Infallible> + Clone + use<> {
131 let message_service = router_state.dependencies.message_service.clone();
132 warp::any().map(move || message_service.clone())
133}
134
135pub fn with_prover_service(
137 router_state: &RouterState,
138) -> impl Filter<Extract = (Arc<dyn LegacyProverService>,), Error = Infallible> + Clone + use<> {
139 let prover_service = router_state.dependencies.legacy_prover_service.clone();
140 warp::any().map(move || prover_service.clone())
141}
142
143pub fn with_single_signature_authenticator(
145 router_state: &RouterState,
146) -> impl Filter<Extract = (Arc<SingleSignatureAuthenticator>,), Error = Infallible> + Clone + use<>
147{
148 let single_signer_authenticator = router_state.dependencies.single_signer_authenticator.clone();
149 warp::any().map(move || single_signer_authenticator.clone())
150}
151
152pub fn with_metrics_service(
154 router_state: &RouterState,
155) -> impl Filter<Extract = (Arc<MetricsService>,), Error = Infallible> + Clone + use<> {
156 let metrics_service = router_state.dependencies.metrics_service.clone();
157 warp::any().map(move || metrics_service.clone())
158}
159
160pub fn with_origin_tag(
162 router_state: &RouterState,
163) -> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone + use<> {
164 let white_list = router_state.configuration.origin_tag_white_list.clone();
165
166 warp::header::optional::<String>(MITHRIL_ORIGIN_TAG_HEADER).map(move |name: Option<String>| {
167 name.filter(|tag| white_list.contains(tag)).or(Some("NA".to_string()))
168 })
169}
170
171pub fn with_client_type()
172-> impl Filter<Extract = (Option<String>,), Error = warp::reject::Rejection> + Clone {
173 let authorized_client_types = ["CLI", "WASM", "LIBRARY", "NA"];
174
175 warp::header::optional::<String>(MITHRIL_CLIENT_TYPE_HEADER).map(
176 move |client_type: Option<String>| {
177 client_type
178 .filter(|ct| authorized_client_types.contains(&ct.as_str()))
179 .or(Some("NA".to_string()))
180 },
181 )
182}
183
184pub fn with_client_metadata(
185 router_state: &RouterState,
186) -> impl Filter<Extract = (ClientMetadata,), Error = warp::reject::Rejection> + Clone + use<> {
187 with_origin_tag(router_state).and(with_client_type()).map(
188 |origin_tag: Option<String>, client_type: Option<String>| ClientMetadata {
189 origin_tag,
190 client_type,
191 },
192 )
193}
194
195pub struct ClientMetadata {
196 pub origin_tag: Option<String>,
197 pub client_type: Option<String>,
198}
199
200pub mod validators {
201 use crate::http_server::validators::{HashKind, ProverHashValidator};
202
203 use super::*;
204
205 pub fn with_prover_transactions_hash_validator(
207 router_state: &RouterState,
208 ) -> impl Filter<Extract = (ProverHashValidator,), Error = Infallible> + Clone + use<> {
209 let max_hashes = router_state
210 .configuration
211 .cardano_prover_max_hashes_allowed_by_request;
212
213 warp::any().map(move || ProverHashValidator::new(HashKind::Transaction, max_hashes))
214 }
215
216 pub fn with_prover_block_hash_validator(
218 router_state: &RouterState,
219 ) -> impl Filter<Extract = (ProverHashValidator,), Error = Infallible> + Clone + use<> {
220 let max_hashes = router_state
221 .configuration
222 .cardano_prover_max_hashes_allowed_by_request;
223
224 warp::any().map(move || ProverHashValidator::new(HashKind::Block, max_hashes))
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use serde_json::Value;
231 use std::convert::Infallible;
232 use warp::{
233 Filter,
234 http::{Method, Response, StatusCode},
235 hyper::body::Bytes,
236 test::request,
237 };
238
239 use mithril_common::test::double::Dummy;
240
241 use crate::http_server::routes::reply;
242 use crate::initialize_dependencies;
243
244 use super::*;
245
246 async fn route_handler(value: Option<String>) -> Result<impl warp::Reply, Infallible> {
247 Ok(reply::json(&value, StatusCode::OK))
248 }
249
250 fn get_body(response: Response<Bytes>) -> Option<String> {
251 let result: &Value = &serde_json::from_slice(response.body()).unwrap();
252 result.as_str().map(|s| s.to_string())
253 }
254
255 mod origin_tag {
256 use super::*;
257 use std::{collections::HashSet, path::PathBuf};
258
259 use mithril_common::temp_dir;
260
261 fn route_with_origin_tag(
262 router_state: &RouterState,
263 ) -> impl Filter<Extract = (impl warp::Reply + use<>,), Error = warp::Rejection> + Clone + use<>
264 {
265 warp::path!("route")
266 .and(warp::get())
267 .and(with_origin_tag(router_state))
268 .and_then(route_handler)
269 }
270
271 async fn router_state_with_origin_whitelist(
272 temp_dir: PathBuf,
273 tags: &[&str],
274 ) -> RouterState {
275 let origin_tag_white_list: HashSet<String> =
276 tags.iter().map(ToString::to_string).collect();
277
278 RouterState::new(
279 Arc::new(initialize_dependencies(temp_dir).await),
280 RouterConfig {
281 origin_tag_white_list,
282 ..RouterConfig::dummy()
283 },
284 )
285 }
286
287 #[tokio::test]
288 async fn test_origin_tag_with_value_in_white_list_return_the_tag() {
289 let router_state =
290 router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
291
292 let response = request()
293 .header(MITHRIL_ORIGIN_TAG_HEADER, "CLIENT_TAG")
294 .method(Method::GET.as_str())
295 .path("/route")
296 .reply(&route_with_origin_tag(&router_state))
297 .await;
298
299 assert_eq!(Some("CLIENT_TAG".to_string()), get_body(response));
300 }
301
302 #[tokio::test]
303 async fn test_origin_tag_with_value_not_in_white_list_return_na() {
304 let router_state =
305 router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
306
307 let response = request()
308 .header(MITHRIL_ORIGIN_TAG_HEADER, "UNKNOWN_TAG")
309 .method(Method::GET.as_str())
310 .path("/route")
311 .reply(&route_with_origin_tag(&router_state))
312 .await;
313
314 assert_eq!(Some("NA".to_string()), get_body(response));
315 }
316
317 #[tokio::test]
318 async fn test_without_origin_tag() {
319 let router_state =
320 router_state_with_origin_whitelist(temp_dir!(), &["CLIENT_TAG"]).await;
321
322 let response = request()
323 .method(Method::GET.as_str())
324 .path("/route")
325 .reply(&route_with_origin_tag(&router_state))
326 .await;
327
328 assert_eq!(Some("NA".to_string()), get_body(response));
329 }
330 }
331
332 mod client_type {
333 use super::*;
334
335 async fn request_with_client_type(client_type_header_value: &str) -> Response<Bytes> {
336 request()
337 .method(Method::GET.as_str())
338 .header(MITHRIL_CLIENT_TYPE_HEADER, client_type_header_value)
339 .path("/route")
340 .reply(&route_with_client_type())
341 .await
342 }
343
344 fn route_with_client_type()
345 -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
346 warp::path!("route")
347 .and(warp::get())
348 .and(with_client_type())
349 .and_then(route_handler)
350 }
351
352 #[tokio::test]
353 async fn test_with_client_type_use_na_as_default_value_if_header_not_set() {
354 let response = request()
355 .method(Method::GET.as_str())
356 .path("/route")
357 .reply(&route_with_client_type())
358 .await;
359
360 assert_eq!(Some("NA".to_string()), get_body(response));
361 }
362
363 #[tokio::test]
364 async fn test_with_client_type_only_authorize_specific_values() {
365 let response: Response<Bytes> = request_with_client_type("CLI").await;
366 assert_eq!(Some("CLI".to_string()), get_body(response));
367
368 let response: Response<Bytes> = request_with_client_type("WASM").await;
369 assert_eq!(Some("WASM".to_string()), get_body(response));
370
371 let response: Response<Bytes> = request_with_client_type("LIBRARY").await;
372 assert_eq!(Some("LIBRARY".to_string()), get_body(response));
373
374 let response: Response<Bytes> = request_with_client_type("NA").await;
375 assert_eq!(Some("NA".to_string()), get_body(response));
376
377 let response: Response<Bytes> = request_with_client_type("UNKNOWN").await;
378 assert_eq!(Some("NA".to_string()), get_body(response));
379 }
380 }
381
382 mod json_with_max_length {
383 use warp::test::RequestBuilder;
384
385 use super::*;
386
387 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
388 struct TestPayload {
389 d: String,
390 }
391
392 impl TestPayload {
393 fn new(payload: &str) -> Self {
394 TestPayload {
395 d: payload.to_string(),
396 }
397 }
398
399 fn json_bytes_len(&self) -> usize {
400 serde_json::to_string(self).unwrap().len()
401 }
402 }
403
404 fn route_with_json_limit(
405 max_length: u64,
406 ) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
407 warp::path!("json-route")
408 .and(warp::post())
409 .and(json_with_max_length(max_length))
410 .then(|json: TestPayload| async move { reply::json(&json, StatusCode::OK) })
411 }
412
413 fn test_request(json: &TestPayload) -> RequestBuilder {
414 request().method(Method::POST.as_str()).json(json).path("/json-route")
416 }
417
418 #[tokio::test]
419 async fn test_accepts_request_with_content_length_just_below_threshold() {
420 let payload = TestPayload::new("test");
421 let threshold = payload.json_bytes_len();
422 let response = test_request(&payload)
423 .reply(&route_with_json_limit(threshold as u64))
424 .await;
425
426 assert_eq!(response.status(), StatusCode::OK);
427 }
428
429 #[tokio::test]
430 async fn test_rejects_request_with_content_length_too_large() {
431 let payload = TestPayload::new("test");
432 let threshold = payload.json_bytes_len() - 1;
433 let response = test_request(&payload)
434 .reply(&route_with_json_limit(threshold as u64))
435 .await;
436
437 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
438 }
439
440 #[tokio::test]
441 async fn test_rejects_request_missing_content_length_header() {
442 let payload = TestPayload::new("test");
443 let response = test_request(&payload)
444 .header("content-length", "")
446 .reply(&route_with_json_limit(100))
447 .await;
448
449 assert_eq!(response.status(), StatusCode::LENGTH_REQUIRED);
450 }
451 }
452}