Skip to content

Commit dba84db

Browse files
committed
Breaking out query complexity into its own class
1 parent 5b19375 commit dba84db

3 files changed

Lines changed: 198 additions & 49 deletions

File tree

src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515
import org.slf4j.Logger;
1616
import org.slf4j.LoggerFactory;
1717

18-
import java.util.LinkedHashMap;
1918
import java.util.List;
20-
import java.util.Map;
2119
import java.util.concurrent.atomic.AtomicReference;
2220
import java.util.function.Function;
2321

2422
import static graphql.Assert.assertNotNull;
2523
import static graphql.execution.instrumentation.InstrumentationState.ofState;
2624
import static graphql.execution.instrumentation.SimpleInstrumentationContext.noOp;
27-
import static java.util.Optional.ofNullable;
2825

2926
/**
3027
* Prevents execution if the query complexity is greater than the specified maxComplexity.
@@ -101,21 +98,8 @@ public InstrumentationState createState(InstrumentationCreateStateParameters par
10198
@Override
10299
public @Nullable InstrumentationContext<ExecutionResult> beginExecuteOperation(InstrumentationExecuteOperationParameters instrumentationExecuteOperationParameters, InstrumentationState rawState) {
103100
State state = ofState(rawState);
104-
QueryTraverser queryTraverser = newQueryTraverser(instrumentationExecuteOperationParameters.getExecutionContext());
105-
106-
Map<QueryVisitorFieldEnvironment, Integer> valuesByParent = new LinkedHashMap<>();
107-
queryTraverser.visitPostOrder(new QueryVisitorStub() {
108-
@Override
109-
public void visitField(QueryVisitorFieldEnvironment env) {
110-
int childComplexity = valuesByParent.getOrDefault(env, 0);
111-
int value = calculateComplexity(env, childComplexity);
112-
113-
valuesByParent.compute(env.getParentEnvironment(), (key, oldValue) ->
114-
ofNullable(oldValue).orElse(0) + value
115-
);
116-
}
117-
});
118-
int totalComplexity = valuesByParent.getOrDefault(null, 0);
101+
QueryComplexityCalculator queryComplexityCalculator = newQueryComplexityCalculator(instrumentationExecuteOperationParameters.getExecutionContext());
102+
int totalComplexity = queryComplexityCalculator.calculate();
119103
if (log.isDebugEnabled()) {
120104
log.debug("Query complexity: {}", totalComplexity);
121105
}
@@ -133,6 +117,16 @@ public void visitField(QueryVisitorFieldEnvironment env) {
133117
return noOp();
134118
}
135119

120+
private QueryComplexityCalculator newQueryComplexityCalculator(ExecutionContext executionContext) {
121+
return QueryComplexityCalculator.newCalculator()
122+
.fieldComplexityCalculator(fieldComplexityCalculator)
123+
.schema(executionContext.getGraphQLSchema())
124+
.document(executionContext.getDocument())
125+
.operationName(executionContext.getExecutionInput().getOperationName())
126+
.variables(executionContext.getCoercedVariables())
127+
.build();
128+
}
129+
136130
/**
137131
* Called to generate your own error message or custom exception class
138132
*
@@ -145,37 +139,6 @@ protected AbortExecutionException mkAbortException(int totalComplexity, int maxC
145139
return new AbortExecutionException("maximum query complexity exceeded " + totalComplexity + " > " + maxComplexity);
146140
}
147141

148-
QueryTraverser newQueryTraverser(ExecutionContext executionContext) {
149-
return QueryTraverser.newQueryTraverser()
150-
.schema(executionContext.getGraphQLSchema())
151-
.document(executionContext.getDocument())
152-
.operationName(executionContext.getExecutionInput().getOperationName())
153-
.coercedVariables(executionContext.getCoercedVariables())
154-
.build();
155-
}
156-
157-
private int calculateComplexity(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment, int childComplexity) {
158-
if (queryVisitorFieldEnvironment.isTypeNameIntrospectionField()) {
159-
return 0;
160-
}
161-
FieldComplexityEnvironment fieldComplexityEnvironment = convertEnv(queryVisitorFieldEnvironment);
162-
return fieldComplexityCalculator.calculate(fieldComplexityEnvironment, childComplexity);
163-
}
164-
165-
private FieldComplexityEnvironment convertEnv(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment) {
166-
FieldComplexityEnvironment parentEnv = null;
167-
if (queryVisitorFieldEnvironment.getParentEnvironment() != null) {
168-
parentEnv = convertEnv(queryVisitorFieldEnvironment.getParentEnvironment());
169-
}
170-
return new FieldComplexityEnvironment(
171-
queryVisitorFieldEnvironment.getField(),
172-
queryVisitorFieldEnvironment.getFieldDefinition(),
173-
queryVisitorFieldEnvironment.getFieldsContainer(),
174-
queryVisitorFieldEnvironment.getArguments(),
175-
parentEnv
176-
);
177-
}
178-
179142
private static class State implements InstrumentationState {
180143
AtomicReference<InstrumentationValidationParameters> instrumentationValidationParameters = new AtomicReference<>();
181144
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package graphql.analysis;
2+
3+
import graphql.PublicApi;
4+
import graphql.execution.CoercedVariables;
5+
import graphql.language.Document;
6+
import graphql.schema.GraphQLSchema;
7+
8+
import java.util.LinkedHashMap;
9+
import java.util.Map;
10+
11+
import static graphql.Assert.assertNotNull;
12+
import static java.util.Optional.ofNullable;
13+
14+
/**
15+
* This can calculate the complexity of an operation using the specified {@link FieldComplexityCalculator} you pass
16+
* into it.
17+
*/
18+
@PublicApi
19+
public class QueryComplexityCalculator {
20+
21+
private final FieldComplexityCalculator fieldComplexityCalculator;
22+
private final GraphQLSchema schema;
23+
private final Document document;
24+
private final String operationName;
25+
private final CoercedVariables variables;
26+
27+
public QueryComplexityCalculator(Builder builder) {
28+
this.fieldComplexityCalculator = assertNotNull(builder.fieldComplexityCalculator, () -> "fieldComplexityCalculator can't be null");
29+
this.schema = assertNotNull(builder.schema, () -> "schema can't be null");
30+
this.document = assertNotNull(builder.document, () -> "document can't be null");
31+
this.variables = assertNotNull(builder.variables, () -> "variables can't be null");
32+
this.operationName = builder.operationName;
33+
}
34+
35+
36+
public int calculate() {
37+
Map<QueryVisitorFieldEnvironment, Integer> valuesByParent = calculateByParents();
38+
return valuesByParent.getOrDefault(null, 0);
39+
}
40+
41+
/**
42+
* @return a map that shows the field complexity for each field level in the operation
43+
*/
44+
public Map<QueryVisitorFieldEnvironment, Integer> calculateByParents() {
45+
QueryTraverser queryTraverser = QueryTraverser.newQueryTraverser()
46+
.schema(this.schema)
47+
.document(this.document)
48+
.operationName(this.operationName)
49+
.coercedVariables(this.variables)
50+
.build();
51+
52+
53+
Map<QueryVisitorFieldEnvironment, Integer> valuesByParent = new LinkedHashMap<>();
54+
queryTraverser.visitPostOrder(new QueryVisitorStub() {
55+
@Override
56+
public void visitField(QueryVisitorFieldEnvironment env) {
57+
int childComplexity = valuesByParent.getOrDefault(env, 0);
58+
int value = calculateComplexity(env, childComplexity);
59+
60+
QueryVisitorFieldEnvironment parentEnvironment = env.getParentEnvironment();
61+
valuesByParent.compute(parentEnvironment, (key, oldValue) -> {
62+
Integer currentValue = ofNullable(oldValue).orElse(0);
63+
return currentValue + value;
64+
}
65+
);
66+
}
67+
});
68+
69+
return valuesByParent;
70+
}
71+
72+
private int calculateComplexity(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment, int childComplexity) {
73+
if (queryVisitorFieldEnvironment.isTypeNameIntrospectionField()) {
74+
return 0;
75+
}
76+
FieldComplexityEnvironment fieldComplexityEnvironment = convertEnv(queryVisitorFieldEnvironment);
77+
return fieldComplexityCalculator.calculate(fieldComplexityEnvironment, childComplexity);
78+
}
79+
80+
private FieldComplexityEnvironment convertEnv(QueryVisitorFieldEnvironment queryVisitorFieldEnvironment) {
81+
FieldComplexityEnvironment parentEnv = null;
82+
if (queryVisitorFieldEnvironment.getParentEnvironment() != null) {
83+
parentEnv = convertEnv(queryVisitorFieldEnvironment.getParentEnvironment());
84+
}
85+
return new FieldComplexityEnvironment(
86+
queryVisitorFieldEnvironment.getField(),
87+
queryVisitorFieldEnvironment.getFieldDefinition(),
88+
queryVisitorFieldEnvironment.getFieldsContainer(),
89+
queryVisitorFieldEnvironment.getArguments(),
90+
parentEnv
91+
);
92+
}
93+
94+
public static Builder newCalculator() {
95+
return new Builder();
96+
}
97+
98+
public static class Builder {
99+
private FieldComplexityCalculator fieldComplexityCalculator;
100+
private GraphQLSchema schema;
101+
private Document document;
102+
private String operationName;
103+
private CoercedVariables variables = CoercedVariables.emptyVariables();
104+
105+
public Builder schema(GraphQLSchema graphQLSchema) {
106+
this.schema = graphQLSchema;
107+
return this;
108+
}
109+
110+
public Builder fieldComplexityCalculator(FieldComplexityCalculator complexityCalculator) {
111+
this.fieldComplexityCalculator = complexityCalculator;
112+
return this;
113+
}
114+
115+
public Builder document(Document document) {
116+
this.document = document;
117+
return this;
118+
}
119+
120+
public Builder operationName(String operationName) {
121+
this.operationName = operationName;
122+
return this;
123+
}
124+
125+
public Builder variables(CoercedVariables variables) {
126+
this.variables = variables;
127+
return this;
128+
}
129+
130+
public QueryComplexityCalculator build() {
131+
return new QueryComplexityCalculator(this);
132+
}
133+
}
134+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package graphql.analysis
2+
3+
4+
import graphql.TestUtil
5+
import graphql.execution.CoercedVariables
6+
import graphql.language.Document
7+
import graphql.parser.Parser
8+
import spock.lang.Specification
9+
10+
class QueryComplexityCalculatorTest extends Specification {
11+
12+
Document createQuery(String query) {
13+
Parser parser = new Parser()
14+
parser.parseDocument(query)
15+
}
16+
17+
def "can calculator complexity"() {
18+
given:
19+
def schema = TestUtil.schema("""
20+
type Query{
21+
foo: Foo
22+
bar: String
23+
}
24+
type Foo {
25+
scalar: String
26+
foo: Foo
27+
}
28+
""")
29+
def query = createQuery("""
30+
query q {
31+
f2: foo {scalar foo{scalar}}
32+
f1: foo { foo {foo {foo {foo{foo{scalar}}}}}} }
33+
""")
34+
35+
36+
when:
37+
FieldComplexityCalculator fieldComplexityCalculator = new FieldComplexityCalculator() {
38+
@Override
39+
int calculate(FieldComplexityEnvironment environment, int childComplexity) {
40+
return environment.getField().name.startsWith("foo") ? 10 : 1
41+
}
42+
}
43+
QueryComplexityCalculator calculator = QueryComplexityCalculator.newCalculator()
44+
.fieldComplexityCalculator(fieldComplexityCalculator).schema(schema).document(query).variables(CoercedVariables.emptyVariables())
45+
.build()
46+
def complexityScore = calculator.calculate()
47+
then:
48+
complexityScore == 20
49+
50+
51+
}
52+
}

0 commit comments

Comments
 (0)