mithril_common/test_utils/
test_http_server.rs1use std::{net::SocketAddr, sync::mpsc as std_mpsc, thread, time::Duration};
7use tokio::{runtime, sync::oneshot};
8use warp::{Filter, Reply};
9
10pub struct TestHttpServer {
12 address: SocketAddr,
13 panic_rx: std_mpsc::Receiver<()>,
14 shutdown_tx: Option<oneshot::Sender<()>>,
15}
16
17impl TestHttpServer {
18 pub fn address(&self) -> SocketAddr {
20 self.address
21 }
22
23 pub fn url(&self) -> String {
25 format!("http://{}", self.address)
26 }
27}
28
29impl Drop for TestHttpServer {
30 fn drop(&mut self) {
31 if let Some(tx) = self.shutdown_tx.take() {
32 let _ = tx.send(());
33 }
34
35 if !::std::thread::panicking() {
36 self.panic_rx
37 .recv_timeout(Duration::from_secs(3))
38 .expect("test server should not panic");
39 }
40 }
41}
42
43pub fn test_http_server<F>(filters: F) -> TestHttpServer
45where
46 F: Filter + Clone + Send + Sync + 'static,
47 F::Extract: Reply,
48{
49 test_http_server_with_socket_address(filters, ([127, 0, 0, 1], 0).into())
50}
51
52pub fn test_http_server_with_socket_address<F>(
54 filters: F,
55 socket_addr: SocketAddr,
56) -> TestHttpServer
57where
58 F: Filter + Clone + Send + Sync + 'static,
59 F::Extract: Reply,
60{
61 thread::spawn(move || {
63 let rt = runtime::Builder::new_current_thread()
64 .enable_all()
65 .build()
66 .expect("new rt");
67 let (shutdown_tx, shutdown_rx) = oneshot::channel();
68 let (address, server) = rt.block_on(async move {
69 warp::serve(filters).bind_with_graceful_shutdown(socket_addr, async {
70 shutdown_rx.await.ok();
71 })
72 });
73
74 let (panic_tx, panic_rx) = std_mpsc::channel();
75 let thread_name = format!(
76 "test({})-support-server",
77 thread::current().name().unwrap_or("<unknown>")
78 );
79 thread::Builder::new()
80 .name(thread_name)
81 .spawn(move || {
82 rt.block_on(server);
83 let _ = panic_tx.send(());
84 })
85 .expect("thread spawn");
86
87 TestHttpServer {
88 address,
89 panic_rx,
90 shutdown_tx: Some(shutdown_tx),
91 }
92 })
93 .join()
94 .unwrap()
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use reqwest::StatusCode;
101 use serde::Deserialize;
102
103 #[tokio::test]
104 async fn test_server_simple_http() {
105 let expected: &'static str = "Hello Mithril !";
106 let routes = warp::any().map(move || expected);
107 let server = test_http_server(routes);
108
109 let result = reqwest::get(server.url())
110 .await
111 .expect("should run")
112 .text()
113 .await
114 .unwrap();
115
116 assert_eq!(result, expected)
117 }
118
119 #[tokio::test]
120 async fn test_server_simple_json() {
121 #[derive(Debug, Eq, PartialEq, Deserialize)]
122 struct Test {
123 content: String,
124 }
125
126 let routes = warp::any().map(move || r#"{"content":"Hello Mithril !"}"#);
127 let server = test_http_server(routes);
128
129 let result = reqwest::get(server.url())
130 .await
131 .expect("should run")
132 .json::<Test>()
133 .await
134 .unwrap();
135
136 assert_eq!(
137 result,
138 Test {
139 content: "Hello Mithril !".to_string()
140 }
141 )
142 }
143
144 #[tokio::test]
145 async fn test_server_specific_route() {
146 let expected: &'static str = "Hello Mithril !";
147 let routes = warp::path("hello").map(move || expected);
148 let server = test_http_server(routes);
149
150 let result = reqwest::get(format!("{}/hello", server.url()))
151 .await
152 .expect("should run")
153 .text()
154 .await
155 .unwrap();
156
157 assert_eq!(result, expected);
158 }
159
160 #[tokio::test]
161 async fn test_server_unbind_route_yield_404() {
162 let expected: &'static str = "Hello Mithril !";
163 let routes = warp::path("hello").map(move || expected);
164 let server = test_http_server(routes);
165
166 let result = reqwest::get(format!("{}/unbind", server.url()))
167 .await
168 .expect("should run");
169
170 assert_eq!(result.status(), StatusCode::NOT_FOUND);
171 }
172}