mithril_aggregator/tools/
url_sanitizer.rs

1use anyhow::{anyhow, Context};
2use reqwest::Url;
3use std::fmt::{Display, Formatter};
4use std::ops::Deref;
5
6use mithril_common::StdResult;
7
8/// A sanitized URL, guaranteed to have a trailing slash and no empty segments
9///
10/// This type is meant to be used as a base path to produce resources path, for example:
11/// `https://example.xy/download/` can be joined with `artifact/file.zip` to produce a download link
12#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Ord, Hash)]
13pub struct SanitizedUrlWithTrailingSlash {
14    internal_url: Url,
15}
16
17impl SanitizedUrlWithTrailingSlash {
18    /// Join this URL with the given path, the resulting URL is guaranteed to have a trailing slash
19    ///
20    /// See [Url::join] for more details
21    pub fn sanitize_join(&self, input: &str) -> StdResult<SanitizedUrlWithTrailingSlash> {
22        let url = self.internal_url.join(input).with_context(|| {
23            format!(
24                "Could not join `{}` to URL `{input}`",
25                self.internal_url.as_str()
26            )
27        })?;
28        sanitize_url_path(&url)
29    }
30
31    /// Parse an absolute URL from a string.
32    ///
33    /// See [Url::parse] for more details
34    pub fn parse(input: &str) -> StdResult<SanitizedUrlWithTrailingSlash> {
35        let url = Url::parse(input).with_context(|| format!("Could not parse URL `{input}`"))?;
36        sanitize_url_path(&url)
37    }
38}
39
40impl PartialEq<Url> for SanitizedUrlWithTrailingSlash {
41    fn eq(&self, other: &Url) -> bool {
42        self.internal_url.eq(other)
43    }
44}
45
46impl Deref for SanitizedUrlWithTrailingSlash {
47    type Target = Url;
48
49    fn deref(&self) -> &Self::Target {
50        &self.internal_url
51    }
52}
53
54impl From<SanitizedUrlWithTrailingSlash> for Url {
55    fn from(value: SanitizedUrlWithTrailingSlash) -> Self {
56        value.internal_url
57    }
58}
59
60impl Display for SanitizedUrlWithTrailingSlash {
61    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62        self.internal_url.fmt(f)
63    }
64}
65
66/// Sanitize URL path by removing empty segments and adding trailing slash
67pub fn sanitize_url_path(url: &Url) -> StdResult<SanitizedUrlWithTrailingSlash> {
68    let segments_non_empty = url
69        .path_segments()
70        .map(|s| s.into_iter().filter(|s| !s.is_empty()).collect::<Vec<_>>())
71        .unwrap_or_default();
72    let mut url = url.clone();
73    {
74        let mut url_path_segments = url
75            .path_segments_mut()
76            .map_err(|e| anyhow!("error parsing URL: {e:?}"))
77            .with_context(|| "while sanitizing URL path: {url}")?;
78        let url_path_segments_cleared = url_path_segments.clear();
79        for segment in segments_non_empty {
80            url_path_segments_cleared.push(segment);
81        }
82        url_path_segments_cleared.push("");
83    }
84
85    Ok(SanitizedUrlWithTrailingSlash { internal_url: url })
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_sanitize_url_path() {
94        let url = Url::parse("http://example.com/a//b/c.ext?test=123").unwrap();
95        assert_eq!(
96            "http://example.com/a/b/c.ext/?test=123",
97            sanitize_url_path(&url).unwrap().as_str()
98        );
99
100        let url = Url::parse("http://example.com/a//b/c.ext").unwrap();
101        assert_eq!(
102            "http://example.com/a/b/c.ext/",
103            sanitize_url_path(&url).unwrap().as_str()
104        );
105
106        let url = Url::parse("http://example.com/a//b/c").unwrap();
107        assert_eq!(
108            "http://example.com/a/b/c/",
109            sanitize_url_path(&url).unwrap().as_str()
110        );
111
112        let url = Url::parse("http://example.com/").unwrap();
113        assert_eq!(
114            "http://example.com/",
115            sanitize_url_path(&url).unwrap().as_str()
116        );
117
118        let url = Url::parse("http://example.com").unwrap();
119        assert_eq!(
120            "http://example.com/",
121            sanitize_url_path(&url).unwrap().as_str()
122        );
123    }
124
125    #[test]
126    fn test_sanitize_url_join() {
127        let sanitized_url = sanitize_url_path(&Url::parse("http://example.com/").unwrap()).unwrap();
128        assert_eq!(
129            "http://example.com/a/b/c_ext/",
130            sanitized_url.sanitize_join("a//b/c_ext").unwrap().as_str()
131        );
132    }
133
134    #[test]
135    fn test_sanitize_url_parse() {
136        let sanitized_url =
137            SanitizedUrlWithTrailingSlash::parse("http://example.com/a//b/c.ext?test=123").unwrap();
138        assert_eq!(
139            "http://example.com/a/b/c.ext/?test=123",
140            sanitized_url.as_str()
141        );
142    }
143}