Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.reactivestreams.Subscription;

import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -33,9 +34,12 @@ public CompletionStageMappingPublisher(Publisher<U> upstreamPublisher, Function<
public void subscribe(Subscriber<? super D> downstreamSubscriber) {
upstreamPublisher.subscribe(new Subscriber<U>() {
Subscription delegatingSubscription;
private AtomicInteger inFlight;
private volatile Runnable finish;

@Override
public void onSubscribe(Subscription subscription) {
inFlight = new AtomicInteger();
delegatingSubscription = new DelegatingSubscription(subscription);
downstreamSubscriber.onSubscribe(delegatingSubscription);
}
Expand All @@ -45,11 +49,20 @@ public void onNext(U u) {
CompletionStage<D> completionStage;
try {
completionStage = mapper.apply(u);
inFlight.getAndIncrement();
completionStage.whenComplete((d, throwable) -> {
if (throwable != null) {
handleThrowable(throwable);
} else {
downstreamSubscriber.onNext(d);
try {
if (throwable != null) {
handleThrowable(throwable);
} else {
downstreamSubscriber.onNext(d);
}
}finally {
if(inFlight.intValue() == 1 && finish != null) {
finish.run();
finish = null;
}
inFlight.decrementAndGet();
}
});
} catch (RuntimeException throwable) {
Expand All @@ -71,12 +84,28 @@ private void handleThrowable(Throwable throwable) {

@Override
public void onError(Throwable t) {
downstreamSubscriber.onError(t);
if(inFlight.intValue() > 0) {
finish = () -> downstreamSubscriber.onError(t);
if(inFlight.intValue() == 0 && finish != null) {
//happened together
downstreamSubscriber.onError(t);
}
}else {
downstreamSubscriber.onError(t);
}
}

@Override
public void onComplete() {
downstreamSubscriber.onComplete();
if(inFlight.intValue() > 0) {
finish = () -> downstreamSubscriber.onComplete();
if(inFlight.intValue() == 0 && finish != null) {
//happened together
downstreamSubscriber.onComplete();
}
}else {
downstreamSubscriber.onComplete();
}
}
});
}
Expand Down
13 changes: 13 additions & 0 deletions src/test/groovy/graphql/execution/pubsub/CapturingSubscriber.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;

/**
Expand All @@ -15,6 +17,7 @@ public class CapturingSubscriber<T> implements Subscriber<T> {
private final AtomicBoolean done = new AtomicBoolean();
private Subscription subscription;
private Throwable throwable;
private CompletableFuture<Void> doneFuture = new CompletableFuture<>();


@Override
Expand All @@ -33,11 +36,13 @@ public void onNext(T t) {
public void onError(Throwable t) {
this.throwable = t;
done.set(true);
doneFuture.complete(null);
}

@Override
public void onComplete() {
done.set(true);
doneFuture.complete(null);
}

public List<T> getEvents() {
Expand All @@ -48,6 +53,14 @@ public Throwable getThrowable() {
return throwable;
}

public void done() {
try {
doneFuture.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

public AtomicBoolean isDone() {
return done;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,35 @@ import spock.lang.Specification

import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import java.util.concurrent.Executor
import java.util.concurrent.TimeUnit
import java.util.function.Function
import java.util.function.Supplier

class CompletionStageMappingPublisherTest extends Specification {

def "basic mapping"() {

when:
Publisher<Integer> rxIntegers = Flowable.range(0, 10)
Executor executor = CompletableFuture.delayedExecutor(50, TimeUnit.MILLISECONDS)

def mapper = new Function<Integer, CompletionStage<String>>() {
@Override
CompletionStage<String> apply(Integer integer) {
return CompletableFuture.completedFuture(String.valueOf(integer))
return CompletableFuture.supplyAsync(new Supplier<String>() {
@Override
String get() {
return String.valueOf(integer)
}
}, executor)
}
}
Publisher<String> rxStrings = new CompletionStageMappingPublisher<String, Integer>(rxIntegers, mapper)

def capturingSubscriber = new CapturingSubscriber<>()
rxStrings.subscribe(capturingSubscriber)

capturingSubscriber.done()
then:

capturingSubscriber.events.size() == 10
Expand All @@ -38,11 +47,17 @@ class CompletionStageMappingPublisherTest extends Specification {

when:
Publisher<Integer> rxIntegers = Flowable.range(0, 10)
Executor executor = CompletableFuture.delayedExecutor(50, TimeUnit.MILLISECONDS)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are failing because CompletableFuture.delayedExecutor is a Java 9 API


def mapper = new Function<Integer, CompletionStage<String>>() {
@Override
CompletionStage<String> apply(Integer integer) {
return CompletableFuture.completedFuture(String.valueOf(integer))
return CompletableFuture.supplyAsync(new Supplier<String>() {
@Override
String get() {
return String.valueOf(integer)
}
}, executor)
}
}
Publisher<String> rxStrings = new CompletionStageMappingPublisher<String, Integer>(rxIntegers, mapper)
Expand All @@ -51,7 +66,8 @@ class CompletionStageMappingPublisherTest extends Specification {
def capturingSubscriber2 = new CapturingSubscriber<>()
rxStrings.subscribe(capturingSubscriber1)
rxStrings.subscribe(capturingSubscriber2)

capturingSubscriber1.done()
capturingSubscriber2.done()
then:

capturingSubscriber1.events.size() == 10
Expand Down