Skip to content

Commit ef26e8a

Browse files
author
Daniel Treacy
committed
Clean up types based on observation before and after traversal
1 parent 70cf9fb commit ef26e8a

2 files changed

Lines changed: 182 additions & 65 deletions

File tree

src/main/java/graphql/schema/transform/FieldVisibilitySchemaTransformation.java

Lines changed: 64 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
import graphql.schema.GraphQLSchema;
1313
import graphql.schema.GraphQLSchemaElement;
1414
import graphql.schema.GraphQLType;
15-
import graphql.schema.GraphQLTypeReference;
1615
import graphql.schema.GraphQLTypeVisitorStub;
17-
import graphql.schema.idl.ScalarInfo;
16+
import graphql.schema.SchemaTraverser;
1817
import graphql.schema.transform.VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl;
1918
import graphql.util.TraversalControl;
2019
import graphql.util.TraverserContext;
21-
import java.util.Arrays;
2220
import java.util.HashSet;
21+
import java.util.List;
2322
import java.util.Objects;
2423
import java.util.Set;
2524
import java.util.stream.Collectors;
25+
import java.util.stream.Stream;
2626

2727
/**
2828
* Transforms a schema by applying a visibility predicate to every field.
@@ -39,68 +39,71 @@ public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPre
3939
/**
4040
* Before and after callbacks useful for side effects (logs, stopwatches etc).
4141
*/
42-
protected void beforeTransformation() {}
42+
protected void beforeTransformation() {
43+
}
4344

44-
protected void afterTransformation() {}
45+
protected void afterTransformation() {
46+
}
4547

4648
public final GraphQLSchema apply(GraphQLSchema schema) {
47-
Set<GraphQLType> observedTypes = new HashSet<>();
49+
Set<GraphQLType> observedBeforeTransform = new HashSet<>();
50+
Set<GraphQLType> observedAfterTransform = new HashSet<>();
4851
Set<GraphQLType> removedTypes = new HashSet<>();
4952

5053
// query, mutation, and subscription types should not be removed
51-
final Set<String> protectedTypeNames = new HashSet<>(Arrays.asList(
52-
schema.getQueryType(),
53-
schema.getSubscriptionType(),
54-
schema.getMutationType()
55-
)).stream()
56-
.filter(Objects::nonNull)
54+
final Set<String> protectedTypeNames = getRootTypes(schema).stream()
5755
.map(GraphQLObjectType::getName)
5856
.collect(Collectors.toSet());
5957

6058
beforeTransformation();
6159

60+
new SchemaTraverser().depthFirst(new TypeObservingVisitor(observedBeforeTransform), getRootTypes(schema));
61+
6262
// remove fields
63-
GraphQLSchema interimSchema = transformSchema(schema, new FieldVisibilityVisitor(visibleFieldPredicate,
64-
removedTypes, observedTypes));
63+
GraphQLSchema interimSchema = transformSchema(schema,
64+
new FieldRemovalVisitor(visibleFieldPredicate, removedTypes));
65+
66+
new SchemaTraverser().depthFirst(new TypeObservingVisitor(observedAfterTransform), getRootTypes(interimSchema));
6567

6668
// remove types that are not used
6769
GraphQLSchema finalSchema = transformSchema(interimSchema,
68-
new TypeVisibilityVisitor(protectedTypeNames, observedTypes, removedTypes));
70+
new TypeVisibilityVisitor(protectedTypeNames, observedBeforeTransform, observedAfterTransform,
71+
removedTypes));
6972

7073
afterTransformation();
7174

7275
return finalSchema;
7376
}
7477

75-
private static class FieldVisibilityVisitor extends GraphQLTypeVisitorStub {
78+
private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
7679

77-
private final VisibleFieldPredicate visibilityPredicate;
78-
private final Set<GraphQLType> removedTypes;
7980
private final Set<GraphQLType> observedTypes;
8081

81-
private FieldVisibilityVisitor(VisibleFieldPredicate visibilityPredicate,
82-
Set<GraphQLType> removedTypes, Set<GraphQLType> observedTypes) {
83-
this.visibilityPredicate = visibilityPredicate;
84-
this.removedTypes = removedTypes;
82+
83+
private TypeObservingVisitor(Set<GraphQLType> observedTypes) {
8584
this.observedTypes = observedTypes;
8685
}
8786

8887
@Override
89-
public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
90-
TraverserContext<GraphQLSchemaElement> context) {
91-
return visitType(node, context);
92-
}
88+
protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
89+
TraverserContext<GraphQLSchemaElement> context) {
90+
if (node instanceof GraphQLType) {
91+
observedTypes.add((GraphQLType) node);
92+
}
9393

94-
@Override
95-
public TraversalControl visitGraphQLObjectType(GraphQLObjectType node,
96-
TraverserContext<GraphQLSchemaElement> context) {
97-
return visitType(node, context);
94+
return TraversalControl.CONTINUE;
9895
}
96+
}
9997

100-
@Override
101-
public TraversalControl visitGraphQLTypeReference(GraphQLTypeReference node,
102-
TraverserContext<GraphQLSchemaElement> context) {
103-
return TraversalControl.CONTINUE;
98+
private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {
99+
100+
private final VisibleFieldPredicate visibilityPredicate;
101+
private final Set<GraphQLType> removedTypes;
102+
103+
private FieldRemovalVisitor(VisibleFieldPredicate visibilityPredicate,
104+
Set<GraphQLType> removedTypes) {
105+
this.visibilityPredicate = visibilityPredicate;
106+
this.removedTypes = removedTypes;
104107
}
105108

106109
@Override
@@ -115,25 +118,6 @@ public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField def
115118
return visitField(definition, context);
116119
}
117120

118-
@Override
119-
public TraversalControl visitBackRef(TraverserContext<GraphQLSchemaElement> context) {
120-
if (context.thisNode() instanceof GraphQLInterfaceType || context.thisNode() instanceof GraphQLObjectType) {
121-
return visitType((GraphQLType) context.thisNode(), context);
122-
}
123-
124-
return TraversalControl.CONTINUE;
125-
}
126-
127-
private TraversalControl visitType(GraphQLType type,
128-
TraverserContext<GraphQLSchemaElement> context) {
129-
if (context.getBreadcrumbs().stream()
130-
.noneMatch(crumb -> crumb.getLocation().getName().equalsIgnoreCase("addTypes"))) {
131-
observedTypes.add(type);
132-
}
133-
134-
return TraversalControl.CONTINUE;
135-
}
136-
137121
private TraversalControl visitField(GraphQLNamedSchemaElement element,
138122
TraverserContext<GraphQLSchemaElement> context) {
139123

@@ -156,14 +140,17 @@ private TraversalControl visitField(GraphQLNamedSchemaElement element,
156140
private static class TypeVisibilityVisitor extends GraphQLTypeVisitorStub {
157141

158142
private final Set<String> protectedTypeNames;
159-
private final Set<GraphQLType> observedTypes;
143+
private final Set<GraphQLType> observedBeforeTransform;
144+
private final Set<GraphQLType> observedAfterTransform;
160145
private final Set<GraphQLType> removedTypes;
161146

162147
private TypeVisibilityVisitor(Set<String> protectedTypeNames,
163148
Set<GraphQLType> observedTypes,
149+
Set<GraphQLType> observedAfterTransform,
164150
Set<GraphQLType> removedTypes) {
165151
this.protectedTypeNames = protectedTypeNames;
166-
this.observedTypes = observedTypes;
152+
this.observedBeforeTransform = observedTypes;
153+
this.observedAfterTransform = observedAfterTransform;
167154
this.removedTypes = removedTypes;
168155
}
169156

@@ -174,19 +161,34 @@ public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
174161
}
175162

176163
@Override
177-
public TraversalControl visitGraphQLObjectType(GraphQLObjectType node,
178-
TraverserContext<GraphQLSchemaElement> context) {
179-
if (!observedTypes.contains(node) &&
180-
node.getInterfaces().stream().noneMatch(observedTypes::contains) &&
181-
node.getInterfaces().stream().anyMatch(removedTypes::contains) &&
182-
!ScalarInfo.isStandardScalar(node.getName()) &&
183-
!protectedTypeNames.contains(node.getName())) {
164+
public TraversalControl visitGraphQLType(GraphQLSchemaElement node,
165+
TraverserContext<GraphQLSchemaElement> context) {
166+
167+
if (observedBeforeTransform.contains(node) &&
168+
!observedAfterTransform.contains(node)) {
184169
return deleteNode(context);
185170
}
186171

187172
return TraversalControl.CONTINUE;
173+
174+
// if (!observedBeforeTransform.contains(node) &&
175+
// node.getInterfaces().stream().noneMatch(observedBeforeTransform::contains) &&
176+
// node.getInterfaces().stream().anyMatch(removedTypes::contains) &&
177+
// !ScalarInfo.isStandardScalar(node.getName()) &&
178+
// !protectedTypeNames.contains(node.getName())) {
179+
// return deleteNode(context);
180+
// }
181+
//
182+
// return TraversalControl.CONTINUE;
188183
}
189184
}
190185

186+
private List<GraphQLObjectType> getRootTypes(GraphQLSchema schema) {
187+
return Stream.of(
188+
schema.getQueryType(),
189+
schema.getSubscriptionType(),
190+
schema.getMutationType()
191+
).filter(Objects::nonNull).collect(Collectors.toList());
192+
}
191193

192194
}

src/test/groovy/graphql/schema/transform/FieldVisibilitySchemaTransformationTest.groovy

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package graphql.schema.transform
22

33
import graphql.Scalars
44
import graphql.TestUtil
5+
import graphql.schema.GraphQLDirectiveContainer
56
import graphql.schema.GraphQLInputObjectType
67
import graphql.schema.GraphQLObjectType
78
import graphql.schema.GraphQLSchema
@@ -18,7 +19,8 @@ import static graphql.schema.GraphQLTypeReference.typeRef
1819
class FieldVisibilitySchemaTransformationTest extends Specification {
1920

2021
def visibilitySchemaTransformation = new FieldVisibilitySchemaTransformation({ environment ->
21-
return environment.schemaElement.directives.find({ directive -> directive.name == "private" }) == null
22+
def directives = (environment.schemaElement as GraphQLDirectiveContainer).directives
23+
return directives.find({ directive -> directive.name == "private" }) == null
2224
})
2325

2426
def "can remove a private field"() {
@@ -355,6 +357,42 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
355357
restrictedSchema.getType("SuperSecretCustomerData") != null
356358
}
357359

360+
361+
def "removes interface types implemented by types used in a private field"() {
362+
given:
363+
GraphQLSchema schema = TestUtil.schema("""
364+
365+
directive @private on FIELD_DEFINITION
366+
367+
type Query {
368+
account: Account
369+
}
370+
371+
type Account {
372+
name: String
373+
billingStatus: BillingStatus @private
374+
}
375+
376+
type BillingStatus implements SuperSecretCustomerData {
377+
accountNumber: String
378+
cardLast4: Int
379+
}
380+
381+
interface SuperSecretCustomerData {
382+
cardLast4: Int
383+
}
384+
385+
""")
386+
387+
388+
when:
389+
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(schema)
390+
391+
then:
392+
restrictedSchema.getType("BillingStatus") == null
393+
restrictedSchema.getType("SuperSecretCustomerData") == null
394+
}
395+
358396
def "leaves interface type if has private and public reference"() {
359397

360398
given:
@@ -811,7 +849,7 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
811849
restrictedSchema.getType("Bing") == null
812850
}
813851

814-
def "use type references"() {
852+
def "use type references - private field declared with interface type removes both concrete and interface"() {
815853
given:
816854
def query = newObject()
817855
.name("Query")
@@ -846,14 +884,91 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
846884
.build()
847885
when:
848886

849-
(new SchemaPrinter()).print(schema)
887+
System.out.println((new SchemaPrinter()).print(schema))
850888
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(schema)
851889

852890
then:
853891
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
854892
restrictedSchema.getType("BillingStatus") == null
855893
restrictedSchema.getType("SuperSecretCustomerData") == null
894+
}
895+
896+
897+
def "use type references - private field declared with concrete type removes both concrete and interface"() {
898+
given:
899+
def query = newObject()
900+
.name("Query")
901+
.field(newFieldDefinition().name("account").type(typeRef("Account")).build())
902+
.build()
903+
904+
def privateDirective = newDirective().name("private").build()
905+
def account = newObject()
906+
.name("Account")
907+
.field(newFieldDefinition().name("name").type(Scalars.GraphQLString).build())
908+
.field(newFieldDefinition().name("billingStatus").type(typeRef("BillingStatus")).withDirective(privateDirective).build())
909+
.build()
910+
911+
def billingStatus = newObject()
912+
.name("BillingStatus")
913+
.field(newFieldDefinition().name("id").type(Scalars.GraphQLString).build())
914+
.withInterface(typeRef("SuperSecretCustomerData"))
915+
.build()
916+
917+
def secretData = newInterface()
918+
.name("SuperSecretCustomerData")
919+
.field(newFieldDefinition().name("id").type(Scalars.GraphQLString).build())
920+
.typeResolver(Mock(TypeResolver))
921+
.build()
856922

923+
def schema = GraphQLSchema.newSchema()
924+
.query(query)
925+
.additionalType(billingStatus)
926+
.additionalType(account)
927+
.additionalType(billingStatus)
928+
.additionalType(secretData)
929+
.build()
930+
when:
931+
932+
System.out.println((new SchemaPrinter()).print(schema))
933+
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(schema)
934+
935+
then:
936+
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
937+
restrictedSchema.getType("BillingStatus") == null
938+
restrictedSchema.getType("SuperSecretCustomerData") == null
857939
}
858940

941+
def "use type references - unreferenced types are removed"() {
942+
given:
943+
def query = newObject()
944+
.name("Query")
945+
.field(newFieldDefinition().name("account").type(typeRef("Account")).build())
946+
.build()
947+
948+
def privateDirective = newDirective().name("private").build()
949+
def account = newObject()
950+
.name("Account")
951+
.field(newFieldDefinition().name("name").type(Scalars.GraphQLString).build())
952+
.field(newFieldDefinition().name("billingStatus").type(typeRef("BillingStatus")).withDirective(privateDirective).build())
953+
.build()
954+
955+
def billingStatus = newObject()
956+
.name("BillingStatus")
957+
.field(newFieldDefinition().name("id").type(Scalars.GraphQLString).build())
958+
.build()
959+
960+
def schema = GraphQLSchema.newSchema()
961+
.query(query)
962+
.additionalType(billingStatus)
963+
.additionalType(account)
964+
.build()
965+
when:
966+
967+
System.out.println((new SchemaPrinter()).print(schema))
968+
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(schema)
969+
970+
then:
971+
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
972+
restrictedSchema.getType("BillingStatus") == null
973+
}
859974
}

0 commit comments

Comments
 (0)