mithril_aggregator_client/
builder.rs

1use anyhow::Context;
2use reqwest::{Client, IntoUrl, Url};
3use slog::{Logger, o};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use mithril_common::StdResult;
9use mithril_common::api_version::APIVersionProvider;
10
11use crate::client::AggregatorHttpClient;
12
13/// A builder of [AggregatorHttpClient]
14pub struct AggregatorClientBuilder {
15    aggregator_url_result: reqwest::Result<Url>,
16    api_version_provider: Option<Arc<APIVersionProvider>>,
17    additional_headers: Option<HashMap<String, String>>,
18    timeout_duration: Option<Duration>,
19    #[cfg(not(target_family = "wasm"))]
20    relay_endpoint: Option<String>,
21    logger: Option<Logger>,
22}
23
24impl AggregatorClientBuilder {
25    /// Constructs a new `AggregatorClientBuilder`.
26    //
27    // This is the same as `AggregatorClient::builder()`.
28    pub fn new<U: IntoUrl>(aggregator_url: U) -> Self {
29        Self {
30            aggregator_url_result: aggregator_url.into_url(),
31            api_version_provider: None,
32            additional_headers: None,
33            timeout_duration: None,
34            #[cfg(not(target_family = "wasm"))]
35            relay_endpoint: None,
36            logger: None,
37        }
38    }
39
40    /// Set the [Logger] to use.
41    pub fn with_logger(mut self, logger: Logger) -> Self {
42        self.logger = Some(logger);
43        self
44    }
45
46    /// Set the [APIVersionProvider] to use.
47    pub fn with_api_version_provider(
48        mut self,
49        api_version_provider: Arc<APIVersionProvider>,
50    ) -> Self {
51        self.api_version_provider = Some(api_version_provider);
52        self
53    }
54
55    /// Set a timeout to enforce on each request
56    pub fn with_timeout(mut self, timeout: Duration) -> Self {
57        self.timeout_duration = Some(timeout);
58        self
59    }
60
61    /// Add a set of http headers that will be sent on client requests
62    pub fn with_headers(mut self, custom_headers: HashMap<String, String>) -> Self {
63        self.additional_headers = Some(custom_headers);
64        self
65    }
66
67    /// Set the address of the relay
68    ///
69    /// _Not available on wasm platforms_
70    #[cfg(not(target_family = "wasm"))]
71    pub fn with_relay_endpoint(mut self, relay_endpoint: Option<String>) -> Self {
72        self.relay_endpoint = relay_endpoint;
73        self
74    }
75
76    /// Returns an [AggregatorHttpClient] based on the builder configuration
77    pub fn build(self) -> StdResult<AggregatorHttpClient> {
78        let aggregator_endpoint =
79            enforce_trailing_slash(self.aggregator_url_result.with_context(
80                || "Invalid aggregator endpoint, it must be a correctly formed url",
81            )?);
82        let logger = self.logger.unwrap_or_else(|| Logger::root(slog::Discard, o!()));
83        let api_version_provider = self.api_version_provider.unwrap_or_default();
84        let additional_headers = self.additional_headers.unwrap_or_default();
85        #[cfg(not(target_family = "wasm"))]
86        let mut client_builder = Client::builder();
87        #[cfg(target_family = "wasm")]
88        let client_builder = Client::builder();
89
90        #[cfg(not(target_family = "wasm"))]
91        if let Some(relay_endpoint) = self.relay_endpoint {
92            use reqwest::Proxy;
93
94            client_builder = client_builder
95                .proxy(Proxy::all(relay_endpoint).with_context(|| "Relay proxy creation failed")?)
96        }
97
98        Ok(AggregatorHttpClient {
99            aggregator_endpoint,
100            api_version_provider,
101            additional_headers: (&additional_headers)
102                .try_into()
103                .with_context(|| format!("Invalid headers: '{additional_headers:?}'"))?,
104            timeout_duration: self.timeout_duration,
105            client: client_builder
106                .build()
107                .with_context(|| "HTTP client creation failed")?,
108            logger,
109        })
110    }
111}
112
113fn enforce_trailing_slash(url: Url) -> Url {
114    // Trailing slash is significant because url::join
115    // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove
116    // the 'path' part of the url if it doesn't end with a trailing slash.
117    if url.as_str().ends_with('/') {
118        url
119    } else {
120        let mut url = url.clone();
121        url.set_path(&format!("{}/", url.path()));
122        url
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn enforce_trailing_slash_for_aggregator_url() {
132        let url_without_trailing_slash = Url::parse("http://localhost:8080").unwrap();
133        let url_with_trailing_slash = Url::parse("http://localhost:8080/").unwrap();
134
135        assert_eq!(
136            url_with_trailing_slash,
137            enforce_trailing_slash(url_without_trailing_slash.clone())
138        );
139        assert_eq!(
140            url_with_trailing_slash,
141            enforce_trailing_slash(url_with_trailing_slash.clone())
142        );
143    }
144}