33import graphql .PublicApi ;
44import graphql .schema .GraphQLEnumType ;
55import graphql .schema .GraphQLFieldDefinition ;
6+ import graphql .schema .GraphQLImplementingType ;
67import graphql .schema .GraphQLInputObjectField ;
78import graphql .schema .GraphQLInterfaceType ;
89import graphql .schema .GraphQLNamedSchemaElement ;
1415import graphql .schema .GraphQLTypeVisitorStub ;
1516import graphql .schema .GraphQLUnionType ;
1617import graphql .schema .SchemaTraverser ;
18+ import graphql .schema .impl .SchemaUtil ;
1719import graphql .schema .transform .VisibleFieldPredicateEnvironment .VisibleFieldPredicateEnvironmentImpl ;
1820import graphql .util .TraversalControl ;
1921import graphql .util .TraverserContext ;
2022
21- import java .util .Collections ;
23+ import java .util .ArrayList ;
2224import java .util .HashSet ;
2325import java .util .List ;
26+ import java .util .Map ;
2427import java .util .Objects ;
2528import java .util .Set ;
29+ import java .util .function .Function ;
2630import java .util .stream .Collectors ;
2731import java .util .stream .Stream ;
2832
2933import static graphql .schema .SchemaTransformer .transformSchema ;
30- import static graphql .util .TreeTransformerUtil .deleteNode ;
3134
3235/**
3336 * Transforms a schema by applying a visibility predicate to every field.
@@ -63,13 +66,13 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
6366
6467 beforeTransformationHook .run ();
6568
66- new SchemaTraverser () .depthFirst (new TypeObservingVisitor (observedBeforeTransform , schema ), getRootTypes (schema ));
69+ new SchemaTraverser (getChildrenFn ( schema )) .depthFirst (new TypeObservingVisitor (observedBeforeTransform ), getRootTypes (schema ));
6770
6871 // remove fields
6972 GraphQLSchema interimSchema = transformSchema (schema ,
7073 new FieldRemovalVisitor (visibleFieldPredicate , markedForRemovalTypes ));
7174
72- new SchemaTraverser () .depthFirst (new TypeObservingVisitor (observedAfterTransform , interimSchema ), getRootTypes (interimSchema ));
75+ new SchemaTraverser (getChildrenFn ( interimSchema )) .depthFirst (new TypeObservingVisitor (observedAfterTransform ), getRootTypes (interimSchema ));
7376
7477 // remove types that are not used after removing fields - (connected schema only)
7578 GraphQLSchema connectedSchema = transformSchema (interimSchema ,
@@ -84,6 +87,23 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
8487 return finalSchema ;
8588 }
8689
90+ // Creates a getChildrenFn that includes interface
91+ private Function <GraphQLSchemaElement , List <GraphQLSchemaElement >> getChildrenFn (GraphQLSchema schema ) {
92+ Map <String , List <GraphQLImplementingType >> interfaceImplementations = new SchemaUtil ().groupImplementationsForInterfacesAndObjects (schema );
93+
94+ return graphQLSchemaElement -> {
95+ if (!(graphQLSchemaElement instanceof GraphQLInterfaceType )) {
96+ return graphQLSchemaElement .getChildren ();
97+ }
98+ ArrayList <GraphQLSchemaElement > children = new ArrayList <>(graphQLSchemaElement .getChildren ());
99+ List <GraphQLImplementingType > implementations = interfaceImplementations .get (((GraphQLInterfaceType ) graphQLSchemaElement ).getName ());
100+ if (implementations != null ) {
101+ children .addAll (implementations );
102+ }
103+ return children ;
104+ };
105+ }
106+
87107 private GraphQLSchema removeUnreferencedTypes (Set <GraphQLType > markedForRemovalTypes , GraphQLSchema connectedSchema ) {
88108 GraphQLSchema withoutAdditionalTypes = connectedSchema .transform (builder -> {
89109 Set <GraphQLType > additionalTypes = new HashSet <>(connectedSchema .getAdditionalTypes ());
@@ -110,12 +130,10 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, Traverser
110130 private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
111131
112132 private final Set <GraphQLType > observedTypes ;
113- private GraphQLSchema graphQLSchema ;
114133
115134
116- private TypeObservingVisitor (Set <GraphQLType > observedTypes , GraphQLSchema graphQLSchema ) {
135+ private TypeObservingVisitor (Set <GraphQLType > observedTypes ) {
117136 this .observedTypes = observedTypes ;
118- this .graphQLSchema = graphQLSchema ;
119137 }
120138
121139 @ Override
@@ -124,21 +142,6 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
124142 if (node instanceof GraphQLType ) {
125143 observedTypes .add ((GraphQLType ) node );
126144 }
127- if (node instanceof GraphQLInterfaceType ) {
128- final GraphQLSchemaElement parentNode = context .getParentNode ();
129- if (parentNode instanceof GraphQLObjectType && observedTypes .contains (parentNode )) {
130- // This means the traversal reached this interface via a type that implements it. If that type
131- // has already been observed, we don't need to continue traversing so we quit early.
132- // This is just to avoid the streaming/filtering bellow
133- return TraversalControl .QUIT ;
134- }
135-
136- final List <GraphQLObjectType > implementations = graphQLSchema .getImplementations ((GraphQLInterfaceType ) node );
137-
138- implementations .stream ()
139- .filter (impl -> !observedTypes .contains (impl ))
140- .forEach (impl -> new SchemaTraverser ().depthFirst (this , Collections .singleton (impl )));
141- }
142145
143146 return TraversalControl .CONTINUE ;
144147 }
0 commit comments