mithril_aggregator_discovery/
capabilities_discoverer.rs1use std::sync::Arc;
2
3use mithril_common::{
4 AggregateSignatureType, StdResult,
5 entities::{MithrilNetwork, SignedEntityTypeDiscriminants},
6 messages::AggregatorCapabilities,
7};
8
9use crate::{AggregatorDiscoverer, AggregatorEndpoint};
10
11#[derive(Clone, PartialEq, Eq, Debug)]
13pub enum RequiredAggregatorCapabilities {
14 All,
16 SignedEntityType(SignedEntityTypeDiscriminants),
18 AggregateSignatureType(AggregateSignatureType),
20 Or(Vec<RequiredAggregatorCapabilities>),
22 And(Vec<RequiredAggregatorCapabilities>),
24}
25
26impl RequiredAggregatorCapabilities {
27 fn matches(&self, available: &AggregatorCapabilities) -> bool {
29 match self {
30 RequiredAggregatorCapabilities::All => true,
31 RequiredAggregatorCapabilities::SignedEntityType(required_signed_entity_type) => {
32 available
33 .signed_entity_types
34 .iter()
35 .any(|req| req == required_signed_entity_type)
36 }
37 RequiredAggregatorCapabilities::AggregateSignatureType(
38 required_aggregate_signature_types,
39 ) => *required_aggregate_signature_types == available.aggregate_signature_type,
40 RequiredAggregatorCapabilities::Or(requirements) => {
41 requirements.iter().any(|req| req.matches(available))
42 }
43 RequiredAggregatorCapabilities::And(requirements) => {
44 requirements.iter().all(|req| req.matches(available))
45 }
46 }
47 }
48}
49
50pub struct CapableAggregatorDiscoverer {
52 required_capabilities: RequiredAggregatorCapabilities,
53 inner_discoverer: Arc<dyn AggregatorDiscoverer>,
54}
55
56impl CapableAggregatorDiscoverer {
57 pub fn new(
59 capabilities: RequiredAggregatorCapabilities,
60 inner_discoverer: Arc<dyn AggregatorDiscoverer>,
61 ) -> Self {
62 Self {
63 required_capabilities: capabilities,
64 inner_discoverer,
65 }
66 }
67}
68
69#[async_trait::async_trait]
70impl AggregatorDiscoverer for CapableAggregatorDiscoverer {
71 async fn get_available_aggregators(
72 &self,
73 network: MithrilNetwork,
74 ) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpoint>>> {
75 let aggregator_endpoints = self.inner_discoverer.get_available_aggregators(network).await?;
76
77 Ok(Box::new(CapableAggregatorDiscovererIterator {
78 required_capabilities: self.required_capabilities.clone(),
79 inner_iterator: aggregator_endpoints,
80 }))
81 }
82}
83
84struct CapableAggregatorDiscovererIterator {
86 required_capabilities: RequiredAggregatorCapabilities,
87 inner_iterator: Box<dyn Iterator<Item = AggregatorEndpoint>>,
88}
89
90impl Iterator for CapableAggregatorDiscovererIterator {
91 type Item = AggregatorEndpoint;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 for aggregator_endpoint in self.inner_iterator.by_ref() {
95 let aggregator_endpoint_clone = aggregator_endpoint.clone();
96 let aggregator_capabilities = tokio::task::block_in_place(move || {
97 tokio::runtime::Handle::current().block_on(async move {
98 aggregator_endpoint_clone.retrieve_capabilities().await
99 })
100 });
101 if let Ok(aggregator_capabilities) = aggregator_capabilities
102 && self.required_capabilities.matches(&aggregator_capabilities)
103 {
104 return Some(aggregator_endpoint);
105 }
106 }
107
108 None
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use std::collections::BTreeSet;
115
116 use httpmock::MockServer;
117 use serde_json::json;
118
119 use mithril_common::{
120 AggregateSignatureType::Concatenation,
121 entities::SignedEntityTypeDiscriminants::{
122 CardanoDatabase, CardanoStakeDistribution, CardanoTransactions,
123 MithrilStakeDistribution,
124 },
125 messages::AggregatorFeaturesMessage,
126 };
127
128 use super::*;
129
130 mod required_capabilities {
131 use super::*;
132
133 #[test]
134 fn required_capabilities_match_all_success() {
135 let required = RequiredAggregatorCapabilities::All;
136 let available = AggregatorCapabilities {
137 aggregate_signature_type: Concatenation,
138 signed_entity_types: BTreeSet::from([]),
139 cardano_transactions_prover: None,
140 };
141
142 assert!(required.matches(&available));
143 }
144
145 #[test]
146 fn required_capabilities_match_signed_entity_types_success() {
147 let required =
148 RequiredAggregatorCapabilities::SignedEntityType(CardanoStakeDistribution);
149 let available = AggregatorCapabilities {
150 aggregate_signature_type: Concatenation,
151 signed_entity_types: BTreeSet::from([
152 CardanoTransactions,
153 CardanoStakeDistribution,
154 CardanoDatabase,
155 ]),
156 cardano_transactions_prover: None,
157 };
158
159 assert!(required.matches(&available));
160 }
161
162 #[test]
163 fn required_capabilities_match_signed_entity_types_failure() {
164 let required =
165 RequiredAggregatorCapabilities::SignedEntityType(MithrilStakeDistribution);
166 let available = AggregatorCapabilities {
167 aggregate_signature_type: Concatenation,
168 signed_entity_types: BTreeSet::from([
169 CardanoTransactions,
170 CardanoStakeDistribution,
171 CardanoDatabase,
172 ]),
173 cardano_transactions_prover: None,
174 };
175
176 assert!(!required.matches(&available));
177 }
178
179 #[test]
180 fn required_capabilities_match_signed_aggregate_signature_type_success() {
181 let required = RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation);
182 let available = AggregatorCapabilities {
183 aggregate_signature_type: Concatenation,
184 signed_entity_types: BTreeSet::from([
185 CardanoTransactions,
186 CardanoStakeDistribution,
187 CardanoDatabase,
188 ]),
189 cardano_transactions_prover: None,
190 };
191
192 assert!(required.matches(&available));
193 }
194
195 #[test]
196 fn required_capabilities_match_or_success() {
197 let required = RequiredAggregatorCapabilities::Or(vec![
198 RequiredAggregatorCapabilities::SignedEntityType(MithrilStakeDistribution),
199 RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
200 ]);
201 let available = AggregatorCapabilities {
202 aggregate_signature_type: Concatenation,
203 signed_entity_types: BTreeSet::from([
204 CardanoTransactions,
205 CardanoStakeDistribution,
206 CardanoDatabase,
207 ]),
208 cardano_transactions_prover: None,
209 };
210
211 assert!(required.matches(&available));
212 }
213
214 #[test]
215 fn required_capabilities_match_and_success() {
216 let required = RequiredAggregatorCapabilities::And(vec![
217 RequiredAggregatorCapabilities::SignedEntityType(CardanoTransactions),
218 RequiredAggregatorCapabilities::SignedEntityType(CardanoStakeDistribution),
219 RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
220 ]);
221 let available = AggregatorCapabilities {
222 aggregate_signature_type: Concatenation,
223 signed_entity_types: BTreeSet::from([
224 CardanoTransactions,
225 CardanoStakeDistribution,
226 CardanoDatabase,
227 ]),
228 cardano_transactions_prover: None,
229 };
230
231 assert!(required.matches(&available));
232 }
233
234 #[test]
235 fn required_capabilities_match_and_failure() {
236 let required = RequiredAggregatorCapabilities::And(vec![
237 RequiredAggregatorCapabilities::SignedEntityType(CardanoTransactions),
238 RequiredAggregatorCapabilities::SignedEntityType(CardanoStakeDistribution),
239 RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
240 ]);
241 let available = AggregatorCapabilities {
242 aggregate_signature_type: Concatenation,
243 signed_entity_types: BTreeSet::from([CardanoTransactions]),
244 cardano_transactions_prover: None,
245 };
246
247 assert!(!required.matches(&available));
248 }
249 }
250
251 mod capable_discoverer {
252 use super::*;
253
254 fn create_aggregator_features_message(
255 capabilities: AggregatorCapabilities,
256 ) -> AggregatorFeaturesMessage {
257 AggregatorFeaturesMessage {
258 open_api_version: "1.0.0".to_string(),
259 documentation_url: "https://docs".to_string(),
260 capabilities,
261 }
262 }
263
264 #[tokio::test(flavor = "multi_thread")]
265 async fn get_available_aggregators_success() {
266 let capabilities = AggregatorCapabilities {
267 aggregate_signature_type: Concatenation,
268 signed_entity_types: BTreeSet::from([
269 CardanoStakeDistribution,
270 CardanoTransactions,
271 ]),
272 cardano_transactions_prover: None,
273 };
274 let aggregator_server = MockServer::start();
275 let aggregator_server_mock = aggregator_server.mock(|when, then| {
276 when.path("/");
277 then.status(200)
278 .body(json!(create_aggregator_features_message(capabilities)).to_string());
279 });
280 let discoverer = CapableAggregatorDiscoverer::new(
281 RequiredAggregatorCapabilities::And(vec![
282 RequiredAggregatorCapabilities::SignedEntityType(CardanoTransactions),
283 RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
284 ]),
285 Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
286 Ok(vec![AggregatorEndpoint::new(aggregator_server.url("/"))]),
287 ])),
288 );
289
290 let mut aggregators = discoverer
291 .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
292 .await
293 .unwrap();
294
295 let next_aggregator = aggregators.next();
296 aggregator_server_mock.assert();
297 assert_eq!(
298 Some(AggregatorEndpoint::new(aggregator_server.url("/"))),
299 next_aggregator
300 );
301 }
302
303 #[tokio::test(flavor = "multi_thread")]
304 async fn get_available_aggregators_succeeds_when_aggregator_capabilities_do_not_match() {
305 let capabilities = AggregatorCapabilities {
306 aggregate_signature_type: Concatenation,
307 signed_entity_types: BTreeSet::from([CardanoTransactions]),
308 cardano_transactions_prover: None,
309 };
310 let aggregator_server = MockServer::start();
311 let aggregator_server_mock = aggregator_server.mock(|when, then| {
312 when.path("/");
313 then.status(200)
314 .body(json!(create_aggregator_features_message(capabilities)).to_string());
315 });
316 let discoverer = CapableAggregatorDiscoverer::new(
317 RequiredAggregatorCapabilities::And(vec![
318 RequiredAggregatorCapabilities::SignedEntityType(CardanoDatabase),
319 RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
320 ]),
321 Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
322 Ok(vec![AggregatorEndpoint::new(aggregator_server.url("/"))]),
323 ])),
324 );
325
326 let mut aggregators = discoverer
327 .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
328 .await
329 .unwrap();
330
331 let next_aggregator = aggregators.next();
332 aggregator_server_mock.assert();
333 assert!(next_aggregator.is_none());
334 }
335
336 #[tokio::test(flavor = "multi_thread")]
337 async fn get_available_aggregators_succeeds_when_one_aggregator_returns_an_error() {
338 let aggregator_server_1 = MockServer::start();
339 let aggregator_server_mock_1 = aggregator_server_1.mock(|when, then| {
340 when.path("/");
341 then.status(500);
342 });
343 let capabilities_2 = AggregatorCapabilities {
344 aggregate_signature_type: Concatenation,
345 signed_entity_types: BTreeSet::from([CardanoStakeDistribution, CardanoDatabase]),
346 cardano_transactions_prover: None,
347 };
348 let aggregator_server_2 = MockServer::start();
349 let aggregator_server_mock_2 = aggregator_server_2.mock(|when, then| {
350 when.path("/");
351 then.status(200)
352 .body(json!(create_aggregator_features_message(capabilities_2)).to_string());
353 });
354 let discoverer = CapableAggregatorDiscoverer::new(
355 RequiredAggregatorCapabilities::And(vec![
356 RequiredAggregatorCapabilities::SignedEntityType(CardanoDatabase),
357 RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
358 ]),
359 Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
360 Ok(vec![
361 AggregatorEndpoint::new(aggregator_server_1.url("/")),
362 AggregatorEndpoint::new(aggregator_server_2.url("/")),
363 ]),
364 ])),
365 );
366
367 let mut aggregators = discoverer
368 .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
369 .await
370 .unwrap();
371
372 let next_aggregator = aggregators.next();
373 aggregator_server_mock_1.assert();
374 aggregator_server_mock_2.assert();
375 assert_eq!(
376 Some(AggregatorEndpoint::new(aggregator_server_2.url("/"))),
377 next_aggregator
378 );
379 }
380
381 #[tokio::test(flavor = "multi_thread")]
382 async fn get_available_aggregators_succeeds_and_makes_minimum_calls_to_aggregators() {
383 let aggregator_server_1 = MockServer::start();
384 let aggregator_server_mock_1 = aggregator_server_1.mock(|when, then| {
385 when.path("/");
386 then.status(500);
387 });
388 let capabilities_2 = AggregatorCapabilities {
389 aggregate_signature_type: Concatenation,
390 signed_entity_types: BTreeSet::from([CardanoStakeDistribution]),
391 cardano_transactions_prover: None,
392 };
393 let aggregator_server_2 = MockServer::start();
394 let aggregator_server_mock_2 = aggregator_server_2.mock(|when, then| {
395 when.path("/");
396 then.status(200)
397 .body(json!(create_aggregator_features_message(capabilities_2)).to_string());
398 });
399 let capabilities_3 = AggregatorCapabilities {
400 aggregate_signature_type: Concatenation,
401 signed_entity_types: BTreeSet::from([CardanoDatabase]),
402 cardano_transactions_prover: None,
403 };
404 let aggregator_server_3 = MockServer::start();
405 let aggregator_server_mock_3 = aggregator_server_3.mock(|when, then| {
406 when.path("/");
407 then.status(200)
408 .body(json!(create_aggregator_features_message(capabilities_3)).to_string());
409 });
410 let capabilities_4 = AggregatorCapabilities {
411 aggregate_signature_type: Concatenation,
412 signed_entity_types: BTreeSet::from([CardanoDatabase]),
413 cardano_transactions_prover: None,
414 };
415 let aggregator_server_4 = MockServer::start();
416 let aggregator_server_mock_4 = aggregator_server_4.mock(|when, then| {
417 when.path("/");
418 then.status(200)
419 .body(json!(create_aggregator_features_message(capabilities_4)).to_string());
420 });
421 let discoverer = CapableAggregatorDiscoverer::new(
422 RequiredAggregatorCapabilities::And(vec![
423 RequiredAggregatorCapabilities::SignedEntityType(CardanoDatabase),
424 RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
425 ]),
426 Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
427 Ok(vec![
428 AggregatorEndpoint::new(aggregator_server_1.url("/")),
429 AggregatorEndpoint::new(aggregator_server_2.url("/")),
430 AggregatorEndpoint::new(aggregator_server_3.url("/")),
431 AggregatorEndpoint::new(aggregator_server_4.url("/")),
432 ]),
433 ])),
434 );
435
436 let mut aggregators = discoverer
437 .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
438 .await
439 .unwrap();
440
441 let next_aggregator = aggregators.next();
442 aggregator_server_mock_1.assert();
443 aggregator_server_mock_2.assert();
444 aggregator_server_mock_3.assert();
445 assert_eq!(0, aggregator_server_mock_4.calls());
446 assert_eq!(
447 Some(AggregatorEndpoint::new(aggregator_server_3.url("/"))),
448 next_aggregator
449 );
450 }
451 }
452}