Skip to content

Commit 1e12da9

Browse files
authored
1099 mutation batching (graphql-java#1102)
* Added @OverRide as part of errorprone code health check * Revert "Added @OverRide as part of errorprone code health check" This reverts commit 38dfab1 * Made mutations not use data loader so they don't lock up
1 parent 639aff3 commit 1e12da9

3 files changed

Lines changed: 232 additions & 2 deletions

File tree

src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderDispatcherInstrumentation.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import graphql.ExecutionResult;
44
import graphql.ExecutionResultImpl;
55
import graphql.execution.AsyncExecutionStrategy;
6+
import graphql.execution.ExecutionContext;
67
import graphql.execution.ExecutionStrategy;
78
import graphql.execution.instrumentation.DeferredFieldInstrumentationContext;
89
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext;
@@ -15,6 +16,7 @@
1516
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters;
1617
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters;
1718
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters;
19+
import graphql.language.OperationDefinition;
1820
import graphql.schema.DataFetcher;
1921
import org.dataloader.DataLoader;
2022
import org.dataloader.DataLoaderRegistry;
@@ -99,14 +101,27 @@ private void immediatelyDispatch() {
99101

100102
@Override
101103
public InstrumentationContext<ExecutionResult> beginExecuteOperation(InstrumentationExecuteOperationParameters parameters) {
102-
ExecutionStrategy queryStrategy = parameters.getExecutionContext().getQueryStrategy();
103-
if (!(queryStrategy instanceof AsyncExecutionStrategy)) {
104+
if (!isDataLoaderCompatibleExecution(parameters.getExecutionContext())) {
104105
DataLoaderDispatcherInstrumentationState state = parameters.getInstrumentationState();
105106
state.setAggressivelyBatching(false);
106107
}
107108
return new SimpleInstrumentationContext<>();
108109
}
109110

111+
private boolean isDataLoaderCompatibleExecution(ExecutionContext executionContext) {
112+
//
113+
// currently we only support Query operations and ONLY with AsyncExecutionStrategy as the query ES
114+
// This may change in the future but this is the fix for now
115+
//
116+
if (executionContext.getOperationDefinition().getOperation() == OperationDefinition.Operation.QUERY) {
117+
ExecutionStrategy queryStrategy = executionContext.getQueryStrategy();
118+
if (queryStrategy instanceof AsyncExecutionStrategy) {
119+
return true;
120+
}
121+
}
122+
return false;
123+
}
124+
110125
@Override
111126
public ExecutionStrategyInstrumentationContext beginExecutionStrategy(InstrumentationExecutionStrategyParameters parameters) {
112127
return fieldLevelTrackingApproach.beginExecutionStrategy(parameters);
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package graphql.execution.instrumentation.dataloader;
2+
3+
4+
import org.dataloader.DataLoader;
5+
6+
import java.util.ArrayList;
7+
import java.util.Collections;
8+
import java.util.List;
9+
import java.util.UUID;
10+
import java.util.concurrent.CompletableFuture;
11+
import java.util.concurrent.ConcurrentHashMap;
12+
import java.util.concurrent.ConcurrentMap;
13+
import java.util.stream.Collectors;
14+
15+
import static java.util.Objects.requireNonNull;
16+
17+
public class DataLoaderCompanyProductBackend {
18+
19+
private final ConcurrentMap<UUID, Company> companies = new ConcurrentHashMap<>();
20+
private final ConcurrentMap<UUID, Project> projects = new ConcurrentHashMap<>();
21+
22+
private final DataLoader<UUID, List<Project>> projectsLoader;
23+
24+
public DataLoaderCompanyProductBackend(int companyCount, int projectCount) {
25+
for (int i = 0; i < companyCount; i++) {
26+
mkCompany(projectCount);
27+
}
28+
29+
projectsLoader = new DataLoader<>(keys -> getProjectsForCompanies(keys).thenApply(projects -> keys
30+
.stream()
31+
.map(companyId -> projects.stream()
32+
.filter(project -> project.getCompanyId().equals(companyId))
33+
.collect(Collectors.toList()))
34+
.collect(Collectors.toList())));
35+
36+
}
37+
38+
private Company mkCompany(int projectCount) {
39+
Company company = new Company();
40+
companies.put(company.getId(), company);
41+
for (int j = 0; j < projectCount; j++) {
42+
Project project = new Project(company.getId());
43+
projects.put(project.getId(), project);
44+
}
45+
return company;
46+
}
47+
48+
public DataLoader<UUID, List<Project>> getProjectsLoader() {
49+
return projectsLoader;
50+
}
51+
52+
public CompletableFuture<List<Company>> getCompanies() {
53+
return CompletableFuture.supplyAsync(this::companiesList);
54+
}
55+
56+
private List<Company> companiesList() {
57+
return Collections.unmodifiableList(new ArrayList<>(companies.values()));
58+
}
59+
60+
public CompletableFuture<List<Project>> getProjectsForCompanies(List<UUID> companyIds) {
61+
return CompletableFuture.supplyAsync(() -> projects.values().stream()
62+
.filter(project -> companyIds.contains(project.getCompanyId()))
63+
.collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList)));
64+
}
65+
66+
public CompletableFuture<Company> addCompany() {
67+
return CompletableFuture.supplyAsync(() -> mkCompany(3));
68+
}
69+
70+
public static class Company {
71+
72+
private final UUID id;
73+
private final String name;
74+
75+
public Company() {
76+
id = UUID.randomUUID();
77+
name = "Company " + id.toString().substring(0, 8);
78+
}
79+
80+
public UUID getId() {
81+
return id;
82+
}
83+
84+
public String getName() {
85+
return name;
86+
}
87+
88+
}
89+
90+
public static class Project {
91+
92+
private final UUID id;
93+
private final String name;
94+
private final UUID companyId;
95+
96+
public Project(UUID companyId) {
97+
id = UUID.randomUUID();
98+
name = "Project " + id.toString().substring(0, 8);
99+
this.companyId = requireNonNull(companyId);
100+
}
101+
102+
public UUID getId() {
103+
return id;
104+
}
105+
106+
public String getName() {
107+
return name;
108+
}
109+
110+
public UUID getCompanyId() {
111+
return companyId;
112+
}
113+
114+
}
115+
116+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package graphql.execution.instrumentation.dataloader
2+
3+
import graphql.ExecutionInput
4+
import graphql.ExecutionResult
5+
import graphql.GraphQL
6+
import graphql.TestUtil
7+
import graphql.execution.AsyncExecutionStrategy
8+
import graphql.execution.AsyncSerialExecutionStrategy
9+
import org.dataloader.DataLoaderRegistry
10+
import spock.lang.Specification
11+
import spock.lang.Unroll
12+
13+
import java.util.concurrent.TimeUnit
14+
15+
import static graphql.schema.idl.RuntimeWiring.newRuntimeWiring
16+
import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring
17+
18+
class DataLoaderCompanyProductMutationTest extends Specification {
19+
20+
@Unroll
21+
def "bug #1099 test mutation completes as expected and does not hang - running #note"() {
22+
23+
DataLoaderCompanyProductBackend backend = new DataLoaderCompanyProductBackend(3, 5)
24+
25+
def spec = '''
26+
27+
type Project {
28+
id : ID!
29+
name : String!
30+
}
31+
32+
type Company {
33+
id : ID!
34+
name : String!
35+
projects : [Project!]
36+
}
37+
38+
type Query {
39+
companies : [Company!]
40+
}
41+
42+
type Mutation {
43+
addCompany : Company
44+
}
45+
'''
46+
47+
def wiring = newRuntimeWiring()
48+
.type(
49+
newTypeWiring("Company").dataFetcher("projects", {
50+
environment ->
51+
DataLoaderCompanyProductBackend.Company source = environment.getSource()
52+
return backend.getProjectsLoader().load(source.getId())
53+
}))
54+
.type(
55+
newTypeWiring("Query").dataFetcher("companies", {
56+
environment -> backend.getCompanies()
57+
}))
58+
.type(
59+
newTypeWiring("Mutation").dataFetcher("addCompany", {
60+
environment -> backend.addCompany()
61+
}))
62+
.build()
63+
64+
def schema = TestUtil.schema(spec, wiring)
65+
def registry = new DataLoaderRegistry()
66+
registry.register("projects-dl", backend.getProjectsLoader())
67+
68+
def graphQL = GraphQL.newGraphQL(schema)
69+
.queryExecutionStrategy(queryES)
70+
.mutationExecutionStrategy(mutationES)
71+
.instrumentation(new DataLoaderDispatcherInstrumentation(registry))
72+
.build()
73+
74+
ExecutionInput executionInput = ExecutionInput.newExecutionInput()
75+
.query(query)
76+
.build()
77+
78+
when:
79+
80+
ExecutionResult result = graphQL.executeAsync(executionInput).get(5, TimeUnit.SECONDS)
81+
82+
then:
83+
84+
result != null
85+
result.errors.isEmpty()
86+
result.data != null
87+
88+
where:
89+
90+
note | query | queryES | mutationES
91+
"mutation - spec compliant" | "mutation { addCompany { name projects { name }}}" | new AsyncExecutionStrategy() | new AsyncSerialExecutionStrategy()
92+
"mutation - all serial" | "mutation { addCompany { name projects { name }}}" | new AsyncSerialExecutionStrategy() | new AsyncSerialExecutionStrategy()
93+
"mutation - non spec compliant" | "mutation { addCompany { name projects { name }}}" | new AsyncExecutionStrategy() | new AsyncExecutionStrategy()
94+
95+
"query - spec compliant" | "query {companies { name projects { name }}}" | new AsyncExecutionStrategy() | new AsyncSerialExecutionStrategy()
96+
"query - all serial" | "query {companies { name projects { name }}}" | new AsyncSerialExecutionStrategy() | new AsyncSerialExecutionStrategy()
97+
"query - non spec compliant" | "query {companies { name projects { name }}}" | new AsyncExecutionStrategy() | new AsyncExecutionStrategy()
98+
}
99+
}

0 commit comments

Comments
 (0)