Skip to content

Commit 52a7b6e

Browse files
authored
Merge pull request graphql-java#1445 from tsroka/visit-fragments
Add support for visiting fragment definitions in QueryVisitor.
2 parents 93b7b58 + 8b58e3e commit 52a7b6e

6 files changed

Lines changed: 138 additions & 3 deletions

File tree

src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,13 @@ public TraversalControl visitFragmentDefinition(FragmentDefinition node, Travers
8989
return TraversalControl.ABORT;
9090
}
9191

92+
QueryVisitorFragmentDefinitionEnvironment fragmentEnvironment = new QueryVisitorFragmentDefinitionEnvironmentImpl(node, context);
93+
9294
if (context.getVar(NodeTraverser.LeaveOrEnter.class) == LEAVE) {
95+
postOrderCallback.visitFragmentDefinition(fragmentEnvironment);
9396
return TraversalControl.CONTINUE;
9497
}
95-
98+
preOrderCallback.visitFragmentDefinition(fragmentEnvironment);
9699

97100
QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class);
98101
GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(node.getTypeCondition().getName());

src/main/java/graphql/analysis/QueryVisitor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,8 @@ public interface QueryVisitor {
1616

1717
void visitFragmentSpread(QueryVisitorFragmentSpreadEnvironment queryVisitorFragmentSpreadEnvironment);
1818

19+
default void visitFragmentDefinition(QueryVisitorFragmentDefinitionEnvironment queryVisitorFragmentDefinitionEnvironment) {
20+
21+
}
22+
1923
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package graphql.analysis;
2+
3+
import graphql.PublicApi;
4+
import graphql.language.FragmentDefinition;
5+
import graphql.language.Node;
6+
import graphql.util.TraverserContext;
7+
8+
@PublicApi
9+
public interface QueryVisitorFragmentDefinitionEnvironment {
10+
FragmentDefinition getFragmentDefinition();
11+
12+
TraverserContext<Node> getTraverserContext();
13+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package graphql.analysis;
2+
3+
import graphql.Internal;
4+
import graphql.language.FragmentDefinition;
5+
import graphql.language.Node;
6+
import graphql.util.TraverserContext;
7+
8+
import java.util.Objects;
9+
10+
@Internal
11+
public class QueryVisitorFragmentDefinitionEnvironmentImpl implements QueryVisitorFragmentDefinitionEnvironment {
12+
13+
private final FragmentDefinition fragmentDefinition;
14+
private final TraverserContext<Node> traverserContext;
15+
16+
17+
public QueryVisitorFragmentDefinitionEnvironmentImpl(FragmentDefinition fragmentDefinition, TraverserContext<Node> traverserContext) {
18+
this.fragmentDefinition = fragmentDefinition;
19+
this.traverserContext = traverserContext;
20+
}
21+
22+
@Override
23+
public FragmentDefinition getFragmentDefinition() {
24+
return fragmentDefinition;
25+
}
26+
27+
@Override
28+
public TraverserContext<Node> getTraverserContext() {
29+
return traverserContext;
30+
}
31+
32+
@Override
33+
public boolean equals(Object o) {
34+
if (this == o) {
35+
return true;
36+
}
37+
if (o == null || getClass() != o.getClass()) {
38+
return false;
39+
}
40+
QueryVisitorFragmentDefinitionEnvironmentImpl that = (QueryVisitorFragmentDefinitionEnvironmentImpl) o;
41+
return Objects.equals(fragmentDefinition, that.fragmentDefinition);
42+
}
43+
44+
@Override
45+
public int hashCode() {
46+
return Objects.hash(fragmentDefinition);
47+
}
48+
49+
@Override
50+
public String toString() {
51+
return "QueryVisitorFragmentDefinitionEnvironmentImpl{" +
52+
"fragmentDefinition=" + fragmentDefinition +
53+
'}';
54+
}
55+
}
56+

src/test/groovy/graphql/analysis/QueryTransformerTest.groovy

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import graphql.language.Document
55
import graphql.language.Field
66
import graphql.language.NodeUtil
77
import graphql.language.SelectionSet
8+
import graphql.language.TypeName
89
import graphql.parser.Parser
910
import graphql.schema.GraphQLSchema
1011
import spock.lang.Specification
@@ -207,7 +208,7 @@ class QueryTransformerTest extends Specification {
207208
0 * _
208209
}
209210

210-
def "named fragment is traversed if it is a root and can be transformed"() {
211+
def "fragment definition is traversed if it is a root and can be transformed"() {
211212
def query = TestUtil.parseQuery('''
212213
{
213214
root {
@@ -241,13 +242,22 @@ class QueryTransformerTest extends Specification {
241242
})
242243
}
243244
}
245+
246+
@Override
247+
void visitFragmentDefinition(QueryVisitorFragmentDefinitionEnvironment env) {
248+
def changed = env.fragmentDefinition.transform({ builder ->
249+
builder.typeCondition(TypeName.newTypeName("newTypeName").build())
250+
.name("newFragName")
251+
})
252+
changeNode(env.traverserContext, changed)
253+
}
244254
}
245255

246256

247257
when:
248258
def newFragment = queryTransformer.transform(visitor)
249259
then:
250260
printAstCompact(newFragment) ==
251-
"fragment frag on Root {fooA {midA {newChild1 newChild2}}}"
261+
"fragment newFragName on newTypeName {fooA {midA {newChild1 newChild2}}}"
252262
}
253263
}

src/test/groovy/graphql/analysis/QueryTraversalTest.groovy

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import graphql.language.FragmentDefinition
77
import graphql.language.FragmentSpread
88
import graphql.language.InlineFragment
99
import graphql.language.NodeTraverser
10+
import graphql.language.NodeUtil
1011
import graphql.parser.Parser
1112
import graphql.schema.GraphQLNonNull
1213
import graphql.schema.GraphQLObjectType
@@ -302,6 +303,54 @@ class QueryTraversalTest extends Specification {
302303

303304
}
304305

306+
307+
def "test preOrder and postOrder order for fragment definitions"() {
308+
given:
309+
def schema = TestUtil.schema("""
310+
type Query{
311+
foo: Foo
312+
bar: String
313+
}
314+
type Foo {
315+
subFoo: String
316+
}
317+
""")
318+
def visitor = Mock(QueryVisitor)
319+
def query = createQuery("""
320+
{
321+
...F1
322+
}
323+
324+
fragment F1 on Query {
325+
foo {
326+
subFoo
327+
}
328+
}
329+
""")
330+
331+
def fragments = NodeUtil.getFragmentsByName(query)
332+
333+
QueryTraversal queryTraversal = QueryTraversal.newQueryTraversal()
334+
.schema(schema)
335+
.root(fragments["F1"])
336+
.rootParentType(schema.getQueryType())
337+
.fragmentsByName(fragments)
338+
.variables([:])
339+
.build()
340+
341+
when:
342+
queryTraversal.visitPreOrder(visitor)
343+
344+
then:
345+
1 * visitor.visitFragmentDefinition({ QueryVisitorFragmentDefinitionEnvironment env -> env.fragmentDefinition == fragments["F1"] })
346+
347+
when:
348+
queryTraversal.visitPostOrder(visitor)
349+
350+
then:
351+
1 * visitor.visitFragmentDefinition({ QueryVisitorFragmentDefinitionEnvironment env -> env.fragmentDefinition == fragments["F1"] })
352+
}
353+
305354
def "works for mutations()"() {
306355
given:
307356
def schema = TestUtil.schema("""

0 commit comments

Comments
 (0)