mithril_aggregator_discovery/
capabilities_discoverer.rs

1use std::sync::Arc;
2
3use mithril_common::{
4    AggregateSignatureType, StdResult,
5    entities::{MithrilNetwork, SignedEntityTypeDiscriminants},
6    messages::AggregatorCapabilities,
7};
8
9use crate::{AggregatorDiscoverer, AggregatorEndpoint};
10
11/// Required capabilities for an aggregator.
12#[derive(Clone, PartialEq, Eq, Debug)]
13pub enum RequiredAggregatorCapabilities {
14    /// All
15    All,
16    /// Signed entity type.
17    SignedEntityType(SignedEntityTypeDiscriminants),
18    /// Aggregate signature type.
19    AggregateSignatureType(AggregateSignatureType),
20    /// Logical OR of required capabilities.
21    Or(Vec<RequiredAggregatorCapabilities>),
22    /// Logical AND of required capabilities.
23    And(Vec<RequiredAggregatorCapabilities>),
24}
25
26impl RequiredAggregatorCapabilities {
27    /// Check if the available capabilities match the required capabilities.
28    fn matches(&self, available: &AggregatorCapabilities) -> bool {
29        match self {
30            RequiredAggregatorCapabilities::All => true,
31            RequiredAggregatorCapabilities::SignedEntityType(required_signed_entity_type) => {
32                available
33                    .signed_entity_types
34                    .iter()
35                    .any(|req| req == required_signed_entity_type)
36            }
37            RequiredAggregatorCapabilities::AggregateSignatureType(
38                required_aggregate_signature_types,
39            ) => *required_aggregate_signature_types == available.aggregate_signature_type,
40            RequiredAggregatorCapabilities::Or(requirements) => {
41                requirements.iter().any(|req| req.matches(available))
42            }
43            RequiredAggregatorCapabilities::And(requirements) => {
44                requirements.iter().all(|req| req.matches(available))
45            }
46        }
47    }
48}
49
50/// An aggregator discoverer for specific capabilities.
51pub struct CapableAggregatorDiscoverer {
52    required_capabilities: RequiredAggregatorCapabilities,
53    inner_discoverer: Arc<dyn AggregatorDiscoverer>,
54}
55
56impl CapableAggregatorDiscoverer {
57    /// Creates a new `CapableAggregatorDiscoverer` instance with the provided capabilities.
58    pub fn new(
59        capabilities: RequiredAggregatorCapabilities,
60        inner_discoverer: Arc<dyn AggregatorDiscoverer>,
61    ) -> Self {
62        Self {
63            required_capabilities: capabilities,
64            inner_discoverer,
65        }
66    }
67}
68
69#[async_trait::async_trait]
70impl AggregatorDiscoverer for CapableAggregatorDiscoverer {
71    async fn get_available_aggregators(
72        &self,
73        network: MithrilNetwork,
74    ) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpoint>>> {
75        let aggregator_endpoints = self.inner_discoverer.get_available_aggregators(network).await?;
76
77        Ok(Box::new(CapableAggregatorDiscovererIterator {
78            required_capabilities: self.required_capabilities.clone(),
79            inner_iterator: aggregator_endpoints,
80        }))
81    }
82}
83
84/// An iterator over aggregator endpoints filtered by capabilities.
85struct CapableAggregatorDiscovererIterator {
86    required_capabilities: RequiredAggregatorCapabilities,
87    inner_iterator: Box<dyn Iterator<Item = AggregatorEndpoint>>,
88}
89
90impl Iterator for CapableAggregatorDiscovererIterator {
91    type Item = AggregatorEndpoint;
92
93    fn next(&mut self) -> Option<Self::Item> {
94        for aggregator_endpoint in self.inner_iterator.by_ref() {
95            let aggregator_endpoint_clone = aggregator_endpoint.clone();
96            let aggregator_capabilities = tokio::task::block_in_place(move || {
97                tokio::runtime::Handle::current().block_on(async move {
98                    aggregator_endpoint_clone.retrieve_capabilities().await
99                })
100            });
101            if let Ok(aggregator_capabilities) = aggregator_capabilities
102                && self.required_capabilities.matches(&aggregator_capabilities)
103            {
104                return Some(aggregator_endpoint);
105            }
106        }
107
108        None
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use std::collections::BTreeSet;
115
116    use httpmock::MockServer;
117    use serde_json::json;
118
119    use mithril_common::{
120        AggregateSignatureType::Concatenation,
121        entities::SignedEntityTypeDiscriminants::{
122            CardanoDatabase, CardanoStakeDistribution, CardanoTransactions,
123            MithrilStakeDistribution,
124        },
125        messages::AggregatorFeaturesMessage,
126    };
127
128    use super::*;
129
130    mod required_capabilities {
131        use super::*;
132
133        #[test]
134        fn required_capabilities_match_all_success() {
135            let required = RequiredAggregatorCapabilities::All;
136            let available = AggregatorCapabilities {
137                aggregate_signature_type: Concatenation,
138                signed_entity_types: BTreeSet::from([]),
139                cardano_transactions_prover: None,
140            };
141
142            assert!(required.matches(&available));
143        }
144
145        #[test]
146        fn required_capabilities_match_signed_entity_types_success() {
147            let required =
148                RequiredAggregatorCapabilities::SignedEntityType(CardanoStakeDistribution);
149            let available = AggregatorCapabilities {
150                aggregate_signature_type: Concatenation,
151                signed_entity_types: BTreeSet::from([
152                    CardanoTransactions,
153                    CardanoStakeDistribution,
154                    CardanoDatabase,
155                ]),
156                cardano_transactions_prover: None,
157            };
158
159            assert!(required.matches(&available));
160        }
161
162        #[test]
163        fn required_capabilities_match_signed_entity_types_failure() {
164            let required =
165                RequiredAggregatorCapabilities::SignedEntityType(MithrilStakeDistribution);
166            let available = AggregatorCapabilities {
167                aggregate_signature_type: Concatenation,
168                signed_entity_types: BTreeSet::from([
169                    CardanoTransactions,
170                    CardanoStakeDistribution,
171                    CardanoDatabase,
172                ]),
173                cardano_transactions_prover: None,
174            };
175
176            assert!(!required.matches(&available));
177        }
178
179        #[test]
180        fn required_capabilities_match_signed_aggregate_signature_type_success() {
181            let required = RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation);
182            let available = AggregatorCapabilities {
183                aggregate_signature_type: Concatenation,
184                signed_entity_types: BTreeSet::from([
185                    CardanoTransactions,
186                    CardanoStakeDistribution,
187                    CardanoDatabase,
188                ]),
189                cardano_transactions_prover: None,
190            };
191
192            assert!(required.matches(&available));
193        }
194
195        #[test]
196        fn required_capabilities_match_or_success() {
197            let required = RequiredAggregatorCapabilities::Or(vec![
198                RequiredAggregatorCapabilities::SignedEntityType(MithrilStakeDistribution),
199                RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
200            ]);
201            let available = AggregatorCapabilities {
202                aggregate_signature_type: Concatenation,
203                signed_entity_types: BTreeSet::from([
204                    CardanoTransactions,
205                    CardanoStakeDistribution,
206                    CardanoDatabase,
207                ]),
208                cardano_transactions_prover: None,
209            };
210
211            assert!(required.matches(&available));
212        }
213
214        #[test]
215        fn required_capabilities_match_and_success() {
216            let required = RequiredAggregatorCapabilities::And(vec![
217                RequiredAggregatorCapabilities::SignedEntityType(CardanoTransactions),
218                RequiredAggregatorCapabilities::SignedEntityType(CardanoStakeDistribution),
219                RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
220            ]);
221            let available = AggregatorCapabilities {
222                aggregate_signature_type: Concatenation,
223                signed_entity_types: BTreeSet::from([
224                    CardanoTransactions,
225                    CardanoStakeDistribution,
226                    CardanoDatabase,
227                ]),
228                cardano_transactions_prover: None,
229            };
230
231            assert!(required.matches(&available));
232        }
233
234        #[test]
235        fn required_capabilities_match_and_failure() {
236            let required = RequiredAggregatorCapabilities::And(vec![
237                RequiredAggregatorCapabilities::SignedEntityType(CardanoTransactions),
238                RequiredAggregatorCapabilities::SignedEntityType(CardanoStakeDistribution),
239                RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
240            ]);
241            let available = AggregatorCapabilities {
242                aggregate_signature_type: Concatenation,
243                signed_entity_types: BTreeSet::from([CardanoTransactions]),
244                cardano_transactions_prover: None,
245            };
246
247            assert!(!required.matches(&available));
248        }
249    }
250
251    mod capable_discoverer {
252        use super::*;
253
254        fn create_aggregator_features_message(
255            capabilities: AggregatorCapabilities,
256        ) -> AggregatorFeaturesMessage {
257            AggregatorFeaturesMessage {
258                open_api_version: "1.0.0".to_string(),
259                documentation_url: "https://docs".to_string(),
260                capabilities,
261            }
262        }
263
264        #[tokio::test(flavor = "multi_thread")]
265        async fn get_available_aggregators_success() {
266            let capabilities = AggregatorCapabilities {
267                aggregate_signature_type: Concatenation,
268                signed_entity_types: BTreeSet::from([
269                    CardanoStakeDistribution,
270                    CardanoTransactions,
271                ]),
272                cardano_transactions_prover: None,
273            };
274            let aggregator_server = MockServer::start();
275            let aggregator_server_mock = aggregator_server.mock(|when, then| {
276                when.path("/");
277                then.status(200)
278                    .body(json!(create_aggregator_features_message(capabilities)).to_string());
279            });
280            let discoverer = CapableAggregatorDiscoverer::new(
281                RequiredAggregatorCapabilities::And(vec![
282                    RequiredAggregatorCapabilities::SignedEntityType(CardanoTransactions),
283                    RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
284                ]),
285                Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
286                    Ok(vec![AggregatorEndpoint::new(aggregator_server.url("/"))]),
287                ])),
288            );
289
290            let mut aggregators = discoverer
291                .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
292                .await
293                .unwrap();
294
295            let next_aggregator = aggregators.next();
296            aggregator_server_mock.assert();
297            assert_eq!(
298                Some(AggregatorEndpoint::new(aggregator_server.url("/"))),
299                next_aggregator
300            );
301        }
302
303        #[tokio::test(flavor = "multi_thread")]
304        async fn get_available_aggregators_succeeds_when_aggregator_capabilities_do_not_match() {
305            let capabilities = AggregatorCapabilities {
306                aggregate_signature_type: Concatenation,
307                signed_entity_types: BTreeSet::from([CardanoTransactions]),
308                cardano_transactions_prover: None,
309            };
310            let aggregator_server = MockServer::start();
311            let aggregator_server_mock = aggregator_server.mock(|when, then| {
312                when.path("/");
313                then.status(200)
314                    .body(json!(create_aggregator_features_message(capabilities)).to_string());
315            });
316            let discoverer = CapableAggregatorDiscoverer::new(
317                RequiredAggregatorCapabilities::And(vec![
318                    RequiredAggregatorCapabilities::SignedEntityType(CardanoDatabase),
319                    RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
320                ]),
321                Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
322                    Ok(vec![AggregatorEndpoint::new(aggregator_server.url("/"))]),
323                ])),
324            );
325
326            let mut aggregators = discoverer
327                .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
328                .await
329                .unwrap();
330
331            let next_aggregator = aggregators.next();
332            aggregator_server_mock.assert();
333            assert!(next_aggregator.is_none());
334        }
335
336        #[tokio::test(flavor = "multi_thread")]
337        async fn get_available_aggregators_succeeds_when_one_aggregator_returns_an_error() {
338            let aggregator_server_1 = MockServer::start();
339            let aggregator_server_mock_1 = aggregator_server_1.mock(|when, then| {
340                when.path("/");
341                then.status(500);
342            });
343            let capabilities_2 = AggregatorCapabilities {
344                aggregate_signature_type: Concatenation,
345                signed_entity_types: BTreeSet::from([CardanoStakeDistribution, CardanoDatabase]),
346                cardano_transactions_prover: None,
347            };
348            let aggregator_server_2 = MockServer::start();
349            let aggregator_server_mock_2 = aggregator_server_2.mock(|when, then| {
350                when.path("/");
351                then.status(200)
352                    .body(json!(create_aggregator_features_message(capabilities_2)).to_string());
353            });
354            let discoverer = CapableAggregatorDiscoverer::new(
355                RequiredAggregatorCapabilities::And(vec![
356                    RequiredAggregatorCapabilities::SignedEntityType(CardanoDatabase),
357                    RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
358                ]),
359                Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
360                    Ok(vec![
361                        AggregatorEndpoint::new(aggregator_server_1.url("/")),
362                        AggregatorEndpoint::new(aggregator_server_2.url("/")),
363                    ]),
364                ])),
365            );
366
367            let mut aggregators = discoverer
368                .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
369                .await
370                .unwrap();
371
372            let next_aggregator = aggregators.next();
373            aggregator_server_mock_1.assert();
374            aggregator_server_mock_2.assert();
375            assert_eq!(
376                Some(AggregatorEndpoint::new(aggregator_server_2.url("/"))),
377                next_aggregator
378            );
379        }
380
381        #[tokio::test(flavor = "multi_thread")]
382        async fn get_available_aggregators_succeeds_and_makes_minimum_calls_to_aggregators() {
383            let aggregator_server_1 = MockServer::start();
384            let aggregator_server_mock_1 = aggregator_server_1.mock(|when, then| {
385                when.path("/");
386                then.status(500);
387            });
388            let capabilities_2 = AggregatorCapabilities {
389                aggregate_signature_type: Concatenation,
390                signed_entity_types: BTreeSet::from([CardanoStakeDistribution]),
391                cardano_transactions_prover: None,
392            };
393            let aggregator_server_2 = MockServer::start();
394            let aggregator_server_mock_2 = aggregator_server_2.mock(|when, then| {
395                when.path("/");
396                then.status(200)
397                    .body(json!(create_aggregator_features_message(capabilities_2)).to_string());
398            });
399            let capabilities_3 = AggregatorCapabilities {
400                aggregate_signature_type: Concatenation,
401                signed_entity_types: BTreeSet::from([CardanoDatabase]),
402                cardano_transactions_prover: None,
403            };
404            let aggregator_server_3 = MockServer::start();
405            let aggregator_server_mock_3 = aggregator_server_3.mock(|when, then| {
406                when.path("/");
407                then.status(200)
408                    .body(json!(create_aggregator_features_message(capabilities_3)).to_string());
409            });
410            let capabilities_4 = AggregatorCapabilities {
411                aggregate_signature_type: Concatenation,
412                signed_entity_types: BTreeSet::from([CardanoDatabase]),
413                cardano_transactions_prover: None,
414            };
415            let aggregator_server_4 = MockServer::start();
416            let aggregator_server_mock_4 = aggregator_server_4.mock(|when, then| {
417                when.path("/");
418                then.status(200)
419                    .body(json!(create_aggregator_features_message(capabilities_4)).to_string());
420            });
421            let discoverer = CapableAggregatorDiscoverer::new(
422                RequiredAggregatorCapabilities::And(vec![
423                    RequiredAggregatorCapabilities::SignedEntityType(CardanoDatabase),
424                    RequiredAggregatorCapabilities::AggregateSignatureType(Concatenation),
425                ]),
426                Arc::new(crate::test::double::AggregatorDiscovererFake::new(vec![
427                    Ok(vec![
428                        AggregatorEndpoint::new(aggregator_server_1.url("/")),
429                        AggregatorEndpoint::new(aggregator_server_2.url("/")),
430                        AggregatorEndpoint::new(aggregator_server_3.url("/")),
431                        AggregatorEndpoint::new(aggregator_server_4.url("/")),
432                    ]),
433                ])),
434            );
435
436            let mut aggregators = discoverer
437                .get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
438                .await
439                .unwrap();
440
441            let next_aggregator = aggregators.next();
442            aggregator_server_mock_1.assert();
443            aggregator_server_mock_2.assert();
444            aggregator_server_mock_3.assert();
445            assert_eq!(0, aggregator_server_mock_4.calls());
446            assert_eq!(
447                Some(AggregatorEndpoint::new(aggregator_server_3.url("/"))),
448                next_aggregator
449            );
450        }
451    }
452}