mithril_common/test_utils/
test_http_server.rs

1//! Define a HttpServer for test that can be configured using warp filters.
2
3// Base code from the httpserver in reqwest tests:
4// https://github.com/seanmonstar/reqwest/blob/master/tests/support/server.rs
5
6use std::{net::SocketAddr, sync::mpsc as std_mpsc, thread, time::Duration};
7use tokio::{runtime, sync::oneshot};
8use warp::{Filter, Reply};
9
10/// A HTTP server for test
11pub struct TestHttpServer {
12    address: SocketAddr,
13    panic_rx: std_mpsc::Receiver<()>,
14    shutdown_tx: Option<oneshot::Sender<()>>,
15}
16
17impl TestHttpServer {
18    /// Get the test server address
19    pub fn address(&self) -> SocketAddr {
20        self.address
21    }
22
23    /// Get the server url
24    pub fn url(&self) -> String {
25        format!("http://{}", self.address)
26    }
27}
28
29impl Drop for TestHttpServer {
30    fn drop(&mut self) {
31        if let Some(tx) = self.shutdown_tx.take() {
32            let _ = tx.send(());
33        }
34
35        if !::std::thread::panicking() {
36            self.panic_rx
37                .recv_timeout(Duration::from_secs(3))
38                .expect("test server should not panic");
39        }
40    }
41}
42
43/// Spawn a [TestHttpServer] using the given warp filters
44pub fn test_http_server<F>(filters: F) -> TestHttpServer
45where
46    F: Filter + Clone + Send + Sync + 'static,
47    F::Extract: Reply,
48{
49    test_http_server_with_socket_address(filters, ([127, 0, 0, 1], 0).into())
50}
51
52/// Spawn a [TestHttpServer] using the given warp filters
53pub fn test_http_server_with_socket_address<F>(
54    filters: F,
55    socket_addr: SocketAddr,
56) -> TestHttpServer
57where
58    F: Filter + Clone + Send + Sync + 'static,
59    F::Extract: Reply,
60{
61    //Spawn new runtime in thread to prevent reactor execution context conflict
62    thread::spawn(move || {
63        let rt = runtime::Builder::new_current_thread()
64            .enable_all()
65            .build()
66            .expect("new rt");
67        let (shutdown_tx, shutdown_rx) = oneshot::channel();
68        let (address, server) = rt.block_on(async move {
69            warp::serve(filters).bind_with_graceful_shutdown(socket_addr, async {
70                shutdown_rx.await.ok();
71            })
72        });
73
74        let (panic_tx, panic_rx) = std_mpsc::channel();
75        let thread_name = format!(
76            "test({})-support-server",
77            thread::current().name().unwrap_or("<unknown>")
78        );
79        thread::Builder::new()
80            .name(thread_name)
81            .spawn(move || {
82                rt.block_on(server);
83                let _ = panic_tx.send(());
84            })
85            .expect("thread spawn");
86
87        TestHttpServer {
88            address,
89            panic_rx,
90            shutdown_tx: Some(shutdown_tx),
91        }
92    })
93    .join()
94    .unwrap()
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use reqwest::StatusCode;
101    use serde::Deserialize;
102
103    #[tokio::test]
104    async fn test_server_simple_http() {
105        let expected: &'static str = "Hello Mithril !";
106        let routes = warp::any().map(move || expected);
107        let server = test_http_server(routes);
108
109        let result = reqwest::get(server.url())
110            .await
111            .expect("should run")
112            .text()
113            .await
114            .unwrap();
115
116        assert_eq!(result, expected)
117    }
118
119    #[tokio::test]
120    async fn test_server_simple_json() {
121        #[derive(Debug, Eq, PartialEq, Deserialize)]
122        struct Test {
123            content: String,
124        }
125
126        let routes = warp::any().map(move || r#"{"content":"Hello Mithril !"}"#);
127        let server = test_http_server(routes);
128
129        let result = reqwest::get(server.url())
130            .await
131            .expect("should run")
132            .json::<Test>()
133            .await
134            .unwrap();
135
136        assert_eq!(
137            result,
138            Test {
139                content: "Hello Mithril !".to_string()
140            }
141        )
142    }
143
144    #[tokio::test]
145    async fn test_server_specific_route() {
146        let expected: &'static str = "Hello Mithril !";
147        let routes = warp::path("hello").map(move || expected);
148        let server = test_http_server(routes);
149
150        let result = reqwest::get(format!("{}/hello", server.url()))
151            .await
152            .expect("should run")
153            .text()
154            .await
155            .unwrap();
156
157        assert_eq!(result, expected);
158    }
159
160    #[tokio::test]
161    async fn test_server_unbind_route_yield_404() {
162        let expected: &'static str = "Hello Mithril !";
163        let routes = warp::path("hello").map(move || expected);
164        let server = test_http_server(routes);
165
166        let result = reqwest::get(format!("{}/unbind", server.url()))
167            .await
168            .expect("should run");
169
170        assert_eq!(result.status(), StatusCode::NOT_FOUND);
171    }
172}