1use axum::{
2 Router,
3 body::Body,
4 extract::State,
5 http::{Response, StatusCode},
6 response::IntoResponse,
7 routing::get,
8};
9use slog::{Logger, error, info, warn};
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::TcpListener;
13use tokio::sync::watch::Receiver;
14
15use mithril_common::StdResult;
16use mithril_common::logging::LoggerExtensions;
17
18pub trait MetricsServiceExporter: Send + Sync {
20 fn export_metrics(&self) -> StdResult<String>;
22}
23
24#[derive(Debug)]
26pub enum MetricsServerError {
27 Internal(anyhow::Error),
29}
30
31impl IntoResponse for MetricsServerError {
33 fn into_response(self) -> Response<Body> {
34 match self {
35 Self::Internal(e) => {
36 (StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {e:?}")).into_response()
37 }
38 }
39 }
40}
41
42pub struct MetricsServer {
44 tcp_listener: TcpListener,
45 axum_app: Router,
46 address: SocketAddr,
47 logger: Logger,
48}
49
50pub struct MetricsServerBuilder<T: MetricsServiceExporter> {
52 server_port: u16,
53 server_ip: String,
54 metrics_service: Arc<T>,
55 logger: Logger,
56}
57
58struct RouterState<T: MetricsServiceExporter> {
59 metrics_service: Arc<T>,
60 logger: Logger,
61}
62
63impl MetricsServer {
64 pub fn build<T: MetricsServiceExporter + 'static>(
66 server_ip: &str,
67 server_port: u16,
68 metrics_service: Arc<T>,
69 logger: Logger,
70 ) -> MetricsServerBuilder<T> {
71 MetricsServerBuilder::new(server_ip, server_port, metrics_service, logger)
72 }
73
74 pub fn address(&self) -> SocketAddr {
76 self.address
77 }
78
79 pub async fn serve(self, shutdown_rx: Receiver<()>) -> StdResult<()> {
81 let serve_logger = self.logger;
82 let mut shutdown_rx = shutdown_rx;
83 axum::serve(self.tcp_listener, self.axum_app)
84 .with_graceful_shutdown(async move {
85 shutdown_rx.changed().await.ok();
86 warn!(
87 serve_logger,
88 "shutting down HTTP server after receiving signal"
89 );
90 })
91 .await?;
92
93 Ok(())
94 }
95}
96
97impl<T: MetricsServiceExporter + 'static> MetricsServerBuilder<T> {
98 pub fn new(server_ip: &str, server_port: u16, metrics_service: Arc<T>, logger: Logger) -> Self {
100 Self {
101 server_port,
102 server_ip: server_ip.to_string(),
103 metrics_service,
104 logger: logger.new_with_component_name::<Self>(),
105 }
106 }
107
108 pub async fn bind(self) -> StdResult<MetricsServer> {
110 info!(
111 self.logger,
112 "Starting HTTP server for metrics on port {}", self.server_port
113 );
114
115 let router_state = Arc::new(RouterState {
116 metrics_service: self.metrics_service,
117 logger: self.logger.clone(),
118 });
119 let axum_app = Router::new()
120 .route(
121 "/metrics",
122 get(|State(state): State<Arc<RouterState<T>>>| async move {
123 state.metrics_service.export_metrics().map_err(|e| {
124 error!(state.logger, "Error exporting metrics"; "error" => ?e);
125 MetricsServerError::Internal(e)
126 })
127 }),
128 )
129 .with_state(router_state);
130 let tcp_listener =
131 TcpListener::bind(format!("{}:{}", self.server_ip, self.server_port)).await?;
132 let address = tcp_listener.local_addr()?;
133
134 Ok(MetricsServer {
135 tcp_listener,
136 axum_app,
137 address,
138 logger: self.logger,
139 })
140 }
141
142 pub async fn serve(self, shutdown_rx: Receiver<()>) -> StdResult<()> {
144 let server = self.bind().await?;
145 server.serve(shutdown_rx).await
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use anyhow::anyhow;
152 use reqwest::StatusCode;
153 use std::time::Duration;
154 use tokio::{sync::watch, task::yield_now, time::sleep};
155
156 use crate::helper::test_tools::TestLogger;
157
158 use super::*;
159
160 pub struct PseudoMetricsService {}
161
162 impl PseudoMetricsService {
163 pub fn new() -> Self {
164 Self {}
165 }
166 }
167
168 impl MetricsServiceExporter for PseudoMetricsService {
169 fn export_metrics(&self) -> StdResult<String> {
170 Ok("pseudo metrics".to_string())
171 }
172 }
173
174 #[tokio::test]
175 async fn test_metrics_server() {
176 let logger = TestLogger::stdout();
177 let (shutdown_tx, shutdown_rx) = watch::channel(());
178 let metrics_service = Arc::new(PseudoMetricsService::new());
179 let metrics_server = MetricsServer::build(
180 "127.0.0.1",
181 0, metrics_service.clone(),
183 logger,
184 )
185 .bind()
186 .await
187 .unwrap();
188 let metrics_server_address = metrics_server.address();
189
190 let exported_metrics_test = tokio::spawn(async move {
191 yield_now().await;
193
194 let response = reqwest::get(format!("http://{metrics_server_address}/metrics"))
195 .await
196 .unwrap();
197
198 assert_eq!(StatusCode::OK, response.status());
199 assert_eq!("pseudo metrics", response.text().await.unwrap());
200 });
201
202 let res = tokio::select!(
203 res = metrics_server.serve(shutdown_rx) => Err(anyhow!("Metrics server exited with value '{res:?}'")),
204 _res = sleep(Duration::from_secs(1)) => Err(anyhow!("Timeout: The test should have already completed.")),
205 res = exported_metrics_test => res.map_err(|e| e.into()),
206 );
207
208 shutdown_tx.send(()).unwrap();
209 res.unwrap();
210 }
211}