Skip to content

Commit 8407043

Browse files
committed
handle case where all fields are removed via Field visibility transformer
1 parent 021fe1b commit 8407043

2 files changed

Lines changed: 83 additions & 9 deletions

File tree

src/main/java/graphql/schema/transform/FieldVisibilitySchemaTransformation.java

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import graphql.PublicApi;
55
import graphql.schema.GraphQLEnumType;
66
import graphql.schema.GraphQLFieldDefinition;
7+
import graphql.schema.GraphQLFieldsContainer;
78
import graphql.schema.GraphQLImplementingType;
89
import graphql.schema.GraphQLInputObjectField;
910
import graphql.schema.GraphQLInputObjectType;
@@ -24,6 +25,7 @@
2425

2526
import java.util.ArrayList;
2627
import java.util.HashSet;
28+
import java.util.LinkedHashSet;
2729
import java.util.List;
2830
import java.util.Map;
2931
import java.util.Objects;
@@ -45,7 +47,9 @@ public class FieldVisibilitySchemaTransformation {
4547
private final Runnable afterTransformationHook;
4648

4749
public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate) {
48-
this(visibleFieldPredicate, () -> {}, () -> {});
50+
this(visibleFieldPredicate, () -> {
51+
}, () -> {
52+
});
4953
}
5054

5155
public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPredicate,
@@ -155,16 +159,53 @@ private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {
155159
private final VisibleFieldPredicate visibilityPredicate;
156160
private final Set<GraphQLType> removedTypes;
157161

162+
private final Set<GraphQLFieldDefinition> fieldDefinitionsToActuallyRemove = new LinkedHashSet<>();
163+
158164
private FieldRemovalVisitor(VisibleFieldPredicate visibilityPredicate,
159165
Set<GraphQLType> removedTypes) {
160166
this.visibilityPredicate = visibilityPredicate;
161167
this.removedTypes = removedTypes;
162168
}
163169

170+
@Override
171+
public TraversalControl visitGraphQLObjectType(GraphQLObjectType objectType, TraverserContext<GraphQLSchemaElement> context) {
172+
return visitFieldsContainer(objectType, context);
173+
}
174+
175+
@Override
176+
public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType objectType, TraverserContext<GraphQLSchemaElement> context) {
177+
return visitFieldsContainer(objectType, context);
178+
}
179+
180+
private TraversalControl visitFieldsContainer(GraphQLFieldsContainer fieldsContainer, TraverserContext<GraphQLSchemaElement> context) {
181+
boolean allFieldsDeleted = true;
182+
for (GraphQLFieldDefinition fieldDefinition : fieldsContainer.getFieldDefinitions()) {
183+
VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl(
184+
fieldDefinition, fieldsContainer);
185+
if (!visibilityPredicate.isVisible(environment)) {
186+
fieldDefinitionsToActuallyRemove.add(fieldDefinition);
187+
removedTypes.add(fieldDefinition.getType());
188+
} else {
189+
allFieldsDeleted = false;
190+
}
191+
}
192+
if (allFieldsDeleted) {
193+
// we are deleting the whole interface type because all fields are supposed to be deleted
194+
return deleteNode(context);
195+
} else {
196+
return TraversalControl.CONTINUE;
197+
}
198+
}
199+
200+
164201
@Override
165202
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition definition,
166203
TraverserContext<GraphQLSchemaElement> context) {
167-
return visitField(definition, context);
204+
if (fieldDefinitionsToActuallyRemove.contains(definition)) {
205+
return deleteNode(context);
206+
} else {
207+
return TraversalControl.CONTINUE;
208+
}
168209
}
169210

170211
@Override
@@ -216,12 +257,12 @@ public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
216257
public TraversalControl visitGraphQLType(GraphQLSchemaElement node,
217258
TraverserContext<GraphQLSchemaElement> context) {
218259
if (observedBeforeTransform.contains(node) &&
219-
!observedAfterTransform.contains(node) &&
220-
(node instanceof GraphQLObjectType ||
221-
node instanceof GraphQLEnumType ||
222-
node instanceof GraphQLInputObjectType ||
223-
node instanceof GraphQLInterfaceType ||
224-
node instanceof GraphQLUnionType)) {
260+
!observedAfterTransform.contains(node) &&
261+
(node instanceof GraphQLObjectType ||
262+
node instanceof GraphQLEnumType ||
263+
node instanceof GraphQLInputObjectType ||
264+
node instanceof GraphQLInterfaceType ||
265+
node instanceof GraphQLUnionType)) {
225266

226267
return deleteNode(context);
227268
}

src/test/groovy/graphql/schema/transform/FieldVisibilitySchemaTransformationTest.groovy

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,7 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
10731073
def visibilitySchemaTransformation = new FieldVisibilitySchemaTransformation({ environment ->
10741074
def directives = (environment.schemaElement as GraphQLDirectiveContainer).appliedDirectives
10751075
return directives.find({ directive -> directive.name == "private" }) == null
1076-
}, { -> callbacks << "before" }, { -> callbacks << "after"} )
1076+
}, { -> callbacks << "before" }, { -> callbacks << "after" })
10771077

10781078
GraphQLSchema schema = TestUtil.schema("""
10791079
@@ -1245,5 +1245,38 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
12451245
then:
12461246
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
12471247
restrictedSchema.getType("BillingStatus") == null
1248+
1249+
}
1250+
1251+
def "remove all fields from a type which is referenced via additional types"() {
1252+
given:
1253+
GraphQLSchema schema = TestUtil.schema("""
1254+
directive @private on FIELD_DEFINITION
1255+
type Query {
1256+
foo: Foo
1257+
}
1258+
type Foo {
1259+
foo: String
1260+
toDelete: ToDelete @private
1261+
}
1262+
type ToDelete {
1263+
toDelete:String @private
1264+
}
1265+
""")
1266+
1267+
when:
1268+
schema.typeMap
1269+
def patchedSchema = schema.transform { builder ->
1270+
schema.typeMap.each { entry ->
1271+
def type = entry.value
1272+
if (type != schema.queryType && type != schema.mutationType && type != schema.subscriptionType) {
1273+
builder.additionalType(type)
1274+
}
1275+
}
1276+
}
1277+
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(patchedSchema)
1278+
then:
1279+
(restrictedSchema.getType("Foo") as GraphQLObjectType).getFieldDefinition("toDelete") == null
12481280
}
1281+
12491282
}

0 commit comments

Comments
 (0)