Skip to content

Commit db1c860

Browse files
committed
Add PerMessageDeflate Extension support, see TooTallNate#574
1 parent a2c63db commit db1c860

File tree

3 files changed

+375
-0
lines changed

3 files changed

+375
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import org.java_websocket.WebSocket;
2+
import org.java_websocket.client.WebSocketClient;
3+
import org.java_websocket.drafts.Draft;
4+
import org.java_websocket.drafts.Draft_6455;
5+
import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension;
6+
import org.java_websocket.handshake.ClientHandshake;
7+
import org.java_websocket.handshake.ServerHandshake;
8+
import org.java_websocket.server.WebSocketServer;
9+
10+
import java.net.InetSocketAddress;
11+
import java.net.URI;
12+
import java.net.URISyntaxException;
13+
import java.util.Collections;
14+
15+
/**
16+
* This class only serves the purpose of showing how to enable PerMessageDeflateExtension for both server and client sockets.<br>
17+
* Extensions are required to be registered in
18+
* @see Draft objects and both
19+
* @see WebSocketClient and
20+
* @see WebSocketServer accept a
21+
* @see Draft object in their constructors.
22+
* This example shows how to achieve it for both server and client sockets.
23+
* Once the connection has been established, PerMessageDeflateExtension will be enabled
24+
* and any messages (binary or text) will be compressed/decompressed automatically.<br>
25+
* Since no additional code is required when sending or receiving messages, this example skips those parts.
26+
*/
27+
public class PerMessageDeflateExample {
28+
29+
private static final Draft perMessageDeflateDraft = new Draft_6455(new PerMessageDeflateExtension());
30+
private static final int PORT = 8887;
31+
32+
private static class DeflateClient extends WebSocketClient {
33+
34+
public DeflateClient() throws URISyntaxException {
35+
super(new URI("ws://localhost:" + PORT), perMessageDeflateDraft);
36+
}
37+
38+
@Override
39+
public void onOpen(ServerHandshake handshakedata) { }
40+
41+
@Override
42+
public void onMessage(String message) { }
43+
44+
@Override
45+
public void onClose(int code, String reason, boolean remote) { }
46+
47+
@Override
48+
public void onError(Exception ex) { }
49+
}
50+
51+
private static class DeflateServer extends WebSocketServer {
52+
53+
public DeflateServer() {
54+
super(new InetSocketAddress(PORT), Collections.singletonList(perMessageDeflateDraft));
55+
}
56+
57+
@Override
58+
public void onOpen(WebSocket conn, ClientHandshake handshake) { }
59+
60+
@Override
61+
public void onClose(WebSocket conn, int code, String reason, boolean remote) { }
62+
63+
@Override
64+
public void onMessage(WebSocket conn, String message) { }
65+
66+
@Override
67+
public void onError(WebSocket conn, Exception ex) { }
68+
69+
@Override
70+
public void onStart() { }
71+
}
72+
}
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package org.java_websocket.extensions.permessage_deflate;
2+
3+
import org.java_websocket.enums.Opcode;
4+
import org.java_websocket.exceptions.InvalidDataException;
5+
import org.java_websocket.exceptions.InvalidFrameException;
6+
import org.java_websocket.extensions.CompressionExtension;
7+
import org.java_websocket.extensions.IExtension;
8+
import org.java_websocket.framing.*;
9+
10+
import java.io.ByteArrayOutputStream;
11+
import java.nio.ByteBuffer;
12+
import java.util.zip.DataFormatException;
13+
import java.util.zip.Deflater;
14+
import java.util.zip.Inflater;
15+
16+
public class PerMessageDeflateExtension extends CompressionExtension {
17+
18+
// Name of the extension as registered by IETF https://tools.ietf.org/html/rfc7692#section-9.
19+
private static final String EXTENSION_REGISTERED_NAME = "permessage-deflate";
20+
21+
// Below values are defined for convenience. They are not used in the compression/decompression phase.
22+
// They may be needed during the extension-negotiation offer in the future.
23+
private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
24+
private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
25+
private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
26+
private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
27+
private static final boolean serverNoContextTakeover = true;
28+
private static final boolean clientNoContextTakeover = true;
29+
private static final int serverMaxWindowBits = 1 << 15;
30+
private static final int clientMaxWindowBits = 1 << 15;
31+
32+
private static final byte[] TAIL_BYTES = {0x00, 0x00, (byte)0xFF, (byte)0xFF};
33+
private static final int BUFFER_SIZE = 1 << 10;
34+
35+
/*
36+
An endpoint uses the following algorithm to decompress a message.
37+
1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
38+
payload of the message.
39+
2. Decompress the resulting data using DEFLATE.
40+
See, https://tools.ietf.org/html/rfc7692#section-7.2.2
41+
*/
42+
@Override
43+
public void decodeFrame(Framedata inputFrame) throws InvalidDataException {
44+
// Only DataFrames can be decompressed.
45+
if(!(inputFrame instanceof DataFrame))
46+
return;
47+
48+
// RSV1 bit must be set only for the first frame.
49+
if(inputFrame.getOpcode() == Opcode.CONTINUOUS && inputFrame.isRSV1())
50+
throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, "RSV1 bit can only be set for the first frame.");
51+
52+
// Decompressed output buffer.
53+
ByteArrayOutputStream output = new ByteArrayOutputStream();
54+
Inflater inflater = new Inflater(true);
55+
try {
56+
decompress(inflater, inputFrame.getPayloadData().array(), output);
57+
// Decompress 4 bytes of 0x00 0x00 0xff 0xff as if they were appended to the end of message.
58+
if(inputFrame.isFin())
59+
decompress(inflater, TAIL_BYTES, output);
60+
} catch (DataFormatException e) {
61+
throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, "Given data format couldn't be decompressed.");
62+
}finally {
63+
inflater.end();
64+
}
65+
66+
// Set frames payload to the new decompressed data.
67+
((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(output.toByteArray()));
68+
}
69+
70+
private void decompress(Inflater inflater, byte[] data, ByteArrayOutputStream outputBuffer) throws DataFormatException{
71+
inflater.setInput(data);
72+
byte[] buffer = new byte[BUFFER_SIZE];
73+
74+
int bytesInflated;
75+
while((bytesInflated = inflater.inflate(buffer)) > 0){
76+
outputBuffer.write(buffer, 0, bytesInflated);
77+
}
78+
}
79+
80+
@Override
81+
public void encodeFrame(Framedata inputFrame) {
82+
// Only DataFrames can be decompressed.
83+
if(!(inputFrame instanceof DataFrame))
84+
return;
85+
86+
// Only the first frame's RSV1 must be set.
87+
if(!(inputFrame instanceof ContinuousFrame))
88+
((DataFrame) inputFrame).setRSV1(true);
89+
90+
Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
91+
deflater.setInput(inputFrame.getPayloadData().array());
92+
deflater.finish();
93+
94+
// Compressed output buffer.
95+
ByteArrayOutputStream output = new ByteArrayOutputStream();
96+
// Temporary buffer to hold compressed output.
97+
byte[] buffer = new byte[1024];
98+
int bytesCompressed;
99+
while((bytesCompressed = deflater.deflate(buffer)) > 0) {
100+
output.write(buffer, 0, bytesCompressed);
101+
}
102+
deflater.end();
103+
104+
byte outputBytes[] = output.toByteArray();
105+
int outputLength = outputBytes.length;
106+
/*
107+
https://tools.ietf.org/html/rfc7692#section-7.2.1 states that if the final fragment's compressed
108+
payload ends with 0x00 0x00 0xff 0xff, they should be removed.
109+
To simulate removal, we just pass 4 bytes less to the new payload
110+
if the frame is final and outputBytes ends with 0x00 0x00 0xff 0xff.
111+
*/
112+
if(inputFrame.isFin() && endsWithTail(outputBytes))
113+
outputLength -= TAIL_BYTES.length;
114+
115+
// Set frames payload to the new compressed data.
116+
((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(outputBytes, 0, outputLength));
117+
}
118+
119+
private boolean endsWithTail(byte[] data){
120+
if(data.length < 4)
121+
return false;
122+
123+
int length = data.length;
124+
for(int i = 0; i <= TAIL_BYTES.length; i--){
125+
if(TAIL_BYTES[i] != data[length - TAIL_BYTES.length + i])
126+
return false;
127+
}
128+
129+
return true;
130+
}
131+
132+
@Override
133+
public boolean acceptProvidedExtensionAsServer(String inputExtension) {
134+
String[] requestedExtensions = inputExtension.split(",");
135+
for(String extension : requestedExtensions)
136+
if(EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extension.trim()))
137+
return true;
138+
139+
return false;
140+
}
141+
142+
@Override
143+
public boolean acceptProvidedExtensionAsClient(String inputExtension) {
144+
String[] requestedExtensions = inputExtension.split(",");
145+
for(String extension : requestedExtensions)
146+
if(EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extension.trim()))
147+
return true;
148+
149+
return false;
150+
}
151+
152+
@Override
153+
public String getProvidedExtensionAsClient() {
154+
return EXTENSION_REGISTERED_NAME;
155+
}
156+
157+
@Override
158+
public String getProvidedExtensionAsServer() {
159+
return EXTENSION_REGISTERED_NAME;
160+
}
161+
162+
@Override
163+
public IExtension copyInstance() {
164+
return new PerMessageDeflateExtension();
165+
}
166+
167+
/**
168+
* This extension requires the RSV1 bit to be set only for the first frame.
169+
* If the frame is type is CONTINUOUS, RSV1 bit must be unset.
170+
*/
171+
@Override
172+
public void isFrameValid(Framedata inputFrame) throws InvalidDataException {
173+
if((inputFrame instanceof TextFrame || inputFrame instanceof BinaryFrame) && !inputFrame.isRSV1())
174+
throw new InvalidFrameException("RSV1 bit must be set for DataFrames.");
175+
if((inputFrame instanceof ContinuousFrame) && (inputFrame.isRSV1() || inputFrame.isRSV2() || inputFrame.isRSV3()))
176+
throw new InvalidFrameException( "bad rsv RSV1: " + inputFrame.isRSV1() + " RSV2: " + inputFrame.isRSV2() + " RSV3: " + inputFrame.isRSV3() );
177+
super.isFrameValid(inputFrame);
178+
}
179+
180+
@Override
181+
public String toString() {
182+
return "PerMessageDeflateExtension";
183+
}
184+
185+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package org.java_websocket.extensions;
2+
3+
import org.java_websocket.exceptions.InvalidDataException;
4+
import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension;
5+
import org.java_websocket.framing.ContinuousFrame;
6+
import org.java_websocket.framing.TextFrame;
7+
import org.junit.Test;
8+
9+
import java.nio.ByteBuffer;
10+
11+
import static org.junit.Assert.*;
12+
13+
public class PerMessageDeflateExtensionTest {
14+
15+
@Test
16+
public void testDecodeFrame() throws InvalidDataException {
17+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
18+
String str = "This is a highly compressable text"
19+
+ "This is a highly compressable text"
20+
+ "This is a highly compressable text"
21+
+ "This is a highly compressable text"
22+
+ "This is a highly compressable text";
23+
byte[] message = str.getBytes();
24+
TextFrame frame = new TextFrame();
25+
frame.setPayload(ByteBuffer.wrap(message));
26+
deflateExtension.encodeFrame(frame);
27+
deflateExtension.decodeFrame(frame);
28+
assertArrayEquals(message, frame.getPayloadData().array());
29+
}
30+
31+
@Test
32+
public void testEncodeFrame() {
33+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
34+
String str = "This is a highly compressable text"
35+
+ "This is a highly compressable text"
36+
+ "This is a highly compressable text"
37+
+ "This is a highly compressable text"
38+
+ "This is a highly compressable text";
39+
byte[] message = str.getBytes();
40+
TextFrame frame = new TextFrame();
41+
frame.setPayload(ByteBuffer.wrap(message));
42+
deflateExtension.encodeFrame(frame);
43+
assertTrue(message.length > frame.getPayloadData().array().length);
44+
}
45+
46+
@Test
47+
public void testAcceptProvidedExtensionAsServer() {
48+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
49+
assertTrue(deflateExtension.acceptProvidedExtensionAsServer("permessage-deflate"));
50+
assertTrue(deflateExtension.acceptProvidedExtensionAsServer("some-other-extension, permessage-deflate"));
51+
assertFalse(deflateExtension.acceptProvidedExtensionAsServer("wrong-permessage-deflate"));
52+
}
53+
54+
@Test
55+
public void testAcceptProvidedExtensionAsClient() {
56+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
57+
assertTrue(deflateExtension.acceptProvidedExtensionAsClient("permessage-deflate"));
58+
assertTrue(deflateExtension.acceptProvidedExtensionAsClient("some-other-extension, permessage-deflate"));
59+
assertFalse(deflateExtension.acceptProvidedExtensionAsClient("wrong-permessage-deflate"));
60+
}
61+
62+
@Test
63+
public void testIsFrameValid() {
64+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
65+
TextFrame frame = new TextFrame();
66+
try {
67+
deflateExtension.isFrameValid(frame);
68+
fail("Frame not valid. RSV1 must be set.");
69+
} catch (Exception e) {
70+
//
71+
}
72+
frame.setRSV1(true);
73+
try {
74+
deflateExtension.isFrameValid(frame);
75+
} catch (Exception e) {
76+
fail("Frame is valid.");
77+
}
78+
frame.setRSV2(true);
79+
try {
80+
deflateExtension.isFrameValid(frame);
81+
fail("Only RSV1 bit must be set.");
82+
} catch (Exception e) {
83+
//
84+
}
85+
ContinuousFrame contFrame = new ContinuousFrame();
86+
contFrame.setRSV1(true);
87+
try {
88+
deflateExtension.isFrameValid(contFrame);
89+
fail("RSV1 must only be set for first fragments.Continuous frames can't have RSV1 bit set.");
90+
} catch (Exception e) {
91+
//
92+
}
93+
contFrame.setRSV1(false);
94+
try {
95+
deflateExtension.isFrameValid(contFrame);
96+
} catch (Exception e) {
97+
fail("Continuous frame is valid.");
98+
}
99+
}
100+
101+
@Test
102+
public void testGetProvidedExtensionAsClient() {
103+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
104+
assertEquals( "permessage-deflate", deflateExtension.getProvidedExtensionAsClient() );
105+
}
106+
107+
@Test
108+
public void testGetProvidedExtensionAsServer() {
109+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
110+
assertEquals( "permessage-deflate", deflateExtension.getProvidedExtensionAsServer() );
111+
}
112+
113+
@Test
114+
public void testToString() throws Exception {
115+
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
116+
assertEquals( "PerMessageDeflateExtension", deflateExtension.toString() );
117+
}
118+
}

0 commit comments

Comments
 (0)