mithril_aggregator_discovery/
rand_discoverer.rs1use std::sync::Arc;
2
3use rand::{Rng, seq::SliceRandom};
4use tokio::sync::Mutex;
5
6use mithril_common::{StdResult, entities::MithrilNetwork};
7
8use crate::{AggregatorDiscoverer, AggregatorEndpoint};
9
10pub struct ShuffleAggregatorDiscoverer<R: Rng + Send + Sized> {
12 random_generator: Arc<Mutex<Box<R>>>,
13 inner_discoverer: Arc<dyn AggregatorDiscoverer>,
14}
15
16impl<R: Rng + Send + Sized> ShuffleAggregatorDiscoverer<R> {
17 pub fn new(inner_discoverer: Arc<dyn AggregatorDiscoverer>, random_generator: R) -> Self {
19 Self {
20 inner_discoverer,
21 random_generator: Arc::new(Mutex::new(Box::new(random_generator))),
22 }
23 }
24}
25
26#[async_trait::async_trait]
27impl<R: Rng + Send + Sized> AggregatorDiscoverer for ShuffleAggregatorDiscoverer<R> {
28 async fn get_available_aggregators(
29 &self,
30 network: MithrilNetwork,
31 ) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpoint>>> {
32 let mut aggregators: Vec<AggregatorEndpoint> = self
33 .inner_discoverer
34 .get_available_aggregators(network)
35 .await?
36 .collect();
37 let mut rng = self.random_generator.lock().await;
38 aggregators.shuffle(&mut *rng);
39
40 Ok(Box::new(aggregators.into_iter()))
41 }
42}
43
44#[cfg(test)]
45mod tests {
46 use rand::{SeedableRng, rngs::StdRng};
47
48 use crate::test::double::AggregatorDiscovererFake;
49
50 use super::*;
51
52 #[tokio::test]
53 async fn shuffle_aggregator_discoverer() {
54 let inner_discoverer = AggregatorDiscovererFake::new(vec![Ok(vec![
55 AggregatorEndpoint::new("https://release-devnet-aggregator1".to_string()),
56 AggregatorEndpoint::new("https://release-devnet-aggregator2".to_string()),
57 AggregatorEndpoint::new("https://release-devnet-aggregator3".to_string()),
58 ])]);
59 let seed = [0u8; 32];
60 let rng = StdRng::from_seed(seed);
61 let discoverer = ShuffleAggregatorDiscoverer::new(Arc::new(inner_discoverer), rng);
62
63 let aggregators = discoverer
64 .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
65 .await
66 .unwrap();
67
68 assert_eq!(
69 vec![
70 AggregatorEndpoint::new("https://release-devnet-aggregator3".into()),
71 AggregatorEndpoint::new("https://release-devnet-aggregator2".into()),
72 AggregatorEndpoint::new("https://release-devnet-aggregator1".into()),
73 ],
74 aggregators.collect::<Vec<_>>()
75 );
76 }
77}