mithril_aggregator_client/
builder.rs1use anyhow::Context;
2use reqwest::{Client, IntoUrl, Url};
3use slog::{Logger, o};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use mithril_common::StdResult;
9use mithril_common::api_version::APIVersionProvider;
10
11use crate::client::AggregatorHttpClient;
12
13pub struct AggregatorClientBuilder {
15 aggregator_url_result: reqwest::Result<Url>,
16 api_version_provider: Option<Arc<APIVersionProvider>>,
17 additional_headers: Option<HashMap<String, String>>,
18 timeout_duration: Option<Duration>,
19 #[cfg(not(target_family = "wasm"))]
20 relay_endpoint: Option<String>,
21 logger: Option<Logger>,
22}
23
24impl AggregatorClientBuilder {
25 pub fn new<U: IntoUrl>(aggregator_url: U) -> Self {
29 Self {
30 aggregator_url_result: aggregator_url.into_url(),
31 api_version_provider: None,
32 additional_headers: None,
33 timeout_duration: None,
34 #[cfg(not(target_family = "wasm"))]
35 relay_endpoint: None,
36 logger: None,
37 }
38 }
39
40 pub fn with_logger(mut self, logger: Logger) -> Self {
42 self.logger = Some(logger);
43 self
44 }
45
46 pub fn with_api_version_provider(
48 mut self,
49 api_version_provider: Arc<APIVersionProvider>,
50 ) -> Self {
51 self.api_version_provider = Some(api_version_provider);
52 self
53 }
54
55 pub fn with_timeout(mut self, timeout: Duration) -> Self {
57 self.timeout_duration = Some(timeout);
58 self
59 }
60
61 pub fn with_headers(mut self, custom_headers: HashMap<String, String>) -> Self {
63 self.additional_headers = Some(custom_headers);
64 self
65 }
66
67 #[cfg(not(target_family = "wasm"))]
71 pub fn with_relay_endpoint(mut self, relay_endpoint: Option<String>) -> Self {
72 self.relay_endpoint = relay_endpoint;
73 self
74 }
75
76 pub fn build(self) -> StdResult<AggregatorHttpClient> {
78 let aggregator_endpoint =
79 enforce_trailing_slash(self.aggregator_url_result.with_context(
80 || "Invalid aggregator endpoint, it must be a correctly formed url",
81 )?);
82 let logger = self.logger.unwrap_or_else(|| Logger::root(slog::Discard, o!()));
83 let api_version_provider = self.api_version_provider.unwrap_or_default();
84 let additional_headers = self.additional_headers.unwrap_or_default();
85 #[cfg(not(target_family = "wasm"))]
86 let mut client_builder = Client::builder();
87 #[cfg(target_family = "wasm")]
88 let client_builder = Client::builder();
89
90 #[cfg(not(target_family = "wasm"))]
91 if let Some(relay_endpoint) = self.relay_endpoint {
92 use reqwest::Proxy;
93
94 client_builder = client_builder
95 .proxy(Proxy::all(relay_endpoint).with_context(|| "Relay proxy creation failed")?)
96 }
97
98 Ok(AggregatorHttpClient {
99 aggregator_endpoint,
100 api_version_provider,
101 additional_headers: (&additional_headers)
102 .try_into()
103 .with_context(|| format!("Invalid headers: '{additional_headers:?}'"))?,
104 timeout_duration: self.timeout_duration,
105 client: client_builder
106 .build()
107 .with_context(|| "HTTP client creation failed")?,
108 logger,
109 })
110 }
111}
112
113fn enforce_trailing_slash(url: Url) -> Url {
114 if url.as_str().ends_with('/') {
118 url
119 } else {
120 let mut url = url.clone();
121 url.set_path(&format!("{}/", url.path()));
122 url
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn enforce_trailing_slash_for_aggregator_url() {
132 let url_without_trailing_slash = Url::parse("http://localhost:8080").unwrap();
133 let url_with_trailing_slash = Url::parse("http://localhost:8080/").unwrap();
134
135 assert_eq!(
136 url_with_trailing_slash,
137 enforce_trailing_slash(url_without_trailing_slash.clone())
138 );
139 assert_eq!(
140 url_with_trailing_slash,
141 enforce_trailing_slash(url_with_trailing_slash.clone())
142 );
143 }
144}