@@ -60,8 +60,8 @@ public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPre
6060 }
6161
6262 public final GraphQLSchema apply (GraphQLSchema schema ) {
63- Set <GraphQLType > observedBeforeTransform = new HashSet <>();
64- Set <GraphQLType > observedAfterTransform = new HashSet <>();
63+ Set <String > observedBeforeTransform = new LinkedHashSet <>();
64+ Set <String > observedAfterTransform = new LinkedHashSet <>();
6565 Set <GraphQLType > markedForRemovalTypes = new HashSet <>();
6666
6767 // query, mutation, and subscription types should not be removed
@@ -135,18 +135,22 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, Traverser
135135
136136 private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
137137
138- private final Set <GraphQLType > observedTypes ;
138+ private final Set <String > observedTypes ;
139139
140140
141- private TypeObservingVisitor (Set <GraphQLType > observedTypes ) {
141+ private TypeObservingVisitor (Set <String > observedTypes ) {
142142 this .observedTypes = observedTypes ;
143143 }
144144
145145 @ Override
146146 protected TraversalControl visitGraphQLType (GraphQLSchemaElement node ,
147147 TraverserContext <GraphQLSchemaElement > context ) {
148- if (node instanceof GraphQLType ) {
149- observedTypes .add ((GraphQLType ) node );
148+ if (node instanceof GraphQLObjectType ||
149+ node instanceof GraphQLEnumType ||
150+ node instanceof GraphQLInputObjectType ||
151+ node instanceof GraphQLInterfaceType ||
152+ node instanceof GraphQLUnionType ) {
153+ observedTypes .add (((GraphQLNamedType ) node ).getName ());
150154 }
151155
152156 return TraversalControl .CONTINUE ;
@@ -243,12 +247,12 @@ public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField def
243247 private static class TypeVisibilityVisitor extends GraphQLTypeVisitorStub {
244248
245249 private final Set <String > protectedTypeNames ;
246- private final Set <GraphQLType > observedBeforeTransform ;
247- private final Set <GraphQLType > observedAfterTransform ;
250+ private final Set <String > observedBeforeTransform ;
251+ private final Set <String > observedAfterTransform ;
248252
249253 private TypeVisibilityVisitor (Set <String > protectedTypeNames ,
250- Set <GraphQLType > observedTypes ,
251- Set <GraphQLType > observedAfterTransform ) {
254+ Set <String > observedTypes ,
255+ Set <String > observedAfterTransform ) {
252256 this .protectedTypeNames = protectedTypeNames ;
253257 this .observedBeforeTransform = observedTypes ;
254258 this .observedAfterTransform = observedAfterTransform ;
@@ -263,17 +267,19 @@ public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
263267 @ Override
264268 public TraversalControl visitGraphQLType (GraphQLSchemaElement node ,
265269 TraverserContext <GraphQLSchemaElement > context ) {
266- if (observedBeforeTransform .contains (node ) &&
267- !observedAfterTransform .contains (node ) &&
268- (node instanceof GraphQLObjectType ||
269- node instanceof GraphQLEnumType ||
270- node instanceof GraphQLInputObjectType ||
271- node instanceof GraphQLInterfaceType ||
272- node instanceof GraphQLUnionType )) {
273-
274- return deleteNode (context );
270+ if (node instanceof GraphQLObjectType ||
271+ node instanceof GraphQLEnumType ||
272+ node instanceof GraphQLInputObjectType ||
273+ node instanceof GraphQLInterfaceType ||
274+ node instanceof GraphQLUnionType ) {
275+ String name = ((GraphQLNamedType ) node ).getName ();
276+ if (observedBeforeTransform .contains (name ) &&
277+ !observedAfterTransform .contains (name )
278+ && !protectedTypeNames .contains (name )
279+ ) {
280+ return deleteNode (context );
281+ }
275282 }
276-
277283 return TraversalControl .CONTINUE ;
278284 }
279285 }
0 commit comments