99import graphql .schema .GraphQLInputObjectField ;
1010import graphql .schema .GraphQLInputObjectType ;
1111import graphql .schema .GraphQLInterfaceType ;
12- import graphql .schema .GraphQLNamedSchemaElement ;
1312import graphql .schema .GraphQLNamedType ;
1413import graphql .schema .GraphQLObjectType ;
1514import graphql .schema .GraphQLSchema ;
@@ -160,6 +159,7 @@ private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {
160159 private final Set <GraphQLType > removedTypes ;
161160
162161 private final Set <GraphQLFieldDefinition > fieldDefinitionsToActuallyRemove = new LinkedHashSet <>();
162+ private final Set <GraphQLInputObjectField > inputObjectFieldsToDelete = new LinkedHashSet <>();
163163
164164 private FieldRemovalVisitor (VisibleFieldPredicate visibilityPredicate ,
165165 Set <GraphQLType > removedTypes ) {
@@ -197,6 +197,27 @@ private TraversalControl visitFieldsContainer(GraphQLFieldsContainer fieldsConta
197197 }
198198 }
199199
200+ @ Override
201+ public TraversalControl visitGraphQLInputObjectType (GraphQLInputObjectType inputObjectType , TraverserContext <GraphQLSchemaElement > context ) {
202+ boolean allFieldsDeleted = true ;
203+ for (GraphQLInputObjectField inputField : inputObjectType .getFieldDefinitions ()) {
204+ VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl (
205+ inputField , inputObjectType );
206+ if (!visibilityPredicate .isVisible (environment )) {
207+ inputObjectFieldsToDelete .add (inputField );
208+ removedTypes .add (inputField .getType ());
209+ } else {
210+ allFieldsDeleted = false ;
211+ }
212+ }
213+ if (allFieldsDeleted ) {
214+ // we are deleting the whole input object type because all fields are supposed to be deleted
215+ return deleteNode (context );
216+ } else {
217+ return TraversalControl .CONTINUE ;
218+ }
219+
220+ }
200221
201222 @ Override
202223 public TraversalControl visitGraphQLFieldDefinition (GraphQLFieldDefinition definition ,
@@ -211,25 +232,11 @@ public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition defin
211232 @ Override
212233 public TraversalControl visitGraphQLInputObjectField (GraphQLInputObjectField definition ,
213234 TraverserContext <GraphQLSchemaElement > context ) {
214- return visitField (definition , context );
215- }
216-
217- private TraversalControl visitField (GraphQLNamedSchemaElement element ,
218- TraverserContext <GraphQLSchemaElement > context ) {
219-
220- VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl (
221- element , context .getParentNode ());
222- if (!visibilityPredicate .isVisible (environment )) {
223- deleteNode (context );
224-
225- if (element instanceof GraphQLFieldDefinition ) {
226- removedTypes .add (((GraphQLFieldDefinition ) element ).getType ());
227- } else if (element instanceof GraphQLInputObjectField ) {
228- removedTypes .add (((GraphQLInputObjectField ) element ).getType ());
229- }
235+ if (inputObjectFieldsToDelete .contains (definition )) {
236+ return deleteNode (context );
237+ } else {
238+ return TraversalControl .CONTINUE ;
230239 }
231-
232- return TraversalControl .CONTINUE ;
233240 }
234241 }
235242
0 commit comments