Skip to content

Commit d7b5467

Browse files
committed
improved buffer allocation, renamed variables, documented code and fixed potential threading issue in SSLSocketChannel2#processHandshake (TooTallNate#154)
1 parent 2c10852 commit d7b5467

File tree

3 files changed

+78
-42
lines changed

3 files changed

+78
-42
lines changed

src/main/java/org/java_websocket/SSLSocketChannel2.java

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -44,33 +44,34 @@ public class SSLSocketChannel2 implements ByteChannel, WrappedByteChannel {
4444
/** encrypted data incoming */
4545
protected ByteBuffer inCrypt;
4646

47-
protected SocketChannel sc;
48-
protected SelectionKey key;
47+
/** the underlying channel */
48+
protected SocketChannel socketChannel;
49+
/** used to set interestOP SelectionKey.OP_WRITE for the underlying channel */
50+
protected SelectionKey selectionKey;
4951

50-
protected SSLEngineResult res;
52+
53+
protected SSLEngineResult engineResult;
5154
protected SSLEngine sslEngine;
52-
protected final boolean isblocking;
5355

54-
private Status status = Status.BUFFER_UNDERFLOW;
56+
57+
private Status engineStatus = Status.BUFFER_UNDERFLOW;
5558

5659
public SSLSocketChannel2( SocketChannel channel , SSLEngine sslEngine , ExecutorService exec , SelectionKey key ) throws IOException {
57-
this.sc = channel;
60+
if( channel == null || sslEngine == null || exec == null )
61+
throw new IllegalArgumentException( "parameter must not be null" );
5862

63+
this.socketChannel = channel;
5964
this.sslEngine = sslEngine;
6065
this.exec = exec;
6166

6267
tasks = new ArrayList<Future<?>>( 3 );
6368
if( key != null ) {
6469
key.interestOps( key.interestOps() | SelectionKey.OP_WRITE );
65-
this.key = key;
70+
this.selectionKey = key;
6671
}
67-
isblocking = channel.isBlocking();
68-
69-
sslEngine.setEnableSessionCreation( true );
70-
SSLSession session = sslEngine.getSession();
71-
createBuffers( session );
72-
73-
sc.write( wrap( emptybuffer ) );// initializes res
72+
createBuffers( sslEngine.getSession() );
73+
// kick off handshake
74+
socketChannel.write( wrap( emptybuffer ) );// initializes res
7475
processHandshake();
7576
}
7677

@@ -93,6 +94,8 @@ private void consumeFutureUninterruptible( Future<?> f ) {
9394
}
9495

9596
private synchronized void processHandshake() throws IOException {
97+
if( engineResult.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING )
98+
return; // since this may be called either from a reading or a writing thread and because this method is synchronized it is necessary to double check if we are still handshaking.
9699
if( !tasks.isEmpty() ) {
97100
Iterator<Future<?>> it = tasks.iterator();
98101
while ( it.hasNext() ) {
@@ -107,28 +110,35 @@ private synchronized void processHandshake() throws IOException {
107110
}
108111
}
109112

110-
if( res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP ) {
111-
if( !isblocking || status == Status.BUFFER_UNDERFLOW ) {
113+
if( engineResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP ) {
114+
if( !isBlocking() || engineStatus == Status.BUFFER_UNDERFLOW ) {
112115
inCrypt.compact();
113-
int read = sc.read( inCrypt );
116+
int read = socketChannel.read( inCrypt );
114117
if( read == -1 ) {
115118
throw new IOException( "connection closed unexpectedly by peer" );
116119
}
117120
inCrypt.flip();
118121
}
119122
inData.compact();
120123
unwrap();
124+
if( engineResult.getHandshakeStatus() == HandshakeStatus.FINISHED ) {
125+
createBuffers( sslEngine.getSession() );
126+
return;
127+
}
121128
}
122129
consumeDelegatedTasks();
123-
if( tasks.isEmpty() || res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP ) {
124-
sc.write( wrap( emptybuffer ) );
130+
assert ( engineResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING );
131+
if( tasks.isEmpty() || engineResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP ) {
132+
socketChannel.write( wrap( emptybuffer ) );
133+
if( engineResult.getHandshakeStatus() == HandshakeStatus.FINISHED ) {
134+
createBuffers( sslEngine.getSession() );
135+
}
125136
}
126-
127137
}
128138

129139
private synchronized ByteBuffer wrap( ByteBuffer b ) throws SSLException {
130140
outCrypt.compact();
131-
res = sslEngine.wrap( b, outCrypt );
141+
engineResult = sslEngine.wrap( b, outCrypt );
132142
outCrypt.flip();
133143
return outCrypt;
134144
}
@@ -137,9 +147,9 @@ private synchronized ByteBuffer unwrap() throws SSLException {
137147
int rem;
138148
do {
139149
rem = inData.remaining();
140-
res = sslEngine.unwrap( inCrypt, inData );
141-
status = res.getStatus();
142-
} while ( status == SSLEngineResult.Status.OK && ( rem != inData.remaining() || res.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP ) );
150+
engineResult = sslEngine.unwrap( inCrypt, inData );
151+
engineStatus = engineResult.getStatus();
152+
} while ( engineStatus == SSLEngineResult.Status.OK && ( rem != inData.remaining() || engineResult.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP ) );
143153

144154
inData.flip();
145155
return inData;
@@ -157,11 +167,23 @@ protected void createBuffers( SSLSession session ) {
157167
int appBufferMax = session.getApplicationBufferSize();
158168
int netBufferMax = session.getPacketBufferSize();
159169

160-
inData = ByteBuffer.allocate( appBufferMax );
161-
outCrypt = ByteBuffer.allocate( netBufferMax );
162-
inCrypt = ByteBuffer.allocate( netBufferMax );
170+
if( inData == null ) {
171+
inData = ByteBuffer.allocate( appBufferMax );
172+
outCrypt = ByteBuffer.allocate( netBufferMax );
173+
inCrypt = ByteBuffer.allocate( netBufferMax );
174+
} else {
175+
if( inData.capacity() != appBufferMax )
176+
inData = ByteBuffer.allocate( appBufferMax );
177+
if( outCrypt.capacity() != netBufferMax )
178+
outCrypt = ByteBuffer.allocate( netBufferMax );
179+
if( inCrypt.capacity() != netBufferMax )
180+
inCrypt = ByteBuffer.allocate( netBufferMax );
181+
}
182+
inData.rewind();
163183
inData.flip();
184+
inCrypt.rewind();
164185
inCrypt.flip();
186+
outCrypt.rewind();
165187
outCrypt.flip();
166188
}
167189

@@ -170,7 +192,7 @@ public int write( ByteBuffer src ) throws IOException {
170192
processHandshake();
171193
return 0;
172194
}
173-
int num = sc.write( wrap( src ) );
195+
int num = socketChannel.write( wrap( src ) );
174196
return num;
175197

176198
}
@@ -203,8 +225,8 @@ public int read( ByteBuffer dst ) throws IOException {
203225
else
204226
inCrypt.compact();
205227

206-
if( ( isblocking && inCrypt.position() == 0 ) || status == Status.BUFFER_UNDERFLOW )
207-
if( sc.read( inCrypt ) == -1 ) {
228+
if( ( isBlocking() && inCrypt.position() == 0 ) || engineStatus == Status.BUFFER_UNDERFLOW )
229+
if( socketChannel.read( inCrypt ) == -1 ) {
208230
return -1;
209231
}
210232
inCrypt.flip();
@@ -230,36 +252,36 @@ private int readRemaining( ByteBuffer dst ) throws SSLException {
230252
}
231253

232254
public boolean isConnected() {
233-
return sc.isConnected();
255+
return socketChannel.isConnected();
234256
}
235257

236258
public void close() throws IOException {
237259
sslEngine.closeOutbound();
238260
sslEngine.getSession().invalidate();
239-
if( sc.isOpen() )
240-
sc.write( wrap( emptybuffer ) );// FIXME what if not all bytes can be written
241-
sc.close();
261+
if( socketChannel.isOpen() )
262+
socketChannel.write( wrap( emptybuffer ) );// FIXME what if not all bytes can be written
263+
socketChannel.close();
242264
}
243265

244266
private boolean isHandShakeComplete() {
245-
HandshakeStatus status = res.getHandshakeStatus();
267+
HandshakeStatus status = engineResult.getHandshakeStatus();
246268
return status == SSLEngineResult.HandshakeStatus.FINISHED || status == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
247269
}
248270

249271
public SelectableChannel configureBlocking( boolean b ) throws IOException {
250-
return sc.configureBlocking( b );
272+
return socketChannel.configureBlocking( b );
251273
}
252274

253275
public boolean connect( SocketAddress remote ) throws IOException {
254-
return sc.connect( remote );
276+
return socketChannel.connect( remote );
255277
}
256278

257279
public boolean finishConnect() throws IOException {
258-
return sc.finishConnect();
280+
return socketChannel.finishConnect();
259281
}
260282

261283
public Socket socket() {
262-
return sc.socket();
284+
return socketChannel.socket();
263285
}
264286

265287
public boolean isInboundDone() {
@@ -268,7 +290,7 @@ public boolean isInboundDone() {
268290

269291
@Override
270292
public boolean isOpen() {
271-
return sc.isOpen();
293+
return socketChannel.isOpen();
272294
}
273295

274296
@Override
@@ -283,7 +305,7 @@ public void writeMore() throws IOException {
283305

284306
@Override
285307
public boolean isNeedRead() {
286-
return inData.hasRemaining() || ( inCrypt.hasRemaining() && res.getStatus() != Status.BUFFER_UNDERFLOW );
308+
return inData.hasRemaining() || ( inCrypt.hasRemaining() && engineResult.getStatus() != Status.BUFFER_UNDERFLOW );
287309
}
288310

289311
@Override
@@ -310,7 +332,7 @@ private int transfereTo( ByteBuffer from, ByteBuffer to ) {
310332

311333
@Override
312334
public boolean isBlocking() {
313-
return isblocking;
335+
return socketChannel.isBlocking();
314336
}
315337

316338
}

src/main/java/org/java_websocket/SocketChannelIOHelper.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ public static boolean read( final ByteBuffer buf, WebSocketImpl ws, ByteChannel
1919
return read != 0;
2020
}
2121

22+
/**
23+
* @see WrappedByteChannel#readMore(ByteBuffer)
24+
* @return returns whether there is more data left which can be obtained via {@link #readMore(ByteBuffer, WebSocketImpl, WrappedByteChannel)}
25+
**/
2226
public static boolean readMore( final ByteBuffer buf, WebSocketImpl ws, WrappedByteChannel channel ) throws IOException {
2327
buf.clear();
2428
int read = channel.readMore( buf );

src/main/java/org/java_websocket/WrappedByteChannel.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@ public interface WrappedByteChannel extends ByteChannel {
1010
public boolean isNeedWrite();
1111
public void writeMore() throws IOException;
1212

13+
/**
14+
* returns whether readMore should be called to fetch data which has been decoded but not yet been returned.
15+
*
16+
* @see #read(ByteBuffer)
17+
* @see #readMore(ByteBuffer)
18+
**/
1319
public boolean isNeedRead();
20+
/**
21+
* This function does not read data from the underlying channel at all. It is just a way to fetch data which has already be received or decoded but was but was not yet returned to the user.
22+
* This could be the case when the decoded data did not fit into the buffer the user passed to {@link #read(ByteBuffer)}.
23+
**/
1424
public int readMore( ByteBuffer dst ) throws SSLException;
1525
public boolean isBlocking();
1626
}

0 commit comments

Comments
 (0)