Skip to content

Commit 0e6edd2

Browse files
committed
bugfix: querytraversal should also work with mutations and subscriptions
1 parent e2411b6 commit 0e6edd2

File tree

2 files changed

+95
-2
lines changed

2 files changed

+95
-2
lines changed

src/main/java/graphql/analysis/QueryTraversal.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
import graphql.schema.GraphQLCompositeType;
1717
import graphql.schema.GraphQLFieldDefinition;
1818
import graphql.schema.GraphQLFieldsContainer;
19+
import graphql.schema.GraphQLObjectType;
1920
import graphql.schema.GraphQLSchema;
2021

2122
import java.util.LinkedHashMap;
2223
import java.util.Map;
2324

25+
import static graphql.Assert.assertNotNull;
26+
import static graphql.Assert.assertShouldNeverHappen;
27+
2428
@Internal
2529
public class QueryTraversal {
2630

@@ -47,11 +51,23 @@ public QueryTraversal(GraphQLSchema schema,
4751
}
4852

4953
public void visitPostOrder(QueryVisitor visitor) {
50-
visitImpl(visitor, operationDefinition.getSelectionSet(), schema.getQueryType(), null, false);
54+
visitImpl(visitor, operationDefinition.getSelectionSet(), getRootType(), null, false);
5155
}
5256

5357
public void visitPreOrder(QueryVisitor visitor) {
54-
visitImpl(visitor, operationDefinition.getSelectionSet(), schema.getQueryType(), null, true);
58+
visitImpl(visitor, operationDefinition.getSelectionSet(), getRootType(), null, true);
59+
}
60+
61+
private GraphQLObjectType getRootType() {
62+
if (operationDefinition.getOperation() == OperationDefinition.Operation.MUTATION) {
63+
return assertNotNull(schema.getMutationType());
64+
} else if (operationDefinition.getOperation() == OperationDefinition.Operation.QUERY) {
65+
return assertNotNull(schema.getQueryType());
66+
} else if (operationDefinition.getOperation() == OperationDefinition.Operation.SUBSCRIPTION) {
67+
return assertNotNull(schema.getSubscriptionType());
68+
} else {
69+
return assertShouldNeverHappen();
70+
}
5571
}
5672

5773
public <T> T reducePostOrder(QueryReducer<T> queryReducer, T initialValue) {

src/test/groovy/graphql/analysis/QueryTraversalTest.groovy

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,83 @@ class QueryTraversalTest extends Specification {
8989

9090
}
9191

92+
def "works for mutations()"() {
93+
given:
94+
def schema = TestUtil.schema("""
95+
type Query {
96+
a: String
97+
}
98+
type Mutation{
99+
foo: Foo
100+
bar: String
101+
}
102+
type Foo {
103+
subFoo: String
104+
}
105+
schema {mutation: Mutation, query: Query}
106+
""")
107+
def visitor = Mock(QueryVisitor)
108+
def query = createQuery("""
109+
mutation M{bar foo { subFoo} }
110+
""")
111+
QueryTraversal queryTraversal = createQueryTraversal(query, schema)
112+
when:
113+
queryTraversal."$visitFn"(visitor)
114+
115+
then:
116+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Mutation" })
117+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Mutation" })
118+
1 * visitor.visitField({ QueryVisitorEnvironment it ->
119+
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
120+
it.parentType.name == "Foo" &&
121+
it.parentEnvironment.field.name == "foo" && it.parentEnvironment.fieldDefinition.type.name == "Foo"
122+
})
123+
124+
where:
125+
order | visitFn
126+
'postOrder' | 'visitPostOrder'
127+
'preOrder' | 'visitPreOrder'
128+
129+
}
130+
131+
def "works for subscriptions()"() {
132+
given:
133+
def schema = TestUtil.schema("""
134+
type Query {
135+
a: String
136+
}
137+
type Subscription{
138+
foo: Foo
139+
bar: String
140+
}
141+
type Foo {
142+
subFoo: String
143+
}
144+
schema {subscription: Subscription, query: Query}
145+
""")
146+
def visitor = Mock(QueryVisitor)
147+
def query = createQuery("""
148+
subscription S{bar foo { subFoo} }
149+
""")
150+
QueryTraversal queryTraversal = createQueryTraversal(query, schema)
151+
when:
152+
queryTraversal."$visitFn"(visitor)
153+
154+
then:
155+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "foo" && it.fieldDefinition.type.name == "Foo" && it.parentType.name == "Subscription" })
156+
1 * visitor.visitField({ QueryVisitorEnvironment it -> it.field.name == "bar" && it.fieldDefinition.type.name == "String" && it.parentType.name == "Subscription" })
157+
1 * visitor.visitField({ QueryVisitorEnvironment it ->
158+
it.field.name == "subFoo" && it.fieldDefinition.type.name == "String" &&
159+
it.parentType.name == "Foo" &&
160+
it.parentEnvironment.field.name == "foo" && it.parentEnvironment.fieldDefinition.type.name == "Foo"
161+
})
162+
163+
where:
164+
order | visitFn
165+
'postOrder' | 'visitPostOrder'
166+
'preOrder' | 'visitPreOrder'
167+
168+
}
92169
93170
@Unroll
94171
def "field with arguments: (#order)"() {

0 commit comments

Comments
 (0)