Skip to content

Commit

Permalink
Add cancellation checks to potentially heavy loops in query planning
Browse files Browse the repository at this point in the history
  • Loading branch information
sachindshinde committed Feb 28, 2025
1 parent 90d2b21 commit ba9877f
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 12 deletions.
23 changes: 17 additions & 6 deletions apollo-federation/src/operation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1541,9 +1541,12 @@ impl SelectionSet {
Ok(())
}

pub(crate) fn expand_all_fragments(&self) -> Result<SelectionSet, FederationError> {
pub(crate) fn expand_all_fragments(
&self,
check_cancellation: &impl Fn() -> Result<(), SingleFederationError>,
) -> Result<SelectionSet, FederationError> {
let mut expanded_selections = vec![];
SelectionSet::expand_selection_set(&mut expanded_selections, self)?;
SelectionSet::expand_selection_set(&mut expanded_selections, self, check_cancellation)?;

let mut expanded = SelectionSet {
schema: self.schema.clone(),
Expand All @@ -1557,12 +1560,14 @@ impl SelectionSet {
fn expand_selection_set(
destination: &mut Vec<Selection>,
selection_set: &SelectionSet,
check_cancellation: &impl Fn() -> Result<(), SingleFederationError>,
) -> Result<(), FederationError> {
for value in selection_set.selections.values() {
check_cancellation()?;
match value {
Selection::Field(field_selection) => {
let selections = match &field_selection.selection_set {
Some(s) => Some(s.expand_all_fragments()?),
Some(s) => Some(s.expand_all_fragments(check_cancellation)?),
None => None,
};
destination.push(Selection::from_field(
Expand All @@ -1580,12 +1585,14 @@ impl SelectionSet {
SelectionSet::expand_selection_set(
destination,
&spread_selection.selection_set,
check_cancellation,
)?;
} else {
// convert to inline fragment
let expanded = InlineFragmentSelection::from_fragment_spread_selection(
selection_set.type_position.clone(), // the parent type of this inline selection
spread_selection,
check_cancellation,
)?;
destination.push(Selection::InlineFragment(Arc::new(expanded)));
}
Expand All @@ -1594,7 +1601,9 @@ impl SelectionSet {
destination.push(
InlineFragmentSelection::new(
inline_selection.inline_fragment.clone(),
inline_selection.selection_set.expand_all_fragments()?,
inline_selection
.selection_set
.expand_all_fragments(check_cancellation)?,
)
.into(),
);
Expand Down Expand Up @@ -2716,6 +2725,7 @@ impl InlineFragmentSelection {
pub(crate) fn from_fragment_spread_selection(
parent_type_position: CompositeTypeDefinitionPosition,
fragment_spread_selection: &Arc<FragmentSpreadSelection>,
check_cancellation: &impl Fn() -> Result<(), SingleFederationError>,
) -> Result<InlineFragmentSelection, FederationError> {
let schema = fragment_spread_selection.spread.schema.schema();
for directive in fragment_spread_selection.spread.directives.iter() {
Expand Down Expand Up @@ -2753,7 +2763,7 @@ impl InlineFragmentSelection {
},
fragment_spread_selection
.selection_set
.expand_all_fragments()?,
.expand_all_fragments(check_cancellation)?,
))
}

Expand Down Expand Up @@ -3749,10 +3759,11 @@ pub(crate) fn normalize_operation(
named_fragments: NamedFragments,
schema: &ValidFederationSchema,
interface_types_with_interface_objects: &IndexSet<InterfaceTypeDefinitionPosition>,
check_cancellation: &impl Fn() -> Result<(), SingleFederationError>,
) -> Result<Operation, FederationError> {
let mut normalized_selection_set =
SelectionSet::from_selection_set(&operation.selection_set, &named_fragments, schema)?;
normalized_selection_set = normalized_selection_set.expand_all_fragments()?;
normalized_selection_set = normalized_selection_set.expand_all_fragments(check_cancellation)?;
// We clear up the fragments since we've expanded all.
// Also note that expanding fragment usually generate unnecessary fragments/inefficient
// selections, so it basically always make sense to flatten afterwards. Besides, fragment
Expand Down
32 changes: 31 additions & 1 deletion apollo-federation/src/operation/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ pub(super) fn parse_and_expand(
.expect("must have anonymous operation");
let fragments = NamedFragments::new(&doc.fragments, schema);

normalize_operation(operation, fragments, schema, &Default::default())
normalize_operation(
operation,
fragments,
schema,
&Default::default(),
&|| Ok(()),
)
}

#[test]
Expand Down Expand Up @@ -100,6 +106,7 @@ type Foo {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
normalized_operation.named_fragments = Default::default();
Expand Down Expand Up @@ -154,6 +161,7 @@ type Foo {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
normalized_operation.named_fragments = Default::default();
Expand Down Expand Up @@ -196,6 +204,7 @@ type Query {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();

Expand Down Expand Up @@ -231,6 +240,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test {
Expand Down Expand Up @@ -274,6 +284,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skipIf: Boolean!) {
Expand Down Expand Up @@ -320,6 +331,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skipIf: Boolean!) {
Expand Down Expand Up @@ -364,6 +376,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skipIf: Boolean!) {
Expand Down Expand Up @@ -410,6 +423,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skip1: Boolean!, $skip2: Boolean!) {
Expand Down Expand Up @@ -461,6 +475,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test {
Expand Down Expand Up @@ -527,6 +542,7 @@ type V {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test {
Expand Down Expand Up @@ -586,6 +602,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test {
Expand Down Expand Up @@ -632,6 +649,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skipIf: Boolean!) {
Expand Down Expand Up @@ -682,6 +700,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skipIf: Boolean!) {
Expand Down Expand Up @@ -730,6 +749,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skipIf: Boolean!) {
Expand Down Expand Up @@ -778,6 +798,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test($skip1: Boolean!, $skip2: Boolean!) {
Expand Down Expand Up @@ -830,6 +851,7 @@ type T {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test {
Expand Down Expand Up @@ -898,6 +920,7 @@ type V {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query Test {
Expand Down Expand Up @@ -944,6 +967,7 @@ type Foo {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query TestQuery {
Expand Down Expand Up @@ -983,6 +1007,7 @@ type Foo {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
let expected = r#"query TestQuery {
Expand Down Expand Up @@ -1033,6 +1058,7 @@ scalar FieldSet
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&interface_objects,
&|| Ok(()),
)
.unwrap();
let expected = r#"query TestQuery {
Expand Down Expand Up @@ -1172,6 +1198,7 @@ mod make_selection_tests {
Default::default(),
&schema,
&Default::default(),
&|| Ok(()),
)
.unwrap();

Expand Down Expand Up @@ -1271,6 +1298,7 @@ mod lazy_map_tests {
Default::default(),
&schema,
&Default::default(),
&|| Ok(()),
)
.unwrap();

Expand Down Expand Up @@ -1329,6 +1357,7 @@ mod lazy_map_tests {
Default::default(),
&schema,
&Default::default(),
&|| Ok(()),
)
.unwrap();

Expand Down Expand Up @@ -1534,6 +1563,7 @@ fn test_expand_all_fragments1() {
NamedFragments::new(&executable_document.fragments, &schema),
&schema,
&IndexSet::default(),
&|| Ok(()),
)
.unwrap();
normalized_operation.named_fragments = Default::default();
Expand Down
Loading

0 comments on commit ba9877f

Please sign in to comment.