use anyhow::{anyhow, Context};
use async_recursion::async_recursion;
use async_trait::async_trait;
use reqwest::{Response, StatusCode, Url};
use semver::Version;
use slog::{debug, Logger};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
#[cfg(test)]
use mockall::automock;
use mithril_common::MITHRIL_API_VERSION_HEADER;
use crate::{MithrilError, MithrilResult};
#[derive(Error, Debug)]
pub enum AggregatorClientError {
#[error("remote server technical error")]
RemoteServerTechnical(#[source] MithrilError),
#[error("remote server logical error")]
RemoteServerLogical(#[source] MithrilError),
#[error("API version mismatch")]
ApiVersionMismatch(#[source] MithrilError),
#[error("HTTP subsystem error")]
SubsystemError(#[source] MithrilError),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum AggregatorRequest {
GetCertificate {
hash: String,
},
ListCertificates,
GetMithrilStakeDistribution {
hash: String,
},
ListMithrilStakeDistributions,
GetSnapshot {
digest: String,
},
ListSnapshots,
IncrementSnapshotStatistic {
snapshot: String,
},
#[cfg(feature = "unstable")]
GetTransactionsProofs {
transactions_hashes: Vec<String>,
},
#[cfg(feature = "unstable")]
GetCardanoTransactionSnapshot {
hash: String,
},
#[cfg(feature = "unstable")]
ListCardanoTransactionSnapshots,
}
impl AggregatorRequest {
pub fn route(&self) -> String {
match self {
AggregatorRequest::GetCertificate { hash } => {
format!("certificate/{hash}")
}
AggregatorRequest::ListCertificates => "certificates".to_string(),
AggregatorRequest::GetMithrilStakeDistribution { hash } => {
format!("artifact/mithril-stake-distribution/{hash}")
}
AggregatorRequest::ListMithrilStakeDistributions => {
"artifact/mithril-stake-distributions".to_string()
}
AggregatorRequest::GetSnapshot { digest } => {
format!("artifact/snapshot/{}", digest)
}
AggregatorRequest::ListSnapshots => "artifact/snapshots".to_string(),
AggregatorRequest::IncrementSnapshotStatistic { snapshot: _ } => {
"statistics/snapshot".to_string()
}
#[cfg(feature = "unstable")]
AggregatorRequest::GetTransactionsProofs {
transactions_hashes,
} => format!(
"proof/cardano-transaction?transaction_hashes={}",
transactions_hashes.join(",")
),
#[cfg(feature = "unstable")]
AggregatorRequest::GetCardanoTransactionSnapshot { hash } => {
format!("artifact/cardano-transaction/{hash}")
}
#[cfg(feature = "unstable")]
AggregatorRequest::ListCardanoTransactionSnapshots => {
"artifact/cardano-transactions".to_string()
}
}
}
pub fn get_body(&self) -> Option<String> {
match self {
AggregatorRequest::IncrementSnapshotStatistic { snapshot } => {
Some(snapshot.to_string())
}
_ => None,
}
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
pub trait AggregatorClient: Sync + Send {
async fn get_content(
&self,
request: AggregatorRequest,
) -> Result<String, AggregatorClientError>;
async fn post_content(
&self,
request: AggregatorRequest,
) -> Result<String, AggregatorClientError>;
}
pub struct AggregatorHTTPClient {
http_client: reqwest::Client,
aggregator_endpoint: Url,
api_versions: Arc<RwLock<Vec<Version>>>,
logger: Logger,
}
impl AggregatorHTTPClient {
pub fn new(
aggregator_endpoint: Url,
api_versions: Vec<Version>,
logger: Logger,
) -> MithrilResult<Self> {
let http_client = reqwest::ClientBuilder::new()
.build()
.with_context(|| "Building http client for Aggregator client failed")?;
let aggregator_endpoint = if aggregator_endpoint.as_str().ends_with('/') {
aggregator_endpoint
} else {
let mut url = aggregator_endpoint.clone();
url.set_path(&format!("{}/", aggregator_endpoint.path()));
url
};
Ok(Self {
http_client,
aggregator_endpoint,
api_versions: Arc::new(RwLock::new(api_versions)),
logger,
})
}
async fn compute_current_api_version(&self) -> Option<Version> {
self.api_versions.read().await.first().cloned()
}
async fn discard_current_api_version(&self) -> Option<Version> {
if self.api_versions.read().await.len() < 2 {
return None;
}
if let Some(current_api_version) = self.compute_current_api_version().await {
let mut api_versions = self.api_versions.write().await;
if let Some(index) = api_versions
.iter()
.position(|value| *value == current_api_version)
{
api_versions.remove(index);
return Some(current_api_version);
}
}
None
}
#[cfg_attr(target_family = "wasm", async_recursion(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_recursion)]
async fn get(&self, url: Url) -> Result<Response, AggregatorClientError> {
debug!(self.logger, "GET url='{url}'.");
let request_builder = self.http_client.get(url.clone());
let current_api_version = self
.compute_current_api_version()
.await
.unwrap()
.to_string();
debug!(
self.logger,
"Prepare request with version: {current_api_version}"
);
let request_builder =
request_builder.header(MITHRIL_API_VERSION_HEADER, current_api_version);
let response = request_builder.send().await.map_err(|e| {
AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
"Cannot perform a GET against the Aggregator HTTP server (url='{url}')"
)))
})?;
match response.status() {
StatusCode::OK => Ok(response),
StatusCode::PRECONDITION_FAILED => {
if self.discard_current_api_version().await.is_some()
&& !self.api_versions.read().await.is_empty()
{
return self.get(url).await;
}
Err(self.handle_api_error(&response).await)
}
StatusCode::NOT_FOUND => Err(AggregatorClientError::RemoteServerLogical(anyhow!(
"Url='{url} not found"
))),
status_code => Err(AggregatorClientError::RemoteServerTechnical(anyhow!(
"Unhandled error {status_code}"
))),
}
}
#[cfg_attr(target_family = "wasm", async_recursion(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_recursion)]
async fn post(&self, url: Url, json: &str) -> Result<Response, AggregatorClientError> {
debug!(self.logger, "POST url='{url}' json='{json}'.");
let request_builder = self.http_client.post(url.to_owned()).body(json.to_owned());
let current_api_version = self
.compute_current_api_version()
.await
.unwrap()
.to_string();
debug!(
self.logger,
"Prepare request with version: {current_api_version}"
);
let request_builder =
request_builder.header(MITHRIL_API_VERSION_HEADER, current_api_version);
let response = request_builder.send().await.map_err(|e| {
AggregatorClientError::SubsystemError(
anyhow!(e).context("Error while POSTing data '{json}' to URL='{url}'."),
)
})?;
match response.status() {
StatusCode::OK | StatusCode::CREATED => Ok(response),
StatusCode::PRECONDITION_FAILED => {
if self.discard_current_api_version().await.is_some()
&& !self.api_versions.read().await.is_empty()
{
return self.post(url, json).await;
}
Err(self.handle_api_error(&response).await)
}
StatusCode::NOT_FOUND => Err(AggregatorClientError::RemoteServerLogical(anyhow!(
"Url='{url} not found"
))),
status_code => Err(AggregatorClientError::RemoteServerTechnical(anyhow!(
"Unhandled error {status_code}"
))),
}
}
async fn handle_api_error(&self, response: &Response) -> AggregatorClientError {
if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
AggregatorClientError::ApiVersionMismatch(anyhow!(
"server version: '{}', signer version: '{}'",
version.to_str().unwrap(),
self.compute_current_api_version().await.unwrap()
))
} else {
AggregatorClientError::ApiVersionMismatch(anyhow!(
"version precondition failed, sent version '{}'.",
self.compute_current_api_version().await.unwrap()
))
}
}
fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
self.aggregator_endpoint
.join(endpoint)
.with_context(|| {
format!(
"Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
self.aggregator_endpoint
)
})
.map_err(AggregatorClientError::SubsystemError)
}
}
#[cfg_attr(test, automock)]
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl AggregatorClient for AggregatorHTTPClient {
async fn get_content(
&self,
request: AggregatorRequest,
) -> Result<String, AggregatorClientError> {
let response = self.get(self.get_url_for_route(&request.route())?).await?;
let content = format!("{response:?}");
response.text().await.map_err(|e| {
AggregatorClientError::SubsystemError(anyhow!(e).context(format!(
"Could not find a JSON body in the response '{content}'."
)))
})
}
async fn post_content(
&self,
request: AggregatorRequest,
) -> Result<String, AggregatorClientError> {
let response = self
.post(
self.get_url_for_route(&request.route())?,
&request.get_body().unwrap_or_default(),
)
.await?;
response.text().await.map_err(|e| {
AggregatorClientError::SubsystemError(
anyhow!(e).context("Could not find a text body in the response."),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn always_append_trailing_slash_at_build() {
for (expected, url) in [
("http://www.test.net/", "http://www.test.net/"),
("http://www.test.net/", "http://www.test.net"),
(
"http://www.test.net/aggregator/",
"http://www.test.net/aggregator/",
),
(
"http://www.test.net/aggregator/",
"http://www.test.net/aggregator",
),
] {
let url = Url::parse(url).unwrap();
let client = AggregatorHTTPClient::new(url, vec![], crate::test_utils::test_logger())
.expect("building aggregator http client should not fail");
assert_eq!(expected, client.aggregator_endpoint.as_str());
}
}
#[test]
fn deduce_routes_from_request() {
assert_eq!(
"certificate/abc".to_string(),
AggregatorRequest::GetCertificate {
hash: "abc".to_string()
}
.route()
);
assert_eq!(
"artifact/mithril-stake-distribution/abc".to_string(),
AggregatorRequest::GetMithrilStakeDistribution {
hash: "abc".to_string()
}
.route()
);
assert_eq!(
"artifact/mithril-stake-distribution/abc".to_string(),
AggregatorRequest::GetMithrilStakeDistribution {
hash: "abc".to_string()
}
.route()
);
assert_eq!(
"artifact/mithril-stake-distributions".to_string(),
AggregatorRequest::ListMithrilStakeDistributions.route()
);
assert_eq!(
"artifact/snapshot/abc".to_string(),
AggregatorRequest::GetSnapshot {
digest: "abc".to_string()
}
.route()
);
assert_eq!(
"artifact/snapshots".to_string(),
AggregatorRequest::ListSnapshots.route()
);
assert_eq!(
"statistics/snapshot".to_string(),
AggregatorRequest::IncrementSnapshotStatistic {
snapshot: "abc".to_string()
}
.route()
);
#[cfg(feature = "unstable")]
{
assert_eq!(
"proof/cardano-transaction?transaction_hashes=abc,def,ghi,jkl".to_string(),
AggregatorRequest::GetTransactionsProofs {
transactions_hashes: vec![
"abc".to_string(),
"def".to_string(),
"ghi".to_string(),
"jkl".to_string()
]
}
.route()
);
assert_eq!(
"artifact/cardano-transaction/abc".to_string(),
AggregatorRequest::GetCardanoTransactionSnapshot {
hash: "abc".to_string()
}
.route()
);
assert_eq!(
"artifact/cardano-transactions".to_string(),
AggregatorRequest::ListCardanoTransactionSnapshots.route()
);
}
}
}