33import graphql .PublicApi ;
44import graphql .schema .GraphQLEnumType ;
55import graphql .schema .GraphQLFieldDefinition ;
6+ import graphql .schema .GraphQLImplementingType ;
67import graphql .schema .GraphQLInputObjectField ;
78import graphql .schema .GraphQLInterfaceType ;
89import graphql .schema .GraphQLNamedSchemaElement ;
1415import graphql .schema .GraphQLTypeVisitorStub ;
1516import graphql .schema .GraphQLUnionType ;
1617import graphql .schema .SchemaTraverser ;
18+ import graphql .schema .impl .SchemaUtil ;
1719import graphql .schema .transform .VisibleFieldPredicateEnvironment .VisibleFieldPredicateEnvironmentImpl ;
1820import graphql .util .TraversalControl ;
1921import graphql .util .TraverserContext ;
2022
23+ import java .util .ArrayList ;
2124import java .util .HashSet ;
2225import java .util .List ;
26+ import java .util .Map ;
2327import java .util .Objects ;
2428import java .util .Set ;
29+ import java .util .function .Function ;
2530import java .util .stream .Collectors ;
2631import java .util .stream .Stream ;
2732
2833import static graphql .schema .SchemaTransformer .transformSchema ;
29- import static graphql .util .TreeTransformerUtil .deleteNode ;
3034
3135/**
3236 * Transforms a schema by applying a visibility predicate to every field.
@@ -62,13 +66,13 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
6266
6367 beforeTransformationHook .run ();
6468
65- new SchemaTraverser () .depthFirst (new TypeObservingVisitor (observedBeforeTransform , schema ), getRootTypes (schema ));
69+ new SchemaTraverser (getChildrenFn ( schema )) .depthFirst (new TypeObservingVisitor (observedBeforeTransform ), getRootTypes (schema ));
6670
6771 // remove fields
6872 GraphQLSchema interimSchema = transformSchema (schema ,
6973 new FieldRemovalVisitor (visibleFieldPredicate , markedForRemovalTypes ));
7074
71- new SchemaTraverser () .depthFirst (new TypeObservingVisitor (observedAfterTransform , interimSchema ), getRootTypes (interimSchema ));
75+ new SchemaTraverser (getChildrenFn ( interimSchema )) .depthFirst (new TypeObservingVisitor (observedAfterTransform ), getRootTypes (interimSchema ));
7276
7377 // remove types that are not used after removing fields - (connected schema only)
7478 GraphQLSchema connectedSchema = transformSchema (interimSchema ,
@@ -83,6 +87,23 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
8387 return finalSchema ;
8488 }
8589
90+ // Creates a getChildrenFn that includes interface
91+ private Function <GraphQLSchemaElement , List <GraphQLSchemaElement >> getChildrenFn (GraphQLSchema schema ) {
92+ Map <String , List <GraphQLImplementingType >> interfaceImplementations = new SchemaUtil ().groupImplementationsForInterfacesAndObjects (schema );
93+
94+ return graphQLSchemaElement -> {
95+ if (!(graphQLSchemaElement instanceof GraphQLInterfaceType )) {
96+ return graphQLSchemaElement .getChildren ();
97+ }
98+ ArrayList <GraphQLSchemaElement > children = new ArrayList <>(graphQLSchemaElement .getChildren ());
99+ List <GraphQLImplementingType > implementations = interfaceImplementations .get (((GraphQLInterfaceType ) graphQLSchemaElement ).getName ());
100+ if (implementations != null ) {
101+ children .addAll (implementations );
102+ }
103+ return children ;
104+ };
105+ }
106+
86107 private GraphQLSchema removeUnreferencedTypes (Set <GraphQLType > markedForRemovalTypes , GraphQLSchema connectedSchema ) {
87108 GraphQLSchema withoutAdditionalTypes = connectedSchema .transform (builder -> {
88109 Set <GraphQLType > additionalTypes = new HashSet <>(connectedSchema .getAdditionalTypes ());
@@ -109,12 +130,10 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, Traverser
109130 private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
110131
111132 private final Set <GraphQLType > observedTypes ;
112- private GraphQLSchema graphQLSchema ;
113133
114134
115- private TypeObservingVisitor (Set <GraphQLType > observedTypes , GraphQLSchema graphQLSchema ) {
135+ private TypeObservingVisitor (Set <GraphQLType > observedTypes ) {
116136 this .observedTypes = observedTypes ;
117- this .graphQLSchema = graphQLSchema ;
118137 }
119138
120139 @ Override
@@ -123,9 +142,6 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
123142 if (node instanceof GraphQLType ) {
124143 observedTypes .add ((GraphQLType ) node );
125144 }
126- if (node instanceof GraphQLInterfaceType ) {
127- observedTypes .addAll (graphQLSchema .getImplementations ((GraphQLInterfaceType ) node ));
128- }
129145
130146 return TraversalControl .CONTINUE ;
131147 }
0 commit comments