1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
//! Define a HttpServer for test that can be configured using warp filters.

// Base code from the httpserver in reqwest tests:
// https://github.com/seanmonstar/reqwest/blob/master/tests/support/server.rs

use std::{net::SocketAddr, sync::mpsc as std_mpsc, thread, time::Duration};
use tokio::{runtime, sync::oneshot};
use warp::{Filter, Reply};

/// A HTTP server for test
pub struct TestHttpServer {
    address: SocketAddr,
    panic_rx: std_mpsc::Receiver<()>,
    shutdown_tx: Option<oneshot::Sender<()>>,
}

impl TestHttpServer {
    /// Get the test server address
    pub fn address(&self) -> SocketAddr {
        self.address
    }

    /// Get the server url
    pub fn url(&self) -> String {
        format!("http://{}", self.address)
    }
}

impl Drop for TestHttpServer {
    fn drop(&mut self) {
        if let Some(tx) = self.shutdown_tx.take() {
            let _ = tx.send(());
        }

        if !::std::thread::panicking() {
            self.panic_rx
                .recv_timeout(Duration::from_secs(3))
                .expect("test server should not panic");
        }
    }
}

/// Spawn a [TestHttpServer] using the given warp filters
pub fn test_http_server<F>(filters: F) -> TestHttpServer
where
    F: Filter + Clone + Send + Sync + 'static,
    F::Extract: Reply,
{
    test_http_server_with_socket_address(filters, ([127, 0, 0, 1], 0).into())
}

/// Spawn a [TestHttpServer] using the given warp filters
pub fn test_http_server_with_socket_address<F>(
    filters: F,
    socket_addr: SocketAddr,
) -> TestHttpServer
where
    F: Filter + Clone + Send + Sync + 'static,
    F::Extract: Reply,
{
    //Spawn new runtime in thread to prevent reactor execution context conflict
    thread::spawn(move || {
        let rt = runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .expect("new rt");
        let (shutdown_tx, shutdown_rx) = oneshot::channel();
        let (address, server) = rt.block_on(async move {
            warp::serve(filters).bind_with_graceful_shutdown(socket_addr, async {
                shutdown_rx.await.ok();
            })
        });

        let (panic_tx, panic_rx) = std_mpsc::channel();
        let thread_name = format!(
            "test({})-support-server",
            thread::current().name().unwrap_or("<unknown>")
        );
        thread::Builder::new()
            .name(thread_name)
            .spawn(move || {
                rt.block_on(server);
                let _ = panic_tx.send(());
            })
            .expect("thread spawn");

        TestHttpServer {
            address,
            panic_rx,
            shutdown_tx: Some(shutdown_tx),
        }
    })
    .join()
    .unwrap()
}

#[cfg(test)]
mod tests {
    use super::*;
    use reqwest::StatusCode;
    use serde::Deserialize;

    #[tokio::test]
    async fn test_server_simple_http() {
        let expected: &'static str = "Hello Mithril !";
        let routes = warp::any().map(move || expected);
        let server = test_http_server(routes);

        let result = reqwest::get(server.url())
            .await
            .expect("should run")
            .text()
            .await
            .unwrap();

        assert_eq!(result, expected)
    }

    #[tokio::test]
    async fn test_server_simple_json() {
        #[derive(Debug, Eq, PartialEq, Deserialize)]
        struct Test {
            content: String,
        }

        let routes = warp::any().map(move || r#"{"content":"Hello Mithril !"}"#);
        let server = test_http_server(routes);

        let result = reqwest::get(server.url())
            .await
            .expect("should run")
            .json::<Test>()
            .await
            .unwrap();

        assert_eq!(
            result,
            Test {
                content: "Hello Mithril !".to_string()
            }
        )
    }

    #[tokio::test]
    async fn test_server_specific_route() {
        let expected: &'static str = "Hello Mithril !";
        let routes = warp::path("hello").map(move || expected);
        let server = test_http_server(routes);

        let result = reqwest::get(format!("{}/hello", server.url()))
            .await
            .expect("should run")
            .text()
            .await
            .unwrap();

        assert_eq!(result, expected);
    }

    #[tokio::test]
    async fn test_server_unbind_route_yield_404() {
        let expected: &'static str = "Hello Mithril !";
        let routes = warp::path("hello").map(move || expected);
        let server = test_http_server(routes);

        let result = reqwest::get(format!("{}/unbind", server.url()))
            .await
            .expect("should run");

        assert_eq!(result.status(), StatusCode::NOT_FOUND);
    }
}