Skip to content

Commit ae10dc1

Browse files
authored
Don't unwrap multiple records until we notified the caller about the finished handshake (#13314)
Motiviation: We should ensure we only unwrap data to the destination buffer once we did notify the user about the completion of the handshake. Failing to do so might result in corrupted state machines. Modifications: Always use the jdkCompat mode until the handshake was complete Result: Correct state machine.
1 parent 94ab6f3 commit ae10dc1

3 files changed

Lines changed: 187 additions & 2 deletions

File tree

handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,8 @@ public final SSLEngineResult wrap(
803803

804804
// Flush any data that may be implicitly generated by OpenSSL (handshake, close, etc..).
805805
SSLEngineResult.HandshakeStatus status = NOT_HANDSHAKING;
806+
HandshakeState oldHandshakeState = handshakeState;
807+
806808
// Prepare OpenSSL to work in server mode and receive handshake
807809
if (handshakeState != HandshakeState.FINISHED) {
808810
if (handshakeState != HandshakeState.STARTED_EXPLICITLY) {
@@ -868,7 +870,11 @@ public final SSLEngineResult wrap(
868870
}
869871

870872
final int endOffset = offset + length;
871-
if (jdkCompatibilityMode) {
873+
if (jdkCompatibilityMode ||
874+
// If the handshake was not finished before we entered the method, we also ensure we only
875+
// wrap one record. We do this to ensure we not produce any extra data before the caller
876+
// of the method is able to observe handshake completion and react on it.
877+
oldHandshakeState != HandshakeState.FINISHED) {
872878
int srcsLen = 0;
873879
for (int i = offset; i < endOffset; ++i) {
874880
final ByteBuffer src = srcs[i];
@@ -1143,6 +1149,7 @@ public final SSLEngineResult unwrap(
11431149
}
11441150

11451151
SSLEngineResult.HandshakeStatus status = NOT_HANDSHAKING;
1152+
HandshakeState oldHandshakeState = handshakeState;
11461153
// Prepare OpenSSL to work in server mode and receive handshake
11471154
if (handshakeState != HandshakeState.FINISHED) {
11481155
if (handshakeState != HandshakeState.STARTED_EXPLICITLY) {
@@ -1171,7 +1178,11 @@ public final SSLEngineResult unwrap(
11711178
// JDK compatibility mode then we should honor this, but if not we just wrap as much as possible. If there
11721179
// are multiple records or partial records this may reduce thrashing events through the pipeline.
11731180
// [1] https://docs.oracle.com/javase/7/docs/api/javax/net/ssl/SSLEngine.html
1174-
if (jdkCompatibilityMode) {
1181+
if (jdkCompatibilityMode ||
1182+
// If the handshake was not finished before we entered the method, we also ensure we only
1183+
// unwrap one record. We do this to ensure we not produce any extra data before the caller
1184+
// of the method is able to observe handshake completion and react on it.
1185+
oldHandshakeState != HandshakeState.FINISHED) {
11751186
if (len < SSL_RECORD_HEADER_LENGTH) {
11761187
return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0);
11771188
}

handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import io.netty.util.internal.PlatformDependent;
2828
import org.junit.jupiter.api.AfterEach;
2929
import org.junit.jupiter.api.BeforeAll;
30+
import org.junit.jupiter.api.Test;
3031
import org.junit.jupiter.api.function.Executable;
3132
import org.junit.jupiter.params.ParameterizedTest;
3233
import org.junit.jupiter.params.provider.MethodSource;
@@ -1577,4 +1578,27 @@ public void testRSASSAPSS(SSLEngineTestParam param) throws Exception {
15771578
checkShouldUseKeyManagerFactory();
15781579
super.testRSASSAPSS(param);
15791580
}
1581+
1582+
@Test
1583+
public void testExtraDataInLastSrcBufferForClientUnwrapNonjdkCompatabilityMode() throws Exception {
1584+
SSLEngineTestParam param = new SSLEngineTestParam(BufferType.Direct, ProtocolCipherCombo.tlsv12(), false);
1585+
SelfSignedCertificate ssc = new SelfSignedCertificate();
1586+
clientSslCtx = wrapContext(param, SslContextBuilder.forClient()
1587+
.trustManager(InsecureTrustManagerFactory.INSTANCE)
1588+
.sslProvider(sslClientProvider())
1589+
.sslContextProvider(clientSslContextProvider())
1590+
.protocols(param.protocols())
1591+
.ciphers(param.ciphers())
1592+
.build());
1593+
serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
1594+
.sslProvider(sslServerProvider())
1595+
.sslContextProvider(serverSslContextProvider())
1596+
.protocols(param.protocols())
1597+
.ciphers(param.ciphers())
1598+
.clientAuth(ClientAuth.NONE)
1599+
.build());
1600+
testExtraDataInLastSrcBufferForClientUnwrap(param,
1601+
wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()),
1602+
wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()));
1603+
}
15801604
}

handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4324,6 +4324,156 @@ public void testBufferUnderflowPacketSizeDependency(SSLEngineTestParam param) th
43244324
}
43254325
}
43264326

4327+
@Test
4328+
public void testExtraDataInLastSrcBufferForClientUnwrap() throws Exception {
4329+
SSLEngineTestParam param = new SSLEngineTestParam(BufferType.Direct, ProtocolCipherCombo.tlsv12(), false);
4330+
SelfSignedCertificate ssc = new SelfSignedCertificate();
4331+
clientSslCtx = wrapContext(param, SslContextBuilder.forClient()
4332+
.trustManager(InsecureTrustManagerFactory.INSTANCE)
4333+
.sslProvider(sslClientProvider())
4334+
.sslContextProvider(clientSslContextProvider())
4335+
.protocols(param.protocols())
4336+
.ciphers(param.ciphers())
4337+
.build());
4338+
serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
4339+
.sslProvider(sslServerProvider())
4340+
.sslContextProvider(serverSslContextProvider())
4341+
.protocols(param.protocols())
4342+
.ciphers(param.ciphers())
4343+
.clientAuth(ClientAuth.NONE)
4344+
.build());
4345+
testExtraDataInLastSrcBufferForClientUnwrap(param,
4346+
wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)),
4347+
wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)));
4348+
}
4349+
4350+
protected final void testExtraDataInLastSrcBufferForClientUnwrap(
4351+
SSLEngineTestParam param, SSLEngine clientEngine, SSLEngine serverEngine) throws Exception {
4352+
try {
4353+
ByteBuffer cTOs = allocateBuffer(param.type(), clientEngine.getSession().getPacketBufferSize());
4354+
// Ensure we can fit two records as we want to include two records once the handshake completes on the
4355+
// server side.
4356+
ByteBuffer sTOc = allocateBuffer(param.type(), serverEngine.getSession().getPacketBufferSize() * 2);
4357+
4358+
ByteBuffer serverAppReadBuffer =
4359+
allocateBuffer(param.type(), serverEngine.getSession().getApplicationBufferSize());
4360+
ByteBuffer clientAppReadBuffer =
4361+
allocateBuffer(param.type(), clientEngine.getSession().getApplicationBufferSize());
4362+
4363+
ByteBuffer empty = allocateBuffer(param.type(), 0);
4364+
4365+
SSLEngineResult clientResult;
4366+
SSLEngineResult serverResult;
4367+
4368+
boolean clientHandshakeFinished = false;
4369+
boolean serverHandshakeFinished = false;
4370+
4371+
do {
4372+
int cTOsPos = cTOs.position();
4373+
int sTOcPos = sTOc.position();
4374+
4375+
if (!clientHandshakeFinished) {
4376+
clientResult = clientEngine.wrap(empty, cTOs);
4377+
runDelegatedTasks(param.delegate(), clientResult, clientEngine);
4378+
assertEquals(empty.remaining(), clientResult.bytesConsumed());
4379+
assertEquals(cTOs.position() - cTOsPos, clientResult.bytesProduced());
4380+
4381+
if (isHandshakeFinished(clientResult)) {
4382+
clientHandshakeFinished = true;
4383+
}
4384+
4385+
if (clientResult.getStatus() == Status.BUFFER_OVERFLOW) {
4386+
cTOs = increaseDstBuffer(clientEngine.getSession().getPacketBufferSize(), param.type(), cTOs);
4387+
}
4388+
}
4389+
4390+
if (!serverHandshakeFinished) {
4391+
serverResult = serverEngine.wrap(empty, sTOc);
4392+
runDelegatedTasks(param.delegate(), serverResult, serverEngine);
4393+
assertEquals(empty.remaining(), serverResult.bytesConsumed());
4394+
assertEquals(sTOc.position() - sTOcPos, serverResult.bytesProduced());
4395+
4396+
if (isHandshakeFinished(serverResult)) {
4397+
serverHandshakeFinished = true;
4398+
// We finished the handshake on the server side, lets add another record to the sTOc buffer
4399+
// so we can test that we will not unwrap extra data before we actually consider the handshake
4400+
// complete on the client side as well.
4401+
serverResult = serverEngine.wrap(ByteBuffer.wrap(new byte[8]), sTOc);
4402+
assertEquals(8, serverResult.bytesConsumed());
4403+
}
4404+
4405+
if (serverResult.getStatus() == Status.BUFFER_OVERFLOW) {
4406+
sTOc = increaseDstBuffer(serverEngine.getSession().getPacketBufferSize(), param.type(), sTOc);
4407+
}
4408+
}
4409+
4410+
cTOs.flip();
4411+
sTOc.flip();
4412+
4413+
cTOsPos = cTOs.position();
4414+
sTOcPos = sTOc.position();
4415+
4416+
if (!clientHandshakeFinished) {
4417+
int clientAppReadBufferPos = clientAppReadBuffer.position();
4418+
clientResult = clientEngine.unwrap(sTOc, clientAppReadBuffer);
4419+
4420+
runDelegatedTasks(param.delegate(), clientResult, clientEngine);
4421+
assertEquals(sTOc.position() - sTOcPos, clientResult.bytesConsumed());
4422+
assertEquals(clientAppReadBuffer.position() - clientAppReadBufferPos, clientResult.bytesProduced());
4423+
assertEquals(0, clientAppReadBuffer.position());
4424+
4425+
if (isHandshakeFinished(clientResult)) {
4426+
clientHandshakeFinished = true;
4427+
} else {
4428+
assertEquals(0, clientAppReadBuffer.position() - clientAppReadBufferPos);
4429+
}
4430+
4431+
if (clientResult.getStatus() == Status.BUFFER_OVERFLOW) {
4432+
clientAppReadBuffer = increaseDstBuffer(
4433+
clientEngine.getSession().getApplicationBufferSize(),
4434+
param.type(), clientAppReadBuffer);
4435+
}
4436+
}
4437+
4438+
if (!serverHandshakeFinished) {
4439+
int serverAppReadBufferPos = serverAppReadBuffer.position();
4440+
serverResult = serverEngine.unwrap(cTOs, serverAppReadBuffer);
4441+
runDelegatedTasks(param.delegate(), serverResult, serverEngine);
4442+
assertEquals(cTOs.position() - cTOsPos, serverResult.bytesConsumed());
4443+
assertEquals(serverAppReadBuffer.position() - serverAppReadBufferPos, serverResult.bytesProduced());
4444+
assertEquals(0, serverAppReadBuffer.position());
4445+
4446+
if (isHandshakeFinished(serverResult)) {
4447+
serverHandshakeFinished = true;
4448+
}
4449+
4450+
if (serverResult.getStatus() == Status.BUFFER_OVERFLOW) {
4451+
serverAppReadBuffer = increaseDstBuffer(
4452+
serverEngine.getSession().getApplicationBufferSize(),
4453+
param.type(), serverAppReadBuffer);
4454+
}
4455+
}
4456+
4457+
compactOrClear(cTOs);
4458+
compactOrClear(sTOc);
4459+
4460+
serverAppReadBuffer.clear();
4461+
clientAppReadBuffer.clear();
4462+
4463+
if (clientEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
4464+
clientHandshakeFinished = true;
4465+
}
4466+
4467+
if (serverEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
4468+
serverHandshakeFinished = true;
4469+
}
4470+
} while (!clientHandshakeFinished || !serverHandshakeFinished);
4471+
} finally {
4472+
cleanupClientSslEngine(clientEngine);
4473+
cleanupServerSslEngine(serverEngine);
4474+
}
4475+
}
4476+
43274477
protected SSLEngine wrapEngine(SSLEngine engine) {
43284478
return engine;
43294479
}

0 commit comments

Comments
 (0)