Skip to content

Commit 00bc317

Browse files
committed
improve implementation to avoid starting nested traversals
1 parent 3b20618 commit 00bc317

2 files changed

Lines changed: 69 additions & 22 deletions

File tree

src/main/java/graphql/schema/transform/FieldVisibilitySchemaTransformation.java

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import graphql.PublicApi;
44
import graphql.schema.GraphQLEnumType;
55
import graphql.schema.GraphQLFieldDefinition;
6+
import graphql.schema.GraphQLImplementingType;
67
import graphql.schema.GraphQLInputObjectField;
78
import graphql.schema.GraphQLInterfaceType;
89
import graphql.schema.GraphQLNamedSchemaElement;
@@ -14,20 +15,22 @@
1415
import graphql.schema.GraphQLTypeVisitorStub;
1516
import graphql.schema.GraphQLUnionType;
1617
import graphql.schema.SchemaTraverser;
18+
import graphql.schema.impl.SchemaUtil;
1719
import graphql.schema.transform.VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl;
1820
import graphql.util.TraversalControl;
1921
import graphql.util.TraverserContext;
2022

21-
import java.util.Collections;
23+
import java.util.ArrayList;
2224
import java.util.HashSet;
2325
import java.util.List;
26+
import java.util.Map;
2427
import java.util.Objects;
2528
import java.util.Set;
29+
import java.util.function.Function;
2630
import java.util.stream.Collectors;
2731
import java.util.stream.Stream;
2832

2933
import static graphql.schema.SchemaTransformer.transformSchema;
30-
import static graphql.util.TreeTransformerUtil.deleteNode;
3134

3235
/**
3336
* Transforms a schema by applying a visibility predicate to every field.
@@ -63,13 +66,13 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
6366

6467
beforeTransformationHook.run();
6568

66-
new SchemaTraverser().depthFirst(new TypeObservingVisitor(observedBeforeTransform, schema), getRootTypes(schema));
69+
new SchemaTraverser(getChildrenFn(schema)).depthFirst(new TypeObservingVisitor(observedBeforeTransform), getRootTypes(schema));
6770

6871
// remove fields
6972
GraphQLSchema interimSchema = transformSchema(schema,
7073
new FieldRemovalVisitor(visibleFieldPredicate, markedForRemovalTypes));
7174

72-
new SchemaTraverser().depthFirst(new TypeObservingVisitor(observedAfterTransform, interimSchema), getRootTypes(interimSchema));
75+
new SchemaTraverser(getChildrenFn(interimSchema)).depthFirst(new TypeObservingVisitor(observedAfterTransform), getRootTypes(interimSchema));
7376

7477
// remove types that are not used after removing fields - (connected schema only)
7578
GraphQLSchema connectedSchema = transformSchema(interimSchema,
@@ -84,6 +87,23 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
8487
return finalSchema;
8588
}
8689

90+
// Creates a getChildrenFn that includes interface
91+
private Function<GraphQLSchemaElement, List<GraphQLSchemaElement>> getChildrenFn(GraphQLSchema schema) {
92+
Map<String, List<GraphQLImplementingType>> interfaceImplementations = new SchemaUtil().groupImplementationsForInterfacesAndObjects(schema);
93+
94+
return graphQLSchemaElement -> {
95+
if (!(graphQLSchemaElement instanceof GraphQLInterfaceType)) {
96+
return graphQLSchemaElement.getChildren();
97+
}
98+
ArrayList<GraphQLSchemaElement> children = new ArrayList<>(graphQLSchemaElement.getChildren());
99+
List<GraphQLImplementingType> implementations = interfaceImplementations.get(((GraphQLInterfaceType) graphQLSchemaElement).getName());
100+
if (implementations != null) {
101+
children.addAll(implementations);
102+
}
103+
return children;
104+
};
105+
}
106+
87107
private GraphQLSchema removeUnreferencedTypes(Set<GraphQLType> markedForRemovalTypes, GraphQLSchema connectedSchema) {
88108
GraphQLSchema withoutAdditionalTypes = connectedSchema.transform(builder -> {
89109
Set<GraphQLType> additionalTypes = new HashSet<>(connectedSchema.getAdditionalTypes());
@@ -110,12 +130,10 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, Traverser
110130
private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
111131

112132
private final Set<GraphQLType> observedTypes;
113-
private GraphQLSchema graphQLSchema;
114133

115134

116-
private TypeObservingVisitor(Set<GraphQLType> observedTypes, GraphQLSchema graphQLSchema) {
135+
private TypeObservingVisitor(Set<GraphQLType> observedTypes) {
117136
this.observedTypes = observedTypes;
118-
this.graphQLSchema = graphQLSchema;
119137
}
120138

121139
@Override
@@ -124,21 +142,6 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
124142
if (node instanceof GraphQLType) {
125143
observedTypes.add((GraphQLType) node);
126144
}
127-
if (node instanceof GraphQLInterfaceType) {
128-
final GraphQLSchemaElement parentNode = context.getParentNode();
129-
if(parentNode instanceof GraphQLObjectType && observedTypes.contains(parentNode)) {
130-
// This means the traversal reached this interface via a type that implements it. If that type
131-
// has already been observed, we don't need to continue traversing so we quit early.
132-
// This is just to avoid the streaming/filtering bellow
133-
return TraversalControl.QUIT;
134-
}
135-
136-
final List<GraphQLObjectType> implementations = graphQLSchema.getImplementations((GraphQLInterfaceType) node);
137-
138-
implementations.stream()
139-
.filter(impl -> !observedTypes.contains(impl))
140-
.forEach(impl -> new SchemaTraverser().depthFirst(this, Collections.singleton(impl)));
141-
}
142145

143146
return TraversalControl.CONTINUE;
144147
}

src/test/groovy/graphql/schema/transform/FieldVisibilitySchemaTransformationTest.groovy

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,4 +1134,48 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
11341134
restrictedSchema.getType("BillingStatus") != null
11351135
}
11361136

1137+
def "handles types that become visible via types reachable by interface that implements interface"() {
1138+
given:
1139+
GraphQLSchema schema = TestUtil.schema("""
1140+
1141+
directive @private on FIELD_DEFINITION
1142+
1143+
type Query {
1144+
account: Account
1145+
node: Node
1146+
}
1147+
1148+
type Account {
1149+
name: String
1150+
billingStatus: BillingStatus @private
1151+
}
1152+
1153+
type BillingStatus {
1154+
accountNumber: String
1155+
}
1156+
1157+
interface Node {
1158+
id: ID!
1159+
}
1160+
1161+
interface NamedNode implements Node {
1162+
id: ID!
1163+
name: String
1164+
}
1165+
1166+
type Billing implements Node & NamedNode {
1167+
id: ID!
1168+
name: String
1169+
status: BillingStatus
1170+
}
1171+
1172+
""")
1173+
1174+
when:
1175+
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(schema)
1176+
1177+
then:
1178+
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
1179+
restrictedSchema.getType("BillingStatus") != null
1180+
}
11371181
}

0 commit comments

Comments
 (0)