diff --git a/src/main/java/graphql/execution/reactive/CompletionStageMappingPublisher.java b/src/main/java/graphql/execution/reactive/CompletionStageMappingPublisher.java index b686889795..4bfdcc1322 100644 --- a/src/main/java/graphql/execution/reactive/CompletionStageMappingPublisher.java +++ b/src/main/java/graphql/execution/reactive/CompletionStageMappingPublisher.java @@ -5,6 +5,7 @@ import org.reactivestreams.Subscription; import java.util.concurrent.CompletionStage; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; /** @@ -33,9 +34,12 @@ public CompletionStageMappingPublisher(Publisher upstreamPublisher, Function< public void subscribe(Subscriber downstreamSubscriber) { upstreamPublisher.subscribe(new Subscriber() { Subscription delegatingSubscription; + private AtomicInteger inFlight; + private volatile Runnable finish; @Override public void onSubscribe(Subscription subscription) { + inFlight = new AtomicInteger(); delegatingSubscription = new DelegatingSubscription(subscription); downstreamSubscriber.onSubscribe(delegatingSubscription); } @@ -45,11 +49,20 @@ public void onNext(U u) { CompletionStage completionStage; try { completionStage = mapper.apply(u); + inFlight.getAndIncrement(); completionStage.whenComplete((d, throwable) -> { - if (throwable != null) { - handleThrowable(throwable); - } else { - downstreamSubscriber.onNext(d); + try { + if (throwable != null) { + handleThrowable(throwable); + } else { + downstreamSubscriber.onNext(d); + } + }finally { + if(inFlight.intValue() == 1 && finish != null) { + finish.run(); + finish = null; + } + inFlight.decrementAndGet(); } }); } catch (RuntimeException throwable) { @@ -71,12 +84,28 @@ private void handleThrowable(Throwable throwable) { @Override public void onError(Throwable t) { - downstreamSubscriber.onError(t); + if(inFlight.intValue() > 0) { + finish = () -> downstreamSubscriber.onError(t); + if(inFlight.intValue() == 0 && finish != null) { + //happened together + downstreamSubscriber.onError(t); + } + }else { + downstreamSubscriber.onError(t); + } } @Override public void onComplete() { - downstreamSubscriber.onComplete(); + if(inFlight.intValue() > 0) { + finish = () -> downstreamSubscriber.onComplete(); + if(inFlight.intValue() == 0 && finish != null) { + //happened together + downstreamSubscriber.onComplete(); + } + }else { + downstreamSubscriber.onComplete(); + } } }); } diff --git a/src/test/groovy/graphql/execution/pubsub/CapturingSubscriber.java b/src/test/groovy/graphql/execution/pubsub/CapturingSubscriber.java index 97e7d48315..5336655ce6 100644 --- a/src/test/groovy/graphql/execution/pubsub/CapturingSubscriber.java +++ b/src/test/groovy/graphql/execution/pubsub/CapturingSubscriber.java @@ -5,6 +5,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -15,6 +17,7 @@ public class CapturingSubscriber implements Subscriber { private final AtomicBoolean done = new AtomicBoolean(); private Subscription subscription; private Throwable throwable; + private CompletableFuture doneFuture = new CompletableFuture<>(); @Override @@ -33,11 +36,13 @@ public void onNext(T t) { public void onError(Throwable t) { this.throwable = t; done.set(true); + doneFuture.complete(null); } @Override public void onComplete() { done.set(true); + doneFuture.complete(null); } public List getEvents() { @@ -48,6 +53,14 @@ public Throwable getThrowable() { return throwable; } + public void done() { + try { + doneFuture.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + public AtomicBoolean isDone() { return done; } diff --git a/src/test/groovy/graphql/execution/reactive/CompletionStageMappingPublisherTest.groovy b/src/test/groovy/graphql/execution/reactive/CompletionStageMappingPublisherTest.groovy index cd2771b605..fed59744cd 100644 --- a/src/test/groovy/graphql/execution/reactive/CompletionStageMappingPublisherTest.groovy +++ b/src/test/groovy/graphql/execution/reactive/CompletionStageMappingPublisherTest.groovy @@ -7,7 +7,10 @@ import spock.lang.Specification import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletionStage +import java.util.concurrent.Executor +import java.util.concurrent.TimeUnit import java.util.function.Function +import java.util.function.Supplier class CompletionStageMappingPublisherTest extends Specification { @@ -15,18 +18,24 @@ class CompletionStageMappingPublisherTest extends Specification { when: Publisher rxIntegers = Flowable.range(0, 10) + Executor executor = CompletableFuture.delayedExecutor(50, TimeUnit.MILLISECONDS) def mapper = new Function>() { @Override CompletionStage apply(Integer integer) { - return CompletableFuture.completedFuture(String.valueOf(integer)) + return CompletableFuture.supplyAsync(new Supplier() { + @Override + String get() { + return String.valueOf(integer) + } + }, executor) } } Publisher rxStrings = new CompletionStageMappingPublisher(rxIntegers, mapper) def capturingSubscriber = new CapturingSubscriber<>() rxStrings.subscribe(capturingSubscriber) - + capturingSubscriber.done() then: capturingSubscriber.events.size() == 10 @@ -38,11 +47,17 @@ class CompletionStageMappingPublisherTest extends Specification { when: Publisher rxIntegers = Flowable.range(0, 10) + Executor executor = CompletableFuture.delayedExecutor(50, TimeUnit.MILLISECONDS) def mapper = new Function>() { @Override CompletionStage apply(Integer integer) { - return CompletableFuture.completedFuture(String.valueOf(integer)) + return CompletableFuture.supplyAsync(new Supplier() { + @Override + String get() { + return String.valueOf(integer) + } + }, executor) } } Publisher rxStrings = new CompletionStageMappingPublisher(rxIntegers, mapper) @@ -51,7 +66,8 @@ class CompletionStageMappingPublisherTest extends Specification { def capturingSubscriber2 = new CapturingSubscriber<>() rxStrings.subscribe(capturingSubscriber1) rxStrings.subscribe(capturingSubscriber2) - + capturingSubscriber1.done() + capturingSubscriber2.done() then: capturingSubscriber1.events.size() == 10