Skip to content

Commit 4d961d0

Browse files
authored
Don't send a RST frame when closing the stream in a write future whil… (#13973)
…e processing inbound frames. Motiviation: Due a bug in netty we would send a RST frame in some cases even tho we correctly received the endOfStream already. This is not necessary and might even confuse the remote peer. Modifications: - Keep track of if we received endOfStream and send endOfStream in our Channel implementation and only send a RST frame if this is not the case during close - Add unit tests Result: Don't send RST frame if we received endOfStream and send endOfStream.
1 parent e19c91f commit 4d961d0

3 files changed

Lines changed: 240 additions & 2 deletions

File tree

codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ void fireChildRead(Http2Frame frame) {
597597
// otherwise we would have drained it from the queue and processed it during the read cycle.
598598
assert inboundBuffer == null || inboundBuffer.isEmpty();
599599
final RecvByteBufAllocator.Handle allocHandle = unsafe.recvBufAllocHandle();
600+
600601
unsafe.doRead0(frame, allocHandle);
601602
// We currently don't need to check for readEOS because the parent channel and child channel are limited
602603
// to the same EventLoop thread. There are a limited number of frame types that may come after EOS is
@@ -635,6 +636,9 @@ private final class Http2ChannelUnsafe implements Unsafe {
635636
private boolean closeInitiated;
636637
private boolean readEOS;
637638

639+
private boolean receivedEndOfStream;
640+
private boolean sentEndOfStream;
641+
638642
@Override
639643
public void connect(final SocketAddress remoteAddress,
640644
SocketAddress localAddress, final ChannelPromise promise) {
@@ -731,7 +735,9 @@ public void operationComplete(ChannelFuture future) {
731735

732736
// Only ever send a reset frame if the connection is still alive and if the stream was created before
733737
// as otherwise we may send a RST on a stream in an invalid state and cause a connection error.
734-
if (parent().isActive() && !readEOS && isStreamIdValid(stream.id())) {
738+
if (parent().isActive() && isStreamIdValid(stream.id()) &&
739+
// Also ensure the stream was never "closed" before.
740+
!readEOS && !(receivedEndOfStream && sentEndOfStream)) {
735741
Http2StreamFrame resetFrame = new DefaultHttp2ResetFrame(error).stream(stream());
736742
write(resetFrame, unsafe().voidPromise());
737743
flush();
@@ -953,7 +959,6 @@ void doRead0(Http2Frame frame, RecvByteBufAllocator.Handle allocHandle) {
953959
final int bytes;
954960
if (frame instanceof Http2DataFrame) {
955961
bytes = ((Http2DataFrame) frame).initialFlowControlledBytes();
956-
957962
// It is important that we increment the flowControlledBytes before we call fireChannelRead(...)
958963
// as it may cause a read() that will call updateLocalWindowIfNeeded() and we need to ensure
959964
// in this case that we accounted for it.
@@ -963,6 +968,11 @@ void doRead0(Http2Frame frame, RecvByteBufAllocator.Handle allocHandle) {
963968
} else {
964969
bytes = MIN_HTTP2_FRAME_SIZE;
965970
}
971+
972+
// Let's keep track of what we received as the stream state itself will only be updated once the frame
973+
// was dispatched for reading which might cause problems if we try to close the channel in a write future.
974+
receivedEndOfStream |= isEndOfStream(frame);
975+
966976
// Update before firing event through the pipeline to be consistent with other Channel implementation.
967977
allocHandle.attemptedBytesRead(bytes);
968978
allocHandle.lastBytesRead(bytes);
@@ -1003,6 +1013,16 @@ public void write(Object msg, final ChannelPromise promise) {
10031013
}
10041014
}
10051015

1016+
private boolean isEndOfStream(Http2Frame frame) {
1017+
if (frame instanceof Http2HeadersFrame) {
1018+
return ((Http2HeadersFrame) frame).isEndStream();
1019+
}
1020+
if (frame instanceof Http2DataFrame) {
1021+
return ((Http2DataFrame) frame).isEndStream();
1022+
}
1023+
return false;
1024+
}
1025+
10061026
private void writeHttp2StreamFrame(Http2StreamFrame frame, final ChannelPromise promise) {
10071027
if (!firstFrameWritten && !isStreamIdValid(stream().id()) && !(frame instanceof Http2HeadersFrame)) {
10081028
ReferenceCountUtil.release(frame);
@@ -1019,6 +1039,9 @@ private void writeHttp2StreamFrame(Http2StreamFrame frame, final ChannelPromise
10191039
firstWrite = firstFrameWritten = true;
10201040
}
10211041

1042+
// Let's keep track of what we send as the stream state itself will only be updated once the frame
1043+
// was written which might cause problems if we try to close the channel in a write future.
1044+
sentEndOfStream |= isEndOfStream(frame);
10221045
ChannelFuture f = write0(parentContext(), frame);
10231046
if (f.isDone()) {
10241047
if (firstWrite) {

codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTest.java

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import io.netty.channel.embedded.EmbeddedChannel;
2929
import io.netty.channel.socket.ChannelInputShutdownReadComplete;
3030
import io.netty.channel.socket.ChannelOutputShutdownEvent;
31+
import io.netty.handler.codec.UnsupportedMessageTypeException;
3132
import io.netty.handler.codec.http.HttpHeaderNames;
3233
import io.netty.handler.codec.http.HttpMethod;
3334
import io.netty.handler.codec.http.HttpScheme;
@@ -36,12 +37,15 @@
3637
import io.netty.handler.ssl.SslCloseCompletionEvent;
3738
import io.netty.util.AsciiString;
3839
import io.netty.util.AttributeKey;
40+
import io.netty.util.ReferenceCountUtil;
3941
import org.junit.jupiter.api.AfterEach;
4042
import org.junit.jupiter.api.BeforeEach;
4143
import org.junit.jupiter.api.Test;
4244
import org.junit.jupiter.api.function.Executable;
4345
import org.junit.jupiter.params.ParameterizedTest;
46+
import org.junit.jupiter.params.provider.EnumSource;
4447
import org.junit.jupiter.params.provider.MethodSource;
48+
import org.junit.jupiter.params.provider.ValueSource;
4549
import org.mockito.ArgumentMatcher;
4650
import org.mockito.Mockito;
4751
import org.mockito.invocation.InvocationOnMock;
@@ -71,6 +75,7 @@
7175
import static org.junit.jupiter.api.Assertions.assertNull;
7276
import static org.junit.jupiter.api.Assertions.assertThrows;
7377
import static org.junit.jupiter.api.Assertions.assertTrue;
78+
import static org.junit.jupiter.api.Assertions.fail;
7479
import static org.mockito.ArgumentMatchers.any;
7580
import static org.mockito.ArgumentMatchers.anyBoolean;
7681
import static org.mockito.ArgumentMatchers.anyInt;
@@ -229,6 +234,86 @@ public void headerAndDataFramesShouldBeDelivered() {
229234
assertNull(inboundHandler.readInbound());
230235
}
231236

237+
enum RstFrameTestMode {
238+
HEADERS_END_STREAM,
239+
DATA_END_STREAM,
240+
TRAILERS_END_STREAM;
241+
}
242+
@ParameterizedTest
243+
@EnumSource(RstFrameTestMode.class)
244+
void noRstFrameSentOnCloseViaListener(final RstFrameTestMode mode) throws Exception {
245+
LastInboundHandler inboundHandler = new LastInboundHandler() {
246+
private boolean headersReceived;
247+
@Override
248+
public void channelRead(ChannelHandlerContext ctx, Object msg) {
249+
try {
250+
final boolean endStream;
251+
if (msg instanceof Http2HeadersFrame) {
252+
endStream = ((Http2HeadersFrame) msg).isEndStream();
253+
switch (mode) {
254+
case HEADERS_END_STREAM:
255+
assertFalse(headersReceived);
256+
assertTrue(endStream);
257+
break;
258+
case TRAILERS_END_STREAM:
259+
if (headersReceived) {
260+
assertTrue(endStream);
261+
} else {
262+
assertFalse(endStream);
263+
}
264+
break;
265+
case DATA_END_STREAM:
266+
assertFalse(endStream);
267+
break;
268+
default:
269+
fail();
270+
}
271+
headersReceived = true;
272+
} else if (msg instanceof Http2DataFrame) {
273+
endStream = ((Http2DataFrame) msg).isEndStream();
274+
switch (mode) {
275+
case HEADERS_END_STREAM:
276+
fail();
277+
break;
278+
case TRAILERS_END_STREAM:
279+
assertFalse(endStream);
280+
break;
281+
case DATA_END_STREAM:
282+
assertTrue(endStream);
283+
break;
284+
default:
285+
fail();
286+
}
287+
} else {
288+
throw new UnsupportedMessageTypeException(msg);
289+
}
290+
if (endStream) {
291+
ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers(), true, 0))
292+
.addListener(ChannelFutureListener.CLOSE);
293+
}
294+
} finally {
295+
ReferenceCountUtil.release(msg);
296+
}
297+
}
298+
};
299+
300+
Http2StreamChannel channel = newInboundStream(3, mode == RstFrameTestMode.HEADERS_END_STREAM, inboundHandler);
301+
if (mode != RstFrameTestMode.HEADERS_END_STREAM) {
302+
frameInboundWriter.writeInboundData(
303+
channel.stream().id(), bb("something"), 0, mode == RstFrameTestMode.DATA_END_STREAM);
304+
if (mode != RstFrameTestMode.DATA_END_STREAM) {
305+
frameInboundWriter.writeInboundHeaders(channel.stream().id(), new DefaultHttp2Headers(), 0, true);
306+
}
307+
}
308+
channel.closeFuture().syncUninterruptibly();
309+
310+
// We should never produce a RST frame in this case as we received the endOfStream before we write a frame
311+
// with the endOfStream flag.
312+
verify(frameWriter, never()).writeRstStream(eqCodecCtx(),
313+
eqStreamId(channel), anyLong(), anyChannelPromise());
314+
inboundHandler.checkException();
315+
}
316+
232317
@Test
233318
public void headerMultipleContentLengthValidationShouldPropagate() {
234319
headerMultipleContentLengthValidationShouldPropagate(false);

codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,24 @@
2727
import io.netty.channel.ChannelInboundHandlerAdapter;
2828
import io.netty.channel.ChannelInitializer;
2929
import io.netty.channel.ChannelOption;
30+
import io.netty.channel.ChannelPipeline;
31+
import io.netty.channel.DefaultEventLoop;
3032
import io.netty.channel.EventLoopGroup;
33+
import io.netty.channel.SimpleChannelInboundHandler;
34+
import io.netty.channel.local.LocalAddress;
35+
import io.netty.channel.local.LocalChannel;
36+
import io.netty.channel.local.LocalServerChannel;
3137
import io.netty.channel.nio.NioEventLoopGroup;
3238
import io.netty.channel.socket.nio.NioServerSocketChannel;
3339
import io.netty.channel.socket.nio.NioSocketChannel;
40+
import io.netty.handler.codec.http.DefaultFullHttpRequest;
41+
import io.netty.handler.codec.http.DefaultFullHttpResponse;
42+
import io.netty.handler.codec.http.FullHttpRequest;
43+
import io.netty.handler.codec.http.FullHttpResponse;
44+
import io.netty.handler.codec.http.HttpMethod;
45+
import io.netty.handler.codec.http.HttpObjectAggregator;
46+
import io.netty.handler.codec.http.HttpResponseStatus;
47+
import io.netty.handler.codec.http.HttpVersion;
3448
import io.netty.handler.ssl.ApplicationProtocolConfig;
3549
import io.netty.handler.ssl.ApplicationProtocolNames;
3650
import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler;
@@ -71,7 +85,10 @@
7185
import java.util.concurrent.atomic.AtomicInteger;
7286
import java.util.concurrent.atomic.AtomicReference;
7387

88+
import static io.netty.handler.codec.http2.Http2FrameCodecBuilder.forClient;
89+
import static io.netty.handler.codec.http2.Http2FrameCodecBuilder.forServer;
7490
import static java.util.concurrent.TimeUnit.MILLISECONDS;
91+
import static java.util.concurrent.TimeUnit.SECONDS;
7592
import static org.junit.jupiter.api.Assertions.assertEquals;
7693
import static org.junit.jupiter.api.Assertions.assertFalse;
7794
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -743,4 +760,117 @@ public boolean isSharable() {
743760
}
744761
}
745762
}
763+
764+
@Test
765+
public void testServerCloseShouldNotSendResetIfClientSentEOS() throws Exception {
766+
EventLoopGroup group = null;
767+
Channel serverChannel = null;
768+
Channel clientChannel = null;
769+
Channel clientStreamChannel = null;
770+
try {
771+
final CountDownLatch clientReceivedResponseLatch = new CountDownLatch(1);
772+
final CountDownLatch resetFrameLatch = new CountDownLatch(1);
773+
group = new DefaultEventLoop();
774+
LocalAddress serverAddress = new LocalAddress(getClass().getName());
775+
ServerBootstrap sb = new ServerBootstrap()
776+
.channel(LocalServerChannel.class)
777+
.group(group)
778+
.childHandler(new ChannelInitializer<Channel>() {
779+
@Override
780+
protected void initChannel(Channel ch) {
781+
ChannelPipeline pipeline = ch.pipeline();
782+
pipeline.addLast(forServer().build());
783+
pipeline.addLast(new Http2FrameIgnore<Http2SettingsFrame>(Http2SettingsFrame.class));
784+
pipeline.addLast(new Http2FrameIgnore<Http2SettingsAckFrame>(Http2SettingsAckFrame.class));
785+
pipeline.addLast(new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
786+
@Override
787+
protected void initChannel(Http2StreamChannel ch) {
788+
ChannelPipeline pipeline = ch.pipeline();
789+
pipeline.addLast(new Http2StreamFrameToHttpObjectCodec(true, true));
790+
pipeline.addLast(new HttpObjectAggregator(16384));
791+
pipeline.addLast(new SimpleChannelInboundHandler<FullHttpRequest>() {
792+
@Override
793+
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) {
794+
ctx.writeAndFlush(
795+
new DefaultFullHttpResponse(
796+
msg.protocolVersion(), HttpResponseStatus.OK,
797+
Unpooled.copiedBuffer("hello", CharsetUtil.US_ASCII)))
798+
.addListeners(ChannelFutureListener.CLOSE);
799+
}
800+
});
801+
}
802+
}));
803+
}
804+
});
805+
serverChannel = sb.bind(serverAddress).sync().channel();
806+
807+
Bootstrap cb = new Bootstrap()
808+
.channel(LocalChannel.class)
809+
.group(group)
810+
.handler(new ChannelInitializer<Channel>() {
811+
@Override
812+
protected void initChannel(Channel ch) {
813+
ChannelPipeline pipeline = ch.pipeline();
814+
pipeline.addLast(forClient().build());
815+
pipeline.addLast(new Http2FrameIgnore<Http2SettingsFrame>(Http2SettingsFrame.class));
816+
pipeline.addLast(new Http2FrameIgnore<Http2SettingsAckFrame>(Http2SettingsAckFrame.class));
817+
pipeline.addLast(new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
818+
@Override
819+
protected void initChannel(Http2StreamChannel ch) {
820+
// noop
821+
}
822+
}));
823+
}
824+
});
825+
826+
clientChannel = cb.connect(serverAddress).sync().channel();
827+
clientStreamChannel = new Http2StreamChannelBootstrap(clientChannel)
828+
.handler(new ChannelInitializer<Channel>() {
829+
@Override
830+
protected void initChannel(Channel ch) {
831+
ChannelPipeline pipeline = ch.pipeline();
832+
pipeline.addLast(new Http2StreamFrameToHttpObjectCodec(false, true));
833+
pipeline.addLast(new HttpObjectAggregator(16384));
834+
pipeline.addLast(new SimpleChannelInboundHandler<FullHttpResponse>() {
835+
@Override
836+
protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) {
837+
clientReceivedResponseLatch.countDown();
838+
}
839+
});
840+
}
841+
})
842+
.open().sync().get();
843+
844+
clientStreamChannel.writeAndFlush(
845+
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/test/")).sync();
846+
847+
assertTrue(clientReceivedResponseLatch.await(3, SECONDS));
848+
849+
// The server should NOT send any RST_STREAM frame.
850+
assertFalse(resetFrameLatch.await(1, SECONDS));
851+
} finally {
852+
if (clientStreamChannel != null) {
853+
clientStreamChannel.close().syncUninterruptibly();
854+
}
855+
if (clientChannel != null) {
856+
clientChannel.close().syncUninterruptibly();
857+
}
858+
if (serverChannel != null) {
859+
serverChannel.close().syncUninterruptibly();
860+
}
861+
if (group != null) {
862+
group.shutdownGracefully(0, 3, SECONDS);
863+
}
864+
}
865+
}
866+
867+
private static final class Http2FrameIgnore<T extends Http2Frame> extends SimpleChannelInboundHandler<T> {
868+
Http2FrameIgnore(Class<? extends T> inboundMessageType) {
869+
super(inboundMessageType);
870+
}
871+
872+
@Override
873+
protected void channelRead0(ChannelHandlerContext ctx, T msg) {
874+
}
875+
}
746876
}

0 commit comments

Comments
 (0)