diff --git a/src/main/java/graphql/VisibleForTesting.java b/src/main/java/graphql/VisibleForTesting.java index 864fa70d1..e8097a97a 100644 --- a/src/main/java/graphql/VisibleForTesting.java +++ b/src/main/java/graphql/VisibleForTesting.java @@ -7,12 +7,13 @@ import static java.lang.annotation.ElementType.CONSTRUCTOR; import static java.lang.annotation.ElementType.FIELD; import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; /** - * Marks fields, methods etc as more visible than actually needed for testing purposes. + * Marks fields, methods, types etc as more visible than actually needed for testing purposes. */ @Retention(RetentionPolicy.RUNTIME) -@Target(value = {CONSTRUCTOR, METHOD, FIELD}) +@Target(value = {CONSTRUCTOR, METHOD, FIELD, TYPE}) @Internal public @interface VisibleForTesting { } diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategy.java b/src/main/java/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategy.java index f237622ec..d7c7669f3 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategy.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategy.java @@ -9,6 +9,7 @@ import graphql.execution.incremental.AlternativeCallContext; import org.dataloader.DataLoader; import org.dataloader.DataLoaderRegistry; +import graphql.VisibleForTesting; import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; @@ -31,7 +32,8 @@ public class ExhaustedDataLoaderDispatchStrategy implements DataLoaderDispatchSt private final Map alternativeCallContextMap = new ConcurrentHashMap<>(); - private static class CallStack { + @VisibleForTesting + static class CallStack { // 30 bits for objectRunningCount // 1 bit for dataLoaderToDispatch @@ -127,7 +129,12 @@ public void clear() { } public ExhaustedDataLoaderDispatchStrategy(ExecutionContext executionContext) { - this.initialCallStack = new CallStack(); + this(executionContext, new CallStack()); + } + + @VisibleForTesting + ExhaustedDataLoaderDispatchStrategy(ExecutionContext executionContext, CallStack callStack) { + this.initialCallStack = callStack; this.executionContext = executionContext; this.profiler = executionContext.getProfiler(); diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategyTest.groovy b/src/test/groovy/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategyTest.groovy index f8413d64e..dd61b3907 100644 --- a/src/test/groovy/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategyTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/ExhaustedDataLoaderDispatchStrategyTest.groovy @@ -298,4 +298,136 @@ class ExhaustedDataLoaderDispatchStrategyTest extends Specification { // The deferred call stack dispatches independently batchLoaderInvocations.get() == 1 } + + /** + * A CallStack subclass that forces CAS failures to deterministically exercise + * the retry paths in dispatchImpl's CAS loop. Without this, CAS retries only + * happen under real thread contention, making coverage non-deterministic. + * + * Failures are targeted: only CAS attempts matching a specific state transition + * (identified by the newState pattern) are failed, so other CAS users like + * incrementObjectRunningCount/decrementObjectRunningCount are not affected. + */ + static class ContendedCallStack extends ExhaustedDataLoaderDispatchStrategy.CallStack { + // The newState value that should trigger a simulated CAS failure + volatile int failOnNewState = -1 + final AtomicInteger casFailuresRemaining = new AtomicInteger(0) + + @Override + boolean tryUpdateState(int oldState, int newState) { + if (newState == failOnNewState && casFailuresRemaining.getAndDecrement() > 0) { + return false + } + return super.tryUpdateState(oldState, newState) + } + } + + private void setupStrategyWithCallStack(BatchLoader batchLoader, ExhaustedDataLoaderDispatchStrategy.CallStack callStack) { + dataLoaderRegistry = new DataLoaderRegistry() + def dataLoader = DataLoaderFactory.newDataLoader(batchLoader) + dataLoaderRegistry.register("testLoader", dataLoader) + + def executionInput = ExecutionInput.newExecutionInput() + .query("{ dummy }") + .build() + def engineRunningState = new EngineRunningState(executionInput, Profiler.NO_OP) + + def executionStrategy = new AsyncExecutionStrategy() + executionContext = new ExecutionContextBuilder() + .executionId(ExecutionId.generate()) + .graphQLSchema(GraphQLSchema.newSchema().query( + graphql.schema.GraphQLObjectType.newObject() + .name("Query") + .field({ f -> f.name("dummy").type(GraphQLString) }) + .build() + ).build()) + .queryStrategy(executionStrategy) + .mutationStrategy(executionStrategy) + .subscriptionStrategy(executionStrategy) + .graphQLContext(GraphQLContext.newContext().build()) + .coercedVariables(CoercedVariables.emptyVariables()) + .dataLoaderRegistry(dataLoaderRegistry) + .executionInput(executionInput) + .profiler(Profiler.NO_OP) + .engineRunningState(engineRunningState) + .build() + + strategy = new ExhaustedDataLoaderDispatchStrategy(executionContext, callStack) + + rootParams = newParameters() + .executionStepInfo(newExecutionStepInfo() + .type(GraphQLString) + .path(ResultPath.rootPath()) + .build()) + .source(new Object()) + .fields(graphql.execution.MergedSelectionSet.newMergedSelectionSet().build()) + .nonNullFieldValidator(new NonNullableFieldValidator(executionContext)) + .build() + } + + def "CAS retry in dispatchImpl dispatch path is exercised under contention"() { + given: + def contendedCallStack = new ContendedCallStack() + setupStrategyWithCallStack(simpleBatchLoader(), contendedCallStack) + dataLoaderRegistry.getDataLoader("testLoader").load("key1") + + when: + strategy.executionStrategy(executionContext, rootParams, 1) + strategy.newDataLoaderInvocation(null) + // The dispatch-setup CAS in dispatchImpl sets currentlyDispatching=true and + // dataLoaderToDispatch=false. With objectRunningCount=0, the target newState is: + // currentlyDispatching(bit0)=1, dataLoaderToDispatch(bit1)=0, objectRunningCount(bits2+)=0 + // = 0b01 = 1 + contendedCallStack.failOnNewState = ExhaustedDataLoaderDispatchStrategy.CallStack.setCurrentlyDispatching( + ExhaustedDataLoaderDispatchStrategy.CallStack.setDataLoaderToDispatch(0, false), true) + contendedCallStack.casFailuresRemaining.set(1) + strategy.finishedFetching(executionContext, rootParams) + + Thread.sleep(200) + + then: + batchLoaderInvocations.get() == 1 + } + + def "CAS retry in dispatchImpl early-exit path is exercised under contention"() { + given: + def doneLatch = new CountDownLatch(1) + AtomicInteger roundCount = new AtomicInteger() + def contendedCallStack = new ContendedCallStack() + + def chainedBatchLoader = new BatchLoader() { + @Override + CompletionStage> load(List keys) { + int round = roundCount.incrementAndGet() + if (round == 1) { + // During first batch, load another key to trigger second dispatch round + dataLoaderRegistry.getDataLoader("testLoader").load("key2") + strategy.newDataLoaderInvocation(null) + } + if (round == 2) { + // Inject a CAS failure targeting the early-exit path. After round 2 + // completes, the recursive dispatchImpl sees dataLoaderToDispatch=false + // and tries to set currentlyDispatching=false. The target newState is 0 + // (all bits cleared: no dispatching, no dataLoader pending, objectRunning=0). + contendedCallStack.failOnNewState = ExhaustedDataLoaderDispatchStrategy.CallStack.setCurrentlyDispatching(0, false) + contendedCallStack.casFailuresRemaining.set(1) + doneLatch.countDown() + } + return CompletableFuture.completedFuture(keys) + } + } + setupStrategyWithCallStack(chainedBatchLoader, contendedCallStack) + dataLoaderRegistry.getDataLoader("testLoader").load("key1") + + when: + strategy.executionStrategy(executionContext, rootParams, 1) + strategy.newDataLoaderInvocation(null) + strategy.finishedFetching(executionContext, rootParams) + + def completed = doneLatch.await(2, TimeUnit.SECONDS) + + then: + completed + roundCount.get() == 2 + } }