Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/main/java/graphql/VisibleForTesting.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import static java.lang.annotation.ElementType.CONSTRUCTOR;
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.TYPE;

/**
* Marks fields, methods etc as more visible than actually needed for testing purposes.
* Marks fields, methods, types etc as more visible than actually needed for testing purposes.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(value = {CONSTRUCTOR, METHOD, FIELD})
@Target(value = {CONSTRUCTOR, METHOD, FIELD, TYPE})
@Internal
public @interface VisibleForTesting {
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import graphql.execution.incremental.AlternativeCallContext;
import org.dataloader.DataLoader;
import org.dataloader.DataLoaderRegistry;
import graphql.VisibleForTesting;
import org.jspecify.annotations.NullMarked;
import org.jspecify.annotations.Nullable;

Expand All @@ -31,7 +32,8 @@ public class ExhaustedDataLoaderDispatchStrategy implements DataLoaderDispatchSt
private final Map<AlternativeCallContext, CallStack> alternativeCallContextMap = new ConcurrentHashMap<>();


private static class CallStack {
@VisibleForTesting
static class CallStack {

// 30 bits for objectRunningCount
// 1 bit for dataLoaderToDispatch
Expand Down Expand Up @@ -127,7 +129,12 @@ public void clear() {
}

public ExhaustedDataLoaderDispatchStrategy(ExecutionContext executionContext) {
this.initialCallStack = new CallStack();
this(executionContext, new CallStack());
}

@VisibleForTesting
ExhaustedDataLoaderDispatchStrategy(ExecutionContext executionContext, CallStack callStack) {
this.initialCallStack = callStack;
this.executionContext = executionContext;

this.profiler = executionContext.getProfiler();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,136 @@ class ExhaustedDataLoaderDispatchStrategyTest extends Specification {
// The deferred call stack dispatches independently
batchLoaderInvocations.get() == 1
}

/**
* A CallStack subclass that forces CAS failures to deterministically exercise
* the retry paths in dispatchImpl's CAS loop. Without this, CAS retries only
* happen under real thread contention, making coverage non-deterministic.
*
* Failures are targeted: only CAS attempts matching a specific state transition
* (identified by the newState pattern) are failed, so other CAS users like
* incrementObjectRunningCount/decrementObjectRunningCount are not affected.
*/
static class ContendedCallStack extends ExhaustedDataLoaderDispatchStrategy.CallStack {
// The newState value that should trigger a simulated CAS failure
volatile int failOnNewState = -1
final AtomicInteger casFailuresRemaining = new AtomicInteger(0)

@Override
boolean tryUpdateState(int oldState, int newState) {
if (newState == failOnNewState && casFailuresRemaining.getAndDecrement() > 0) {
return false
}
return super.tryUpdateState(oldState, newState)
}
}

private void setupStrategyWithCallStack(BatchLoader<String, String> batchLoader, ExhaustedDataLoaderDispatchStrategy.CallStack callStack) {
dataLoaderRegistry = new DataLoaderRegistry()
def dataLoader = DataLoaderFactory.newDataLoader(batchLoader)
dataLoaderRegistry.register("testLoader", dataLoader)

def executionInput = ExecutionInput.newExecutionInput()
.query("{ dummy }")
.build()
def engineRunningState = new EngineRunningState(executionInput, Profiler.NO_OP)

def executionStrategy = new AsyncExecutionStrategy()
executionContext = new ExecutionContextBuilder()
.executionId(ExecutionId.generate())
.graphQLSchema(GraphQLSchema.newSchema().query(
graphql.schema.GraphQLObjectType.newObject()
.name("Query")
.field({ f -> f.name("dummy").type(GraphQLString) })
.build()
).build())
.queryStrategy(executionStrategy)
.mutationStrategy(executionStrategy)
.subscriptionStrategy(executionStrategy)
.graphQLContext(GraphQLContext.newContext().build())
.coercedVariables(CoercedVariables.emptyVariables())
.dataLoaderRegistry(dataLoaderRegistry)
.executionInput(executionInput)
.profiler(Profiler.NO_OP)
.engineRunningState(engineRunningState)
.build()

strategy = new ExhaustedDataLoaderDispatchStrategy(executionContext, callStack)

rootParams = newParameters()
.executionStepInfo(newExecutionStepInfo()
.type(GraphQLString)
.path(ResultPath.rootPath())
.build())
.source(new Object())
.fields(graphql.execution.MergedSelectionSet.newMergedSelectionSet().build())
.nonNullFieldValidator(new NonNullableFieldValidator(executionContext))
.build()
}

def "CAS retry in dispatchImpl dispatch path is exercised under contention"() {
given:
def contendedCallStack = new ContendedCallStack()
setupStrategyWithCallStack(simpleBatchLoader(), contendedCallStack)
dataLoaderRegistry.getDataLoader("testLoader").load("key1")

when:
strategy.executionStrategy(executionContext, rootParams, 1)
strategy.newDataLoaderInvocation(null)
// The dispatch-setup CAS in dispatchImpl sets currentlyDispatching=true and
// dataLoaderToDispatch=false. With objectRunningCount=0, the target newState is:
// currentlyDispatching(bit0)=1, dataLoaderToDispatch(bit1)=0, objectRunningCount(bits2+)=0
// = 0b01 = 1
contendedCallStack.failOnNewState = ExhaustedDataLoaderDispatchStrategy.CallStack.setCurrentlyDispatching(
ExhaustedDataLoaderDispatchStrategy.CallStack.setDataLoaderToDispatch(0, false), true)
contendedCallStack.casFailuresRemaining.set(1)
strategy.finishedFetching(executionContext, rootParams)

Thread.sleep(200)

then:
batchLoaderInvocations.get() == 1
}

def "CAS retry in dispatchImpl early-exit path is exercised under contention"() {
given:
def doneLatch = new CountDownLatch(1)
AtomicInteger roundCount = new AtomicInteger()
def contendedCallStack = new ContendedCallStack()

def chainedBatchLoader = new BatchLoader<String, String>() {
@Override
CompletionStage<List<String>> load(List<String> keys) {
int round = roundCount.incrementAndGet()
if (round == 1) {
// During first batch, load another key to trigger second dispatch round
dataLoaderRegistry.getDataLoader("testLoader").load("key2")
strategy.newDataLoaderInvocation(null)
}
if (round == 2) {
// Inject a CAS failure targeting the early-exit path. After round 2
// completes, the recursive dispatchImpl sees dataLoaderToDispatch=false
// and tries to set currentlyDispatching=false. The target newState is 0
// (all bits cleared: no dispatching, no dataLoader pending, objectRunning=0).
contendedCallStack.failOnNewState = ExhaustedDataLoaderDispatchStrategy.CallStack.setCurrentlyDispatching(0, false)
contendedCallStack.casFailuresRemaining.set(1)
doneLatch.countDown()
}
return CompletableFuture.completedFuture(keys)
}
}
setupStrategyWithCallStack(chainedBatchLoader, contendedCallStack)
dataLoaderRegistry.getDataLoader("testLoader").load("key1")

when:
strategy.executionStrategy(executionContext, rootParams, 1)
strategy.newDataLoaderInvocation(null)
strategy.finishedFetching(executionContext, rootParams)

def completed = doneLatch.await(2, TimeUnit.SECONDS)

then:
completed
roundCount.get() == 2
}
}