Skip to content

Commit f761ab5

Browse files
authored
Correctly keep track of validExtensions per request / response (#13180)
Motivation: At the moment we not correctly reset state between different request / response pairs. This can lead to situations when invalid extensions are used. Modifications: - Use a Queue to keep track of extensions per request / response Result: Fixes #13176
1 parent fd9694e commit f761ab5

2 files changed

Lines changed: 76 additions & 10 deletions

File tree

codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
import io.netty.handler.codec.http.HttpResponse;
2929
import io.netty.handler.codec.http.HttpResponseStatus;
3030

31+
import java.util.ArrayDeque;
3132
import java.util.ArrayList;
3233
import java.util.Arrays;
34+
import java.util.Collections;
3335
import java.util.Iterator;
3436
import java.util.List;
37+
import java.util.Queue;
3538

3639
/**
3740
* This handler negotiates and initializes the WebSocket Extensions.
@@ -47,7 +50,8 @@ public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
4750

4851
private final List<WebSocketServerExtensionHandshaker> extensionHandshakers;
4952

50-
private List<WebSocketServerExtension> validExtensions;
53+
private final Queue<List<WebSocketServerExtension>> validExtensions =
54+
new ArrayDeque<List<WebSocketServerExtension>>(4);
5155

5256
/**
5357
* Constructor
@@ -63,6 +67,7 @@ public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... ext
6367
@Override
6468
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
6569
if (msg instanceof HttpRequest) {
70+
List<WebSocketServerExtension> validExtensionsList = null;
6671
HttpRequest request = (HttpRequest) msg;
6772

6873
if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
@@ -85,15 +90,20 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
8590
}
8691

8792
if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
88-
if (validExtensions == null) {
89-
validExtensions = new ArrayList<WebSocketServerExtension>(1);
93+
if (validExtensionsList == null) {
94+
validExtensionsList = new ArrayList<WebSocketServerExtension>(1);
9095
}
9196
rsv = rsv | validExtension.rsv();
92-
validExtensions.add(validExtension);
97+
validExtensionsList.add(validExtension);
9398
}
9499
}
95100
}
96101
}
102+
103+
if (validExtensionsList == null) {
104+
validExtensionsList = Collections.emptyList();
105+
}
106+
validExtensions.offer(validExtensionsList);
97107
}
98108

99109
super.channelRead(ctx, msg);
@@ -102,28 +112,29 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
102112
@Override
103113
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
104114
if (msg instanceof HttpResponse) {
115+
List<WebSocketServerExtension> validExtensionsList = validExtensions.poll();
105116
HttpResponse httpResponse = (HttpResponse) msg;
106117
//checking the status is faster than looking at headers
107118
//so we do this first
108119
if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(httpResponse.status())) {
109-
handlePotentialUpgrade(ctx, promise, httpResponse);
120+
handlePotentialUpgrade(ctx, promise, httpResponse, validExtensionsList);
110121
}
111122
}
112123

113124
super.write(ctx, msg, promise);
114125
}
115126

116127
private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
117-
ChannelPromise promise, HttpResponse httpResponse) {
128+
ChannelPromise promise, HttpResponse httpResponse,
129+
final List<WebSocketServerExtension> validExtensionsList) {
118130
HttpHeaders headers = httpResponse.headers();
119131

120132
if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
121-
122-
if (validExtensions != null) {
133+
if (validExtensionsList != null && !validExtensionsList.isEmpty()) {
123134
String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
124135
List<WebSocketExtensionData> extraExtensions =
125136
new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
126-
for (WebSocketServerExtension extension : validExtensions) {
137+
for (WebSocketServerExtension extension : validExtensionsList) {
127138
extraExtensions.add(extension.newReponseData());
128139
}
129140
String newHeaderValue = WebSocketExtensionUtil
@@ -132,7 +143,7 @@ private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
132143
@Override
133144
public void operationComplete(ChannelFuture future) {
134145
if (future.isSuccess()) {
135-
for (WebSocketServerExtension extension : validExtensions) {
146+
for (WebSocketServerExtension extension : validExtensionsList) {
136147
WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
137148
WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
138149
String name = ctx.name();

codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Collections;
2626
import java.util.List;
2727

28+
import io.netty.handler.codec.http.LastHttpContent;
2829
import org.junit.jupiter.api.Test;
2930

3031
import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionTestUtil.*;
@@ -41,11 +42,18 @@ public class WebSocketServerExtensionHandlerTest {
4142
mock(WebSocketServerExtensionHandshaker.class, "mainHandshaker");
4243
WebSocketServerExtensionHandshaker fallbackHandshakerMock =
4344
mock(WebSocketServerExtensionHandshaker.class, "fallbackHandshaker");
45+
46+
WebSocketServerExtensionHandshaker main2HandshakerMock =
47+
mock(WebSocketServerExtensionHandshaker.class, "main2Handshaker");
4448
WebSocketServerExtension mainExtensionMock =
4549
mock(WebSocketServerExtension.class, "mainExtension");
50+
4651
WebSocketServerExtension fallbackExtensionMock =
4752
mock(WebSocketServerExtension.class, "fallbackExtension");
4853

54+
WebSocketServerExtension main2ExtensionMock =
55+
mock(WebSocketServerExtension.class, "main2Extension");
56+
4957
@Test
5058
public void testMainSuccess() {
5159
// initialize
@@ -229,4 +237,51 @@ public void testExtensionHandlerNotRemovedByFailureWritePromise() {
229237
assertNotNull(ch.pipeline().context(extensionHandler));
230238
assertTrue(ch.finish());
231239
}
240+
241+
@Test
242+
public void testExtensionMultipleRequests() {
243+
// initialize
244+
when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main")))
245+
.thenReturn(mainExtensionMock);
246+
247+
when(mainExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1);
248+
when(mainExtensionMock.newReponseData()).thenReturn(
249+
new WebSocketExtensionData("main", Collections.<String, String>emptyMap()));
250+
when(mainExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder());
251+
when(mainExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder());
252+
253+
when(main2HandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main2")))
254+
.thenReturn(main2ExtensionMock);
255+
256+
when(main2ExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1);
257+
when(main2ExtensionMock.newReponseData()).thenReturn(
258+
new WebSocketExtensionData("main2", Collections.<String, String>emptyMap()));
259+
when(main2ExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder());
260+
when(main2ExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder());
261+
262+
// execute
263+
WebSocketServerExtensionHandler extensionHandler =
264+
new WebSocketServerExtensionHandler(mainHandshakerMock, main2HandshakerMock);
265+
EmbeddedChannel ch = new EmbeddedChannel(extensionHandler);
266+
267+
HttpRequest req = newUpgradeRequest("main");
268+
assertTrue(ch.writeInbound(req));
269+
assertTrue(ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT));
270+
271+
HttpRequest req2 = newUpgradeRequest("main2");
272+
assertTrue(ch.writeInbound(req2));
273+
assertTrue(ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT));
274+
275+
HttpResponse res = newUpgradeResponse(null);
276+
assertTrue(ch.writeOutbound(res));
277+
assertTrue(ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT));
278+
279+
res = ch.readOutbound();
280+
assertEquals("main", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
281+
LastHttpContent content = ch.readOutbound();
282+
content.release();
283+
284+
assertNull(ch.pipeline().context(extensionHandler));
285+
assertTrue(ch.finishAndReleaseAll());
286+
}
232287
}

0 commit comments

Comments
 (0)