Skip to content

Commit 33b59d6

Browse files
committed
write query directives from ENF to AST
1 parent a54bb43 commit 33b59d6

2 files changed

Lines changed: 93 additions & 7 deletions

File tree

src/main/java/graphql/normalized/ExecutableNormalizedOperationToAstCompiler.java

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import com.google.common.collect.ImmutableMap;
55
import graphql.Assert;
66
import graphql.PublicApi;
7+
import graphql.execution.directives.QueryDirectives;
78
import graphql.introspection.Introspection;
89
import graphql.language.Argument;
910
import graphql.language.ArrayValue;
11+
import graphql.language.Directive;
1012
import graphql.language.Document;
1113
import graphql.language.Field;
1214
import graphql.language.InlineFragment;
@@ -30,6 +32,7 @@
3032
import java.util.LinkedHashMap;
3133
import java.util.List;
3234
import java.util.Map;
35+
import java.util.stream.Collectors;
3336

3437
import static graphql.collect.ImmutableKit.emptyList;
3538
import static graphql.collect.ImmutableKit.map;
@@ -96,10 +99,20 @@ public static CompilerResult compileToDocument(@NotNull GraphQLSchema schema,
9699
@Nullable String operationName,
97100
@NotNull List<ExecutableNormalizedField> topLevelFields,
98101
@Nullable VariablePredicate variablePredicate) {
102+
return compileToDocument(schema,operationKind,operationName,topLevelFields,Map.of(),variablePredicate);
103+
}
104+
105+
106+
public static CompilerResult compileToDocument(@NotNull GraphQLSchema schema,
107+
@NotNull OperationDefinition.Operation operationKind,
108+
@Nullable String operationName,
109+
@NotNull List<ExecutableNormalizedField> topLevelFields,
110+
@NotNull Map<ExecutableNormalizedField, QueryDirectives> fieldToQueryDirectives,
111+
@Nullable VariablePredicate variablePredicate) {
99112
GraphQLObjectType operationType = getOperationType(schema, operationKind);
100113

101114
VariableAccumulator variableAccumulator = new VariableAccumulator(variablePredicate);
102-
List<Selection<?>> selections = subselectionsForNormalizedField(schema, operationType.getName(), topLevelFields, variableAccumulator);
115+
List<Selection<?>> selections = subselectionsForNormalizedField(schema, operationType.getName(), topLevelFields, fieldToQueryDirectives, variableAccumulator);
103116
SelectionSet selectionSet = new SelectionSet(selections);
104117

105118
OperationDefinition.Builder definitionBuilder = OperationDefinition.newOperationDefinition()
@@ -120,6 +133,7 @@ public static CompilerResult compileToDocument(@NotNull GraphQLSchema schema,
120133
private static List<Selection<?>> subselectionsForNormalizedField(GraphQLSchema schema,
121134
@NotNull String parentOutputType,
122135
List<ExecutableNormalizedField> executableNormalizedFields,
136+
@NotNull Map<ExecutableNormalizedField, QueryDirectives> normalizedFieldToQueryDirectives,
123137
VariableAccumulator variableAccumulator) {
124138
ImmutableList.Builder<Selection<?>> selections = ImmutableList.builder();
125139

@@ -129,13 +143,13 @@ private static List<Selection<?>> subselectionsForNormalizedField(GraphQLSchema
129143

130144
for (ExecutableNormalizedField nf : executableNormalizedFields) {
131145
if (nf.isConditional(schema)) {
132-
selectionForNormalizedField(schema, nf, variableAccumulator)
146+
selectionForNormalizedField(schema, nf, normalizedFieldToQueryDirectives, variableAccumulator)
133147
.forEach((objectTypeName, field) ->
134148
fieldsByTypeCondition
135149
.computeIfAbsent(objectTypeName, ignored -> new ArrayList<>())
136150
.add(field));
137151
} else {
138-
selections.add(selectionForNormalizedField(schema, parentOutputType, nf, variableAccumulator));
152+
selections.add(selectionForNormalizedField(schema, parentOutputType, nf, normalizedFieldToQueryDirectives,variableAccumulator));
139153
}
140154
}
141155

@@ -156,11 +170,12 @@ private static List<Selection<?>> subselectionsForNormalizedField(GraphQLSchema
156170
*/
157171
private static Map<String, Field> selectionForNormalizedField(GraphQLSchema schema,
158172
ExecutableNormalizedField executableNormalizedField,
173+
@NotNull Map<ExecutableNormalizedField, QueryDirectives> normalizedFieldToQueryDirectives,
159174
VariableAccumulator variableAccumulator) {
160175
Map<String, Field> groupedFields = new LinkedHashMap<>();
161176

162177
for (String objectTypeName : executableNormalizedField.getObjectTypeNames()) {
163-
groupedFields.put(objectTypeName, selectionForNormalizedField(schema, objectTypeName, executableNormalizedField, variableAccumulator));
178+
groupedFields.put(objectTypeName, selectionForNormalizedField(schema, objectTypeName, executableNormalizedField,normalizedFieldToQueryDirectives, variableAccumulator));
164179
}
165180

166181
return groupedFields;
@@ -172,6 +187,7 @@ private static Map<String, Field> selectionForNormalizedField(GraphQLSchema sche
172187
private static Field selectionForNormalizedField(GraphQLSchema schema,
173188
String objectTypeName,
174189
ExecutableNormalizedField executableNormalizedField,
190+
@NotNull Map<ExecutableNormalizedField, QueryDirectives> normalizedFieldToQueryDirectives,
175191
VariableAccumulator variableAccumulator) {
176192
final List<Selection<?>> subSelections;
177193
if (executableNormalizedField.getChildren().isEmpty()) {
@@ -184,19 +200,30 @@ private static Field selectionForNormalizedField(GraphQLSchema schema,
184200
schema,
185201
fieldOutputType.getName(),
186202
executableNormalizedField.getChildren(),
203+
normalizedFieldToQueryDirectives,
187204
variableAccumulator
188205
);
189206
}
190207

191208
SelectionSet selectionSet = selectionSetOrNullIfEmpty(subSelections);
192209
List<Argument> arguments = createArguments(executableNormalizedField, variableAccumulator);
193210

194-
return newField()
211+
QueryDirectives queryDirectives = normalizedFieldToQueryDirectives.get(executableNormalizedField);
212+
213+
214+
Field.Builder builder = newField()
195215
.name(executableNormalizedField.getFieldName())
196216
.alias(executableNormalizedField.getAlias())
197217
.selectionSet(selectionSet)
198-
.arguments(arguments)
199-
.build();
218+
.arguments(arguments);
219+
if(queryDirectives == null || queryDirectives.getImmediateAppliedDirectivesByField().isEmpty() ){
220+
return builder.build();
221+
}else {
222+
List<Directive> directives = queryDirectives.getImmediateAppliedDirectivesByField().keySet().stream().flatMap(field -> field.getDirectives().stream()).collect(Collectors.toList());
223+
return builder
224+
.directives(directives)
225+
.build();
226+
}
200227
}
201228

202229
@Nullable

src/test/groovy/graphql/normalized/ExecutableNormalizedOperationToAstCompilerTest.groovy

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ import graphql.GraphQL
44
import graphql.TestUtil
55
import graphql.execution.RawVariables
66
import graphql.language.AstPrinter
7+
import graphql.language.Field
8+
import graphql.language.OperationDefinition
79
import graphql.language.AstSorter
810
import graphql.language.Document
911
import graphql.language.IntValue
1012
import graphql.language.StringValue
13+
import graphql.parser.Parser
1114
import graphql.schema.GraphQLSchema
1215
import graphql.schema.idl.RuntimeWiring
1316
import graphql.schema.idl.TestLiveMockedWiringFactory
@@ -1238,6 +1241,62 @@ class ExecutableNormalizedOperationToAstCompilerTest extends Specification {
12381241
'''
12391242
}
12401243

1244+
1245+
1246+
def "test query directive"() {
1247+
def sdl = '''
1248+
type Query {
1249+
foo1(arg: I): String
1250+
1251+
}
1252+
type Subscription {
1253+
foo1(arg: I): DevOps
1254+
1255+
}
1256+
input I {
1257+
arg1: String
1258+
}
1259+
1260+
type DevOps{
1261+
name: String
1262+
}
1263+
1264+
directive @optIn(to : [String!]!) repeatable on FIELD
1265+
'''
1266+
def query = '''subscription {
1267+
foo1 (arg: {
1268+
arg1: "Subscription"
1269+
}) @optIn(to: "foo") {
1270+
name @optIn(to: "devOps")
1271+
}
1272+
1273+
1274+
}
1275+
'''
1276+
GraphQLSchema schema = mkSchema(sdl)
1277+
Document document = new Parser().parse(query)
1278+
ExecutableNormalizedOperation eno = ExecutableNormalizedOperationFactory.createExecutableNormalizedOperationWithRawVariables(schema,document, null,RawVariables.emptyVariables())
1279+
1280+
1281+
when:
1282+
def result = compileToDocument(schema, SUBSCRIPTION, null, eno.topLevelFields, eno.normalizedFieldToQueryDirectives, noVariables)
1283+
OperationDefinition operationDefinition = result.document.getDefinitionsOfType(OperationDefinition.class)[0]
1284+
def fooField = (Field)operationDefinition.selectionSet.children[0]
1285+
def nameField = (Field)fooField.selectionSet.children[0]
1286+
def documentPrinted = AstPrinter.printAst(new AstSorter().sort(result.document))
1287+
1288+
then:
1289+
1290+
fooField.directives.size() == 1
1291+
nameField.directives.size() == 1
1292+
documentPrinted == '''subscription {
1293+
foo1(arg: {arg1 : "Subscription"}) @optIn(to: "foo") {
1294+
name @optIn(to: "devOps")
1295+
}
1296+
}
1297+
'''
1298+
}
1299+
12411300
def "test redundant inline fragments specified in original query"() {
12421301
def sdl = '''
12431302
type Query {

0 commit comments

Comments
 (0)