Skip to content

Commit eddbd26

Browse files
authored
Merge pull request graphql-java#2874 from felipe-gdr/field-visibility-schema-transform-supports-types-reachable-via-interfaces-only
Field visibility schema transform supports types reachable via interfaces only
2 parents 79a837b + 00bc317 commit eddbd26

2 files changed

Lines changed: 110 additions & 9 deletions

File tree

src/main/java/graphql/schema/transform/FieldVisibilitySchemaTransformation.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import graphql.PublicApi;
44
import graphql.schema.GraphQLEnumType;
55
import graphql.schema.GraphQLFieldDefinition;
6+
import graphql.schema.GraphQLImplementingType;
67
import graphql.schema.GraphQLInputObjectField;
78
import graphql.schema.GraphQLInterfaceType;
89
import graphql.schema.GraphQLNamedSchemaElement;
@@ -14,19 +15,22 @@
1415
import graphql.schema.GraphQLTypeVisitorStub;
1516
import graphql.schema.GraphQLUnionType;
1617
import graphql.schema.SchemaTraverser;
18+
import graphql.schema.impl.SchemaUtil;
1719
import graphql.schema.transform.VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl;
1820
import graphql.util.TraversalControl;
1921
import graphql.util.TraverserContext;
2022

23+
import java.util.ArrayList;
2124
import java.util.HashSet;
2225
import java.util.List;
26+
import java.util.Map;
2327
import java.util.Objects;
2428
import java.util.Set;
29+
import java.util.function.Function;
2530
import java.util.stream.Collectors;
2631
import java.util.stream.Stream;
2732

2833
import static graphql.schema.SchemaTransformer.transformSchema;
29-
import static graphql.util.TreeTransformerUtil.deleteNode;
3034

3135
/**
3236
* Transforms a schema by applying a visibility predicate to every field.
@@ -62,13 +66,13 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
6266

6367
beforeTransformationHook.run();
6468

65-
new SchemaTraverser().depthFirst(new TypeObservingVisitor(observedBeforeTransform, schema), getRootTypes(schema));
69+
new SchemaTraverser(getChildrenFn(schema)).depthFirst(new TypeObservingVisitor(observedBeforeTransform), getRootTypes(schema));
6670

6771
// remove fields
6872
GraphQLSchema interimSchema = transformSchema(schema,
6973
new FieldRemovalVisitor(visibleFieldPredicate, markedForRemovalTypes));
7074

71-
new SchemaTraverser().depthFirst(new TypeObservingVisitor(observedAfterTransform, interimSchema), getRootTypes(interimSchema));
75+
new SchemaTraverser(getChildrenFn(interimSchema)).depthFirst(new TypeObservingVisitor(observedAfterTransform), getRootTypes(interimSchema));
7276

7377
// remove types that are not used after removing fields - (connected schema only)
7478
GraphQLSchema connectedSchema = transformSchema(interimSchema,
@@ -83,6 +87,23 @@ public final GraphQLSchema apply(GraphQLSchema schema) {
8387
return finalSchema;
8488
}
8589

90+
// Creates a getChildrenFn that includes interface
91+
private Function<GraphQLSchemaElement, List<GraphQLSchemaElement>> getChildrenFn(GraphQLSchema schema) {
92+
Map<String, List<GraphQLImplementingType>> interfaceImplementations = new SchemaUtil().groupImplementationsForInterfacesAndObjects(schema);
93+
94+
return graphQLSchemaElement -> {
95+
if (!(graphQLSchemaElement instanceof GraphQLInterfaceType)) {
96+
return graphQLSchemaElement.getChildren();
97+
}
98+
ArrayList<GraphQLSchemaElement> children = new ArrayList<>(graphQLSchemaElement.getChildren());
99+
List<GraphQLImplementingType> implementations = interfaceImplementations.get(((GraphQLInterfaceType) graphQLSchemaElement).getName());
100+
if (implementations != null) {
101+
children.addAll(implementations);
102+
}
103+
return children;
104+
};
105+
}
106+
86107
private GraphQLSchema removeUnreferencedTypes(Set<GraphQLType> markedForRemovalTypes, GraphQLSchema connectedSchema) {
87108
GraphQLSchema withoutAdditionalTypes = connectedSchema.transform(builder -> {
88109
Set<GraphQLType> additionalTypes = new HashSet<>(connectedSchema.getAdditionalTypes());
@@ -109,12 +130,10 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, Traverser
109130
private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
110131

111132
private final Set<GraphQLType> observedTypes;
112-
private GraphQLSchema graphQLSchema;
113133

114134

115-
private TypeObservingVisitor(Set<GraphQLType> observedTypes, GraphQLSchema graphQLSchema) {
135+
private TypeObservingVisitor(Set<GraphQLType> observedTypes) {
116136
this.observedTypes = observedTypes;
117-
this.graphQLSchema = graphQLSchema;
118137
}
119138

120139
@Override
@@ -123,9 +142,6 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
123142
if (node instanceof GraphQLType) {
124143
observedTypes.add((GraphQLType) node);
125144
}
126-
if (node instanceof GraphQLInterfaceType) {
127-
observedTypes.addAll(graphQLSchema.getImplementations((GraphQLInterfaceType) node));
128-
}
129145

130146
return TraversalControl.CONTINUE;
131147
}

src/test/groovy/graphql/schema/transform/FieldVisibilitySchemaTransformationTest.groovy

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
4949

5050
then:
5151
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
52+
restrictedSchema.getType("BillingStatus") == null
5253
}
5354

5455
def "can remove a type associated with a private field"() {
@@ -1093,4 +1094,88 @@ class FieldVisibilitySchemaTransformationTest extends Specification {
10931094
then:
10941095
callbacks.containsAll(["before", "after"])
10951096
}
1097+
1098+
def "handles types that become visible via types reachable by interface only"() {
1099+
given:
1100+
GraphQLSchema schema = TestUtil.schema("""
1101+
1102+
directive @private on FIELD_DEFINITION
1103+
1104+
type Query {
1105+
account: Account
1106+
node: Node
1107+
}
1108+
1109+
type Account {
1110+
name: String
1111+
billingStatus: BillingStatus @private
1112+
}
1113+
1114+
type BillingStatus {
1115+
accountNumber: String
1116+
}
1117+
1118+
interface Node {
1119+
id: ID!
1120+
}
1121+
1122+
type Billing implements Node {
1123+
id: ID!
1124+
status: BillingStatus
1125+
}
1126+
1127+
""")
1128+
1129+
when:
1130+
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(schema)
1131+
1132+
then:
1133+
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
1134+
restrictedSchema.getType("BillingStatus") != null
1135+
}
1136+
1137+
def "handles types that become visible via types reachable by interface that implements interface"() {
1138+
given:
1139+
GraphQLSchema schema = TestUtil.schema("""
1140+
1141+
directive @private on FIELD_DEFINITION
1142+
1143+
type Query {
1144+
account: Account
1145+
node: Node
1146+
}
1147+
1148+
type Account {
1149+
name: String
1150+
billingStatus: BillingStatus @private
1151+
}
1152+
1153+
type BillingStatus {
1154+
accountNumber: String
1155+
}
1156+
1157+
interface Node {
1158+
id: ID!
1159+
}
1160+
1161+
interface NamedNode implements Node {
1162+
id: ID!
1163+
name: String
1164+
}
1165+
1166+
type Billing implements Node & NamedNode {
1167+
id: ID!
1168+
name: String
1169+
status: BillingStatus
1170+
}
1171+
1172+
""")
1173+
1174+
when:
1175+
GraphQLSchema restrictedSchema = visibilitySchemaTransformation.apply(schema)
1176+
1177+
then:
1178+
(restrictedSchema.getType("Account") as GraphQLObjectType).getFieldDefinition("billingStatus") == null
1179+
restrictedSchema.getType("BillingStatus") != null
1180+
}
10961181
}

0 commit comments

Comments
 (0)