1212import graphql .schema .GraphQLSchema ;
1313import graphql .schema .GraphQLSchemaElement ;
1414import graphql .schema .GraphQLType ;
15- import graphql .schema .GraphQLTypeReference ;
1615import graphql .schema .GraphQLTypeVisitorStub ;
17- import graphql .schema .idl . ScalarInfo ;
16+ import graphql .schema .SchemaTraverser ;
1817import graphql .schema .transform .VisibleFieldPredicateEnvironment .VisibleFieldPredicateEnvironmentImpl ;
1918import graphql .util .TraversalControl ;
2019import graphql .util .TraverserContext ;
21- import java .util .Arrays ;
2220import java .util .HashSet ;
21+ import java .util .List ;
2322import java .util .Objects ;
2423import java .util .Set ;
2524import java .util .stream .Collectors ;
25+ import java .util .stream .Stream ;
2626
2727/**
2828 * Transforms a schema by applying a visibility predicate to every field.
@@ -39,68 +39,71 @@ public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPre
3939 /**
4040 * Before and after callbacks useful for side effects (logs, stopwatches etc).
4141 */
42- protected void beforeTransformation () {}
42+ protected void beforeTransformation () {
43+ }
4344
44- protected void afterTransformation () {}
45+ protected void afterTransformation () {
46+ }
4547
4648 public final GraphQLSchema apply (GraphQLSchema schema ) {
47- Set <GraphQLType > observedTypes = new HashSet <>();
49+ Set <GraphQLType > observedBeforeTransform = new HashSet <>();
50+ Set <GraphQLType > observedAfterTransform = new HashSet <>();
4851 Set <GraphQLType > removedTypes = new HashSet <>();
4952
5053 // query, mutation, and subscription types should not be removed
51- final Set <String > protectedTypeNames = new HashSet <>(Arrays .asList (
52- schema .getQueryType (),
53- schema .getSubscriptionType (),
54- schema .getMutationType ()
55- )).stream ()
56- .filter (Objects ::nonNull )
54+ final Set <String > protectedTypeNames = getRootTypes (schema ).stream ()
5755 .map (GraphQLObjectType ::getName )
5856 .collect (Collectors .toSet ());
5957
6058 beforeTransformation ();
6159
60+ new SchemaTraverser ().depthFirst (new TypeObservingVisitor (observedBeforeTransform ), getRootTypes (schema ));
61+
6262 // remove fields
63- GraphQLSchema interimSchema = transformSchema (schema , new FieldVisibilityVisitor (visibleFieldPredicate ,
64- removedTypes , observedTypes ));
63+ GraphQLSchema interimSchema = transformSchema (schema ,
64+ new FieldRemovalVisitor (visibleFieldPredicate , removedTypes ));
65+
66+ new SchemaTraverser ().depthFirst (new TypeObservingVisitor (observedAfterTransform ), getRootTypes (interimSchema ));
6567
6668 // remove types that are not used
6769 GraphQLSchema finalSchema = transformSchema (interimSchema ,
68- new TypeVisibilityVisitor (protectedTypeNames , observedTypes , removedTypes ));
70+ new TypeVisibilityVisitor (protectedTypeNames , observedBeforeTransform , observedAfterTransform ,
71+ removedTypes ));
6972
7073 afterTransformation ();
7174
7275 return finalSchema ;
7376 }
7477
75- private static class FieldVisibilityVisitor extends GraphQLTypeVisitorStub {
78+ private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {
7679
77- private final VisibleFieldPredicate visibilityPredicate ;
78- private final Set <GraphQLType > removedTypes ;
7980 private final Set <GraphQLType > observedTypes ;
8081
81- private FieldVisibilityVisitor (VisibleFieldPredicate visibilityPredicate ,
82- Set <GraphQLType > removedTypes , Set <GraphQLType > observedTypes ) {
83- this .visibilityPredicate = visibilityPredicate ;
84- this .removedTypes = removedTypes ;
82+
83+ private TypeObservingVisitor (Set <GraphQLType > observedTypes ) {
8584 this .observedTypes = observedTypes ;
8685 }
8786
8887 @ Override
89- public TraversalControl visitGraphQLInterfaceType (GraphQLInterfaceType node ,
90- TraverserContext <GraphQLSchemaElement > context ) {
91- return visitType (node , context );
92- }
88+ protected TraversalControl visitGraphQLType (GraphQLSchemaElement node ,
89+ TraverserContext <GraphQLSchemaElement > context ) {
90+ if (node instanceof GraphQLType ) {
91+ observedTypes .add ((GraphQLType ) node );
92+ }
9393
94- @ Override
95- public TraversalControl visitGraphQLObjectType (GraphQLObjectType node ,
96- TraverserContext <GraphQLSchemaElement > context ) {
97- return visitType (node , context );
94+ return TraversalControl .CONTINUE ;
9895 }
96+ }
9997
100- @ Override
101- public TraversalControl visitGraphQLTypeReference (GraphQLTypeReference node ,
102- TraverserContext <GraphQLSchemaElement > context ) {
103- return TraversalControl .CONTINUE ;
98+ private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {
99+
100+ private final VisibleFieldPredicate visibilityPredicate ;
101+ private final Set <GraphQLType > removedTypes ;
102+
103+ private FieldRemovalVisitor (VisibleFieldPredicate visibilityPredicate ,
104+ Set <GraphQLType > removedTypes ) {
105+ this .visibilityPredicate = visibilityPredicate ;
106+ this .removedTypes = removedTypes ;
104107 }
105108
106109 @ Override
@@ -115,25 +118,6 @@ public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField def
115118 return visitField (definition , context );
116119 }
117120
118- @ Override
119- public TraversalControl visitBackRef (TraverserContext <GraphQLSchemaElement > context ) {
120- if (context .thisNode () instanceof GraphQLInterfaceType || context .thisNode () instanceof GraphQLObjectType ) {
121- return visitType ((GraphQLType ) context .thisNode (), context );
122- }
123-
124- return TraversalControl .CONTINUE ;
125- }
126-
127- private TraversalControl visitType (GraphQLType type ,
128- TraverserContext <GraphQLSchemaElement > context ) {
129- if (context .getBreadcrumbs ().stream ()
130- .noneMatch (crumb -> crumb .getLocation ().getName ().equalsIgnoreCase ("addTypes" ))) {
131- observedTypes .add (type );
132- }
133-
134- return TraversalControl .CONTINUE ;
135- }
136-
137121 private TraversalControl visitField (GraphQLNamedSchemaElement element ,
138122 TraverserContext <GraphQLSchemaElement > context ) {
139123
@@ -156,14 +140,17 @@ private TraversalControl visitField(GraphQLNamedSchemaElement element,
156140 private static class TypeVisibilityVisitor extends GraphQLTypeVisitorStub {
157141
158142 private final Set <String > protectedTypeNames ;
159- private final Set <GraphQLType > observedTypes ;
143+ private final Set <GraphQLType > observedBeforeTransform ;
144+ private final Set <GraphQLType > observedAfterTransform ;
160145 private final Set <GraphQLType > removedTypes ;
161146
162147 private TypeVisibilityVisitor (Set <String > protectedTypeNames ,
163148 Set <GraphQLType > observedTypes ,
149+ Set <GraphQLType > observedAfterTransform ,
164150 Set <GraphQLType > removedTypes ) {
165151 this .protectedTypeNames = protectedTypeNames ;
166- this .observedTypes = observedTypes ;
152+ this .observedBeforeTransform = observedTypes ;
153+ this .observedAfterTransform = observedAfterTransform ;
167154 this .removedTypes = removedTypes ;
168155 }
169156
@@ -174,19 +161,34 @@ public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
174161 }
175162
176163 @ Override
177- public TraversalControl visitGraphQLObjectType (GraphQLObjectType node ,
178- TraverserContext <GraphQLSchemaElement > context ) {
179- if (!observedTypes .contains (node ) &&
180- node .getInterfaces ().stream ().noneMatch (observedTypes ::contains ) &&
181- node .getInterfaces ().stream ().anyMatch (removedTypes ::contains ) &&
182- !ScalarInfo .isStandardScalar (node .getName ()) &&
183- !protectedTypeNames .contains (node .getName ())) {
164+ public TraversalControl visitGraphQLType (GraphQLSchemaElement node ,
165+ TraverserContext <GraphQLSchemaElement > context ) {
166+
167+ if (observedBeforeTransform .contains (node ) &&
168+ !observedAfterTransform .contains (node )) {
184169 return deleteNode (context );
185170 }
186171
187172 return TraversalControl .CONTINUE ;
173+
174+ // if (!observedBeforeTransform.contains(node) &&
175+ // node.getInterfaces().stream().noneMatch(observedBeforeTransform::contains) &&
176+ // node.getInterfaces().stream().anyMatch(removedTypes::contains) &&
177+ // !ScalarInfo.isStandardScalar(node.getName()) &&
178+ // !protectedTypeNames.contains(node.getName())) {
179+ // return deleteNode(context);
180+ // }
181+ //
182+ // return TraversalControl.CONTINUE;
188183 }
189184 }
190185
186+ private List <GraphQLObjectType > getRootTypes (GraphQLSchema schema ) {
187+ return Stream .of (
188+ schema .getQueryType (),
189+ schema .getSubscriptionType (),
190+ schema .getMutationType ()
191+ ).filter (Objects ::nonNull ).collect (Collectors .toList ());
192+ }
191193
192194}
0 commit comments