mithril_common/test_utils/
test_http_server.rs

1//! Define a HttpServer for test that can be configured using warp filters.
2
3// Base code from the httpserver in reqwest tests:
4// https://github.com/seanmonstar/reqwest/blob/master/tests/support/server.rs
5
6use std::{net::SocketAddr, sync::mpsc as std_mpsc, thread};
7use tokio::{runtime, sync::oneshot};
8use warp::{Filter, Reply};
9
10/// A HTTP server for test
11pub struct TestHttpServer {
12    address: SocketAddr,
13    shutdown_tx: Option<oneshot::Sender<()>>,
14}
15
16impl TestHttpServer {
17    /// Get the test server address
18    pub fn address(&self) -> SocketAddr {
19        self.address
20    }
21
22    /// Get the server url
23    pub fn url(&self) -> String {
24        format!("http://{}", self.address)
25    }
26}
27
28impl Drop for TestHttpServer {
29    fn drop(&mut self) {
30        if let Some(tx) = self.shutdown_tx.take() {
31            let _ = tx.send(());
32        }
33    }
34}
35
36/// Spawn a [TestHttpServer] using the given warp filters
37pub fn test_http_server<F>(filters: F) -> TestHttpServer
38where
39    F: Filter + Clone + Send + Sync + 'static,
40    F::Extract: Reply,
41{
42    test_http_server_with_socket_address(filters, ([127, 0, 0, 1], 0).into())
43}
44
45/// Spawn a [TestHttpServer] using the given warp filters
46pub fn test_http_server_with_socket_address<F>(
47    filters: F,
48    socket_addr: SocketAddr,
49) -> TestHttpServer
50where
51    F: Filter + Clone + Send + Sync + 'static,
52    F::Extract: Reply,
53{
54    //Spawn new runtime in thread to prevent reactor execution context conflict
55    thread::spawn(move || {
56        let rt = runtime::Builder::new_current_thread()
57            .enable_all()
58            .build()
59            .expect("new rt");
60        let (shutdown_tx, shutdown_rx) = oneshot::channel();
61        let (address, server) = rt.block_on(async move {
62            warp::serve(filters).bind_with_graceful_shutdown(socket_addr, async {
63                shutdown_rx.await.ok();
64            })
65        });
66
67        let (panic_tx, _) = std_mpsc::channel();
68        let thread_name = format!(
69            "test({})-support-server",
70            thread::current().name().unwrap_or("<unknown>")
71        );
72        thread::Builder::new()
73            .name(thread_name)
74            .spawn(move || {
75                rt.block_on(server);
76                let _ = panic_tx.send(());
77            })
78            .expect("thread spawn");
79
80        TestHttpServer {
81            address,
82            shutdown_tx: Some(shutdown_tx),
83        }
84    })
85    .join()
86    .unwrap()
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use reqwest::StatusCode;
93    use serde::Deserialize;
94
95    #[tokio::test]
96    async fn test_server_simple_http() {
97        let expected: &'static str = "Hello Mithril !";
98        let routes = warp::any().map(move || expected);
99        let server = test_http_server(routes);
100
101        let result = reqwest::get(server.url())
102            .await
103            .expect("should run")
104            .text()
105            .await
106            .unwrap();
107
108        assert_eq!(result, expected)
109    }
110
111    #[tokio::test]
112    async fn test_server_simple_json() {
113        #[derive(Debug, Eq, PartialEq, Deserialize)]
114        struct Test {
115            content: String,
116        }
117
118        let routes = warp::any().map(move || r#"{"content":"Hello Mithril !"}"#);
119        let server = test_http_server(routes);
120
121        let result = reqwest::get(server.url())
122            .await
123            .expect("should run")
124            .json::<Test>()
125            .await
126            .unwrap();
127
128        assert_eq!(
129            result,
130            Test {
131                content: "Hello Mithril !".to_string()
132            }
133        )
134    }
135
136    #[tokio::test]
137    async fn test_server_specific_route() {
138        let expected: &'static str = "Hello Mithril !";
139        let routes = warp::path("hello").map(move || expected);
140        let server = test_http_server(routes);
141
142        let result = reqwest::get(format!("{}/hello", server.url()))
143            .await
144            .expect("should run")
145            .text()
146            .await
147            .unwrap();
148
149        assert_eq!(result, expected);
150    }
151
152    #[tokio::test]
153    async fn test_server_unbind_route_yield_404() {
154        let expected: &'static str = "Hello Mithril !";
155        let routes = warp::path("hello").map(move || expected);
156        let server = test_http_server(routes);
157
158        let result = reqwest::get(format!("{}/unbind", server.url()))
159            .await
160            .expect("should run");
161
162        assert_eq!(result.status(), StatusCode::NOT_FOUND);
163    }
164}