Skip to content

Commit 6f26665

Browse files
authored
better ssl write handling (#6763)
* ssl_write * Fix thread count timing
1 parent e0479fe commit 6f26665

File tree

4 files changed

+356
-115
lines changed

4 files changed

+356
-115
lines changed

crates/stdlib/src/ssl.rs

Lines changed: 121 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ mod _ssl {
5050

5151
// Import error types used in this module (others are exposed via pymodule(with(...)))
5252
use super::error::{
53-
PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error,
53+
PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error,
5454
};
5555
use alloc::sync::Arc;
5656
use core::{
@@ -1903,6 +1903,7 @@ mod _ssl {
19031903
client_hello_buffer: PyMutex::new(None),
19041904
shutdown_state: PyMutex::new(ShutdownState::NotStarted),
19051905
pending_tls_output: PyMutex::new(Vec::new()),
1906+
write_buffered_len: PyMutex::new(0),
19061907
deferred_cert_error: Arc::new(ParkingRwLock::new(None)),
19071908
};
19081909

@@ -1974,6 +1975,7 @@ mod _ssl {
19741975
client_hello_buffer: PyMutex::new(None),
19751976
shutdown_state: PyMutex::new(ShutdownState::NotStarted),
19761977
pending_tls_output: PyMutex::new(Vec::new()),
1978+
write_buffered_len: PyMutex::new(0),
19771979
deferred_cert_error: Arc::new(ParkingRwLock::new(None)),
19781980
};
19791981

@@ -2345,6 +2347,10 @@ mod _ssl {
23452347
// but the socket cannot accept all the data immediately
23462348
#[pytraverse(skip)]
23472349
pub(crate) pending_tls_output: PyMutex<Vec<u8>>,
2350+
// Tracks bytes already buffered in rustls for the current write operation
2351+
// Prevents duplicate writes when retrying after WantWrite/WantRead
2352+
#[pytraverse(skip)]
2353+
pub(crate) write_buffered_len: PyMutex<usize>,
23482354
// Deferred client certificate verification error (for TLS 1.3)
23492355
// Stores error message if client cert verification failed during handshake
23502356
// Error is raised on first I/O operation after handshake
@@ -2604,6 +2610,36 @@ mod _ssl {
26042610
Ok(timed_out)
26052611
}
26062612

2613+
// Internal implementation with explicit timeout override
2614+
pub(crate) fn sock_wait_for_io_with_timeout(
2615+
&self,
2616+
kind: SelectKind,
2617+
timeout: Option<std::time::Duration>,
2618+
vm: &VirtualMachine,
2619+
) -> PyResult<bool> {
2620+
if self.is_bio_mode() {
2621+
// BIO mode doesn't use select
2622+
return Ok(false);
2623+
}
2624+
2625+
if let Some(t) = timeout
2626+
&& t.is_zero()
2627+
{
2628+
// Non-blocking mode - don't use select
2629+
return Ok(false);
2630+
}
2631+
2632+
let py_socket: PyRef<PySocket> = self.sock.clone().try_into_value(vm)?;
2633+
let socket = py_socket
2634+
.sock()
2635+
.map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?;
2636+
2637+
let timed_out = sock_select(&socket, kind, timeout)
2638+
.map_err(|e| vm.new_os_error(format!("select failed: {e}")))?;
2639+
2640+
Ok(timed_out)
2641+
}
2642+
26072643
// SNI (Server Name Indication) Helper Methods:
26082644
// These methods support the server-side handshake SNI callback mechanism
26092645

@@ -2783,6 +2819,7 @@ mod _ssl {
27832819
let is_non_blocking = socket_timeout.map(|t| t.is_zero()).unwrap_or(false);
27842820

27852821
let mut sent_total = 0;
2822+
27862823
while sent_total < pending.len() {
27872824
// Calculate timeout: use deadline if provided, otherwise use socket timeout
27882825
let timeout_to_use = if let Some(dl) = deadline {
@@ -2810,6 +2847,9 @@ mod _ssl {
28102847
if timed_out {
28112848
// Keep unsent data in pending buffer
28122849
*pending = pending[sent_total..].to_vec();
2850+
if is_non_blocking {
2851+
return Err(create_ssl_want_write_error(vm).upcast());
2852+
}
28132853
return Err(
28142854
timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(),
28152855
);
@@ -2824,6 +2864,7 @@ mod _ssl {
28242864
*pending = pending[sent_total..].to_vec();
28252865
return Err(create_ssl_want_write_error(vm).upcast());
28262866
}
2867+
// Socket said ready but sent 0 bytes - retry
28272868
continue;
28282869
}
28292870
sent_total += sent;
@@ -2916,6 +2957,9 @@ mod _ssl {
29162957
pub(crate) fn blocking_flush_all_pending(&self, vm: &VirtualMachine) -> PyResult<()> {
29172958
// Get socket timeout to respect during flush
29182959
let timeout = self.get_socket_timeout(vm)?;
2960+
if timeout.map(|t| t.is_zero()).unwrap_or(false) {
2961+
return self.flush_pending_tls_output(vm, None);
2962+
}
29192963

29202964
loop {
29212965
let pending_data = {
@@ -2948,8 +2992,7 @@ mod _ssl {
29482992
let mut pending = self.pending_tls_output.lock();
29492993
pending.drain(..sent);
29502994
}
2951-
// If sent == 0, socket wasn't ready despite select() saying so
2952-
// Continue loop to retry - this avoids infinite loops
2995+
// If sent == 0, loop will retry with sock_select
29532996
}
29542997
Err(e) => {
29552998
if is_blocking_io_error(&e, vm) {
@@ -3515,16 +3558,60 @@ mod _ssl {
35153558
return_data(buf, &buffer, vm)
35163559
}
35173560
Err(crate::ssl::compat::SslError::Eof) => {
3561+
// If plaintext is still buffered, return it before EOF.
3562+
let pending = {
3563+
let mut conn_guard = self.connection.lock();
3564+
let conn = match conn_guard.as_mut() {
3565+
Some(conn) => conn,
3566+
None => return Err(create_ssl_eof_error(vm).upcast()),
3567+
};
3568+
use std::io::BufRead;
3569+
let mut reader = conn.reader();
3570+
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
3571+
};
3572+
if pending > 0 {
3573+
let mut buf = vec![0u8; pending.min(len)];
3574+
let read_retry = {
3575+
let mut conn_guard = self.connection.lock();
3576+
let conn = conn_guard
3577+
.as_mut()
3578+
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
3579+
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
3580+
};
3581+
if let Ok(n) = read_retry {
3582+
buf.truncate(n);
3583+
return return_data(buf, &buffer, vm);
3584+
}
3585+
}
35183586
// EOF occurred in violation of protocol (unexpected closure)
3519-
Err(vm
3520-
.new_os_subtype_error(
3521-
PySSLEOFError::class(&vm.ctx).to_owned(),
3522-
None,
3523-
"EOF occurred in violation of protocol",
3524-
)
3525-
.upcast())
3587+
Err(create_ssl_eof_error(vm).upcast())
35263588
}
35273589
Err(crate::ssl::compat::SslError::ZeroReturn) => {
3590+
// If plaintext is still buffered, return it before clean EOF.
3591+
let pending = {
3592+
let mut conn_guard = self.connection.lock();
3593+
let conn = match conn_guard.as_mut() {
3594+
Some(conn) => conn,
3595+
None => return return_data(vec![], &buffer, vm),
3596+
};
3597+
use std::io::BufRead;
3598+
let mut reader = conn.reader();
3599+
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
3600+
};
3601+
if pending > 0 {
3602+
let mut buf = vec![0u8; pending.min(len)];
3603+
let read_retry = {
3604+
let mut conn_guard = self.connection.lock();
3605+
let conn = conn_guard
3606+
.as_mut()
3607+
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
3608+
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
3609+
};
3610+
if let Ok(n) = read_retry {
3611+
buf.truncate(n);
3612+
return return_data(buf, &buffer, vm);
3613+
}
3614+
}
35283615
// Clean closure with close_notify - return empty data
35293616
return_data(vec![], &buffer, vm)
35303617
}
@@ -3580,21 +3667,17 @@ mod _ssl {
35803667
let data_bytes = data.borrow_buf();
35813668
let data_len = data_bytes.len();
35823669

3583-
// return 0 immediately for empty write
35843670
if data_len == 0 {
35853671
return Ok(0);
35863672
}
35873673

3588-
// Ensure handshake is done - if not, complete it first
3589-
// This matches OpenSSL behavior where SSL_write() auto-completes handshake
3674+
// Ensure handshake is done (SSL_write auto-completes handshake)
35903675
if !*self.handshake_done.lock() {
35913676
self.do_handshake(vm)?;
35923677
}
35933678

3594-
// Check if connection has been shut down
3595-
// After unwrap()/shutdown(), write operations should fail with SSLError
3596-
let shutdown_state = *self.shutdown_state.lock();
3597-
if shutdown_state != ShutdownState::NotStarted {
3679+
// Check shutdown state
3680+
if *self.shutdown_state.lock() != ShutdownState::NotStarted {
35983681
return Err(vm
35993682
.new_os_subtype_error(
36003683
PySSLError::class(&vm.ctx).to_owned(),
@@ -3604,76 +3687,32 @@ mod _ssl {
36043687
.upcast());
36053688
}
36063689

3607-
{
3690+
// Call ssl_write (matches CPython's SSL_write_ex loop)
3691+
let result = {
36083692
let mut conn_guard = self.connection.lock();
36093693
let conn = conn_guard
36103694
.as_mut()
36113695
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
36123696

3613-
let is_bio = self.is_bio_mode();
3614-
let data: &[u8] = data_bytes.as_ref();
3697+
crate::ssl::compat::ssl_write(conn, data_bytes.as_ref(), self, vm)
3698+
};
36153699

3616-
// CRITICAL: Flush any pending TLS data before writing new data
3617-
// This ensures TLS 1.3 Finished message reaches server before application data
3618-
// Without this, server may not be ready to process our data
3619-
if !is_bio {
3620-
self.flush_pending_tls_output(vm, None)?;
3700+
match result {
3701+
Ok(n) => {
3702+
self.check_deferred_cert_error(vm)?;
3703+
Ok(n)
36213704
}
3622-
3623-
// Write data in chunks to avoid filling the internal TLS buffer
3624-
// rustls has a limited internal buffer, so we need to flush periodically
3625-
const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size)
3626-
let mut written = 0;
3627-
3628-
while written < data.len() {
3629-
let chunk_end = core::cmp::min(written + CHUNK_SIZE, data.len());
3630-
let chunk = &data[written..chunk_end];
3631-
3632-
// Write chunk to TLS layer
3633-
{
3634-
let mut writer = conn.writer();
3635-
use std::io::Write;
3636-
writer
3637-
.write_all(chunk)
3638-
.map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?;
3639-
// Flush to ensure data is converted to TLS records
3640-
writer
3641-
.flush()
3642-
.map_err(|e| vm.new_os_error(format!("Flush failed: {e}")))?;
3643-
}
3644-
3645-
written = chunk_end;
3646-
3647-
// Flush TLS data to socket after each chunk
3648-
if conn.wants_write() {
3649-
if is_bio {
3650-
self.write_pending_tls(conn, vm)?;
3651-
} else {
3652-
// Socket mode: flush all pending TLS data
3653-
// First, try to send any previously pending data
3654-
self.flush_pending_tls_output(vm, None)?;
3655-
3656-
while conn.wants_write() {
3657-
let mut buf = Vec::new();
3658-
conn.write_tls(&mut buf).map_err(|e| {
3659-
vm.new_os_error(format!("TLS write failed: {e}"))
3660-
})?;
3661-
3662-
if !buf.is_empty() {
3663-
// Try to send TLS data, saving unsent bytes to pending buffer
3664-
self.send_tls_output(buf, vm)?;
3665-
}
3666-
}
3667-
}
3668-
}
3705+
Err(crate::ssl::compat::SslError::WantRead) => {
3706+
Err(create_ssl_want_read_error(vm).upcast())
3707+
}
3708+
Err(crate::ssl::compat::SslError::WantWrite) => {
3709+
Err(create_ssl_want_write_error(vm).upcast())
3710+
}
3711+
Err(crate::ssl::compat::SslError::Timeout(msg)) => {
3712+
Err(timeout_error_msg(vm, msg).upcast())
36693713
}
3714+
Err(e) => Err(e.into_py_err(vm)),
36703715
}
3671-
3672-
// Check for deferred certificate verification errors (TLS 1.3)
3673-
// Must be checked AFTER write completes, as the error may be set during I/O
3674-
self.check_deferred_cert_error(vm)?;
3675-
3676-
Ok(data_len)
36773716
}
36783717

36793718
#[pymethod]
@@ -4013,6 +4052,10 @@ mod _ssl {
40134052

40144053
// Write close_notify to outgoing buffer/BIO
40154054
self.write_pending_tls(conn, vm)?;
4055+
// Ensure close_notify and any pending TLS data are flushed
4056+
if !is_bio {
4057+
self.flush_pending_tls_output(vm, None)?;
4058+
}
40164059

40174060
// Update state
40184061
*self.shutdown_state.lock() = ShutdownState::SentCloseNotify;

0 commit comments

Comments
 (0)