mithril_aggregator_discovery/
rand_discoverer.rs

1use std::sync::Arc;
2
3use rand::{Rng, seq::SliceRandom};
4use tokio::sync::Mutex;
5
6use mithril_common::{StdResult, entities::MithrilNetwork};
7
8use crate::{AggregatorDiscoverer, AggregatorEndpoint};
9
10/// A discoverer that returns a random set of aggregators
11pub struct ShuffleAggregatorDiscoverer<R: Rng + Send + Sized> {
12    random_generator: Arc<Mutex<Box<R>>>,
13    inner_discoverer: Arc<dyn AggregatorDiscoverer>,
14}
15
16impl<R: Rng + Send + Sized> ShuffleAggregatorDiscoverer<R> {
17    /// Creates a new `ShuffleAggregatorDiscoverer` instance with the provided inner discoverer.
18    pub fn new(inner_discoverer: Arc<dyn AggregatorDiscoverer>, random_generator: R) -> Self {
19        Self {
20            inner_discoverer,
21            random_generator: Arc::new(Mutex::new(Box::new(random_generator))),
22        }
23    }
24}
25
26#[async_trait::async_trait]
27impl<R: Rng + Send + Sized> AggregatorDiscoverer for ShuffleAggregatorDiscoverer<R> {
28    async fn get_available_aggregators(
29        &self,
30        network: MithrilNetwork,
31    ) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpoint>>> {
32        let mut aggregators: Vec<AggregatorEndpoint> = self
33            .inner_discoverer
34            .get_available_aggregators(network)
35            .await?
36            .collect();
37        let mut rng = self.random_generator.lock().await;
38        aggregators.shuffle(&mut *rng);
39
40        Ok(Box::new(aggregators.into_iter()))
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use rand::{SeedableRng, rngs::StdRng};
47
48    use crate::test::double::AggregatorDiscovererFake;
49
50    use super::*;
51
52    #[tokio::test]
53    async fn shuffle_aggregator_discoverer() {
54        let inner_discoverer = AggregatorDiscovererFake::new(vec![Ok(vec![
55            AggregatorEndpoint::new("https://release-devnet-aggregator1".to_string()),
56            AggregatorEndpoint::new("https://release-devnet-aggregator2".to_string()),
57            AggregatorEndpoint::new("https://release-devnet-aggregator3".to_string()),
58        ])]);
59        let seed = [0u8; 32];
60        let rng = StdRng::from_seed(seed);
61        let discoverer = ShuffleAggregatorDiscoverer::new(Arc::new(inner_discoverer), rng);
62
63        let aggregators = discoverer
64            .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
65            .await
66            .unwrap();
67
68        assert_eq!(
69            vec![
70                AggregatorEndpoint::new("https://release-devnet-aggregator3".into()),
71                AggregatorEndpoint::new("https://release-devnet-aggregator2".into()),
72                AggregatorEndpoint::new("https://release-devnet-aggregator1".into()),
73            ],
74            aggregators.collect::<Vec<_>>()
75        );
76    }
77}