From ad3060ca823fd78081ef8249e6e492d77336db23 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Tue, 19 May 2026 01:40:56 +0200 Subject: [PATCH 01/11] Do not call `import socket` on each send()/recv() when using rustls Use method references cached during socket creation. --- crates/stdlib/src/ssl.rs | 53 +++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 6e06e4e9efb..ddff6a14d39 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -1825,6 +1825,9 @@ mod _ssl { args: WrapSocketArgs, vm: &VirtualMachine, ) -> PyResult> { + let socket_mod = vm.import("socket", 0)?; + let socket_class = socket_mod.get_attr("socket", vm)?; + // Convert server_hostname to Option // Handle both missing argument and None value let hostname = match args.server_hostname.into_option().flatten() { @@ -1877,6 +1880,8 @@ mod _ssl { // Create _SSLSocket instance let ssl_socket = PySSLSocket { sock: args.sock.clone(), + sock_send_method: socket_class.get_attr("send", vm)?, + sock_recv_method: socket_class.get_attr("recv", vm)?, context: PyRwLock::new(zelf), server_side: args.server_side, server_hostname: PyRwLock::new(hostname), @@ -1948,7 +1953,11 @@ mod _ssl { // Create _SSLSocket instance with BIO mode let ssl_socket = PySSLSocket { - sock: vm.ctx.none(), // No socket in BIO mode + // No socket in BIO mode + sock: vm.ctx.none(), + sock_send_method: vm.ctx.none(), + sock_recv_method: vm.ctx.none(), + context: PyRwLock::new(zelf), server_side, server_hostname: PyRwLock::new(hostname), @@ -2302,6 +2311,12 @@ mod _ssl { pub(crate) struct PySSLSocket { // Underlying socket sock: PyObjectRef, + // Cached socket.socket.send + #[pytraverse(skip)] + sock_send_method: PyObjectRef, + // Cached socket.socket.recv + #[pytraverse(skip)] + sock_recv_method: PyObjectRef, // SSL context context: PyRwLock>, // Server-side or client-side @@ -2771,23 +2786,26 @@ mod _ssl { return read_method.call((vm.ctx.new_int(size),), vm); } - // Normal socket mode - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - - // Call socket.socket.recv(self.sock, size, flags) - let recv_method = socket_class.get_attr("recv", vm)?; - recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm) + self.sock_recv_method + .call((self.sock.clone(), vm.ctx.new_int(size)), vm) } /// Peek at socket data without consuming it (MSG_PEEK). /// Used during TLS shutdown to avoid consuming post-TLS cleartext data. pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult { - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - let recv_method = socket_class.get_attr("recv", vm)?; - let msg_peek = socket_mod.get_attr("MSG_PEEK", vm)?; - recv_method.call((self.sock.clone(), vm.ctx.new_int(size), msg_peek), vm) + #[cfg(not(windows))] + use libc::MSG_PEEK; + #[cfg(windows)] + use windows_sys::Win32::Networking::WinSock::MSG_PEEK; + + self.sock_recv_method.call( + ( + self.sock.clone(), + vm.ctx.new_int(size), + vm.new_pyobj(MSG_PEEK), + ), + vm, + ) } /// Socket send - just sends data, caller must handle pending flush @@ -2800,13 +2818,8 @@ mod _ssl { return write_method.call((vm.ctx.new_bytes(data.to_vec()),), vm); } - // Normal socket mode - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - - // Call socket.socket.send(self.sock, data) - let send_method = socket_class.get_attr("send", vm)?; - send_method.call((self.sock.clone(), vm.ctx.new_bytes(data.to_vec())), vm) + self.sock_send_method + .call((self.sock.clone(), vm.ctx.new_bytes(data.to_vec())), vm) } /// Flush any pending TLS output data to the socket From 7fa757dec36bd7a780562ea1316332eb54416942 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Tue, 19 May 2026 01:51:20 +0200 Subject: [PATCH 02/11] Implement reading of at most one TLS record from socket Previous algorithm didn't take into account that recv() may return less data than requested even for blocking sockets. --- crates/stdlib/src/ssl.rs | 182 +++++++++++++++++++------------- crates/stdlib/src/ssl/compat.rs | 108 ++++++------------- 2 files changed, 140 insertions(+), 150 deletions(-) diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index ddff6a14d39..eff3ec8c421 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -41,8 +41,8 @@ mod _ssl { AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef, - PyUtf8StrRef, + PyBaseExceptionRef, PyByteArray, PyBytesRef, PyListRef, PyStrRef, PyType, + PyTypeRef, PyUtf8StrRef, }, convert::IntoPyException, function::{ @@ -1882,6 +1882,10 @@ mod _ssl { sock: args.sock.clone(), sock_send_method: socket_class.get_attr("send", vm)?, sock_recv_method: socket_class.get_attr("recv", vm)?, + tls_record_header_buf: vm + .ctx + .new_bytearray(Vec::with_capacity(TLS_RECORD_HEADER_SIZE)) + .into(), context: PyRwLock::new(zelf), server_side: args.server_side, server_hostname: PyRwLock::new(hostname), @@ -1897,6 +1901,7 @@ mod _ssl { sni_state: PyRwLock::new(None), pending_context: PyRwLock::new(None), client_hello_buffer: PyMutex::new(None), + sni_callback_processed: PyMutex::new(false), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), write_buffered_len: PyMutex::new(0), @@ -1958,6 +1963,7 @@ mod _ssl { sock_send_method: vm.ctx.none(), sock_recv_method: vm.ctx.none(), + tls_record_header_buf: vm.ctx.none(), context: PyRwLock::new(zelf), server_side, server_hostname: PyRwLock::new(hostname), @@ -1973,6 +1979,7 @@ mod _ssl { sni_state: PyRwLock::new(None), pending_context: PyRwLock::new(None), client_hello_buffer: PyMutex::new(None), + sni_callback_processed: PyMutex::new(false), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), write_buffered_len: PyMutex::new(0), @@ -2317,6 +2324,9 @@ mod _ssl { // Cached socket.socket.recv #[pytraverse(skip)] sock_recv_method: PyObjectRef, + // Header of currently read TLS record. + #[pytraverse(skip)] + tls_record_header_buf: PyObjectRef, // SSL context context: PyRwLock>, // Server-side or client-side @@ -2353,6 +2363,9 @@ mod _ssl { // Buffer to store ClientHello for connection recreation #[pytraverse(skip)] client_hello_buffer: PyMutex>>, + // Whether the Python SNI callback has already been run for this handshake + #[pytraverse(skip)] + sni_callback_processed: PyMutex, // Shutdown state for tracking close-notify exchange #[pytraverse(skip)] shutdown_state: PyMutex, @@ -2382,6 +2395,9 @@ mod _ssl { Completed, // unwrap() completed successfully } + /// TLS record header size (content_type + version + length). + const TLS_RECORD_HEADER_SIZE: usize = 5; + #[pyclass(with(Constructor, Representable), flags(BASETYPE))] impl PySSLSocket { // Check if this is BIO mode @@ -2655,9 +2671,9 @@ mod _ssl { // These methods support the server-side handshake SNI callback mechanism /// Check if this is the first read during handshake (for SNI callback) - /// Returns true if we haven't processed ClientHello yet, regardless of SNI presence + /// Returns true until the SNI callback has been processed. pub(crate) fn is_first_sni_read(&self) -> bool { - self.client_hello_buffer.lock().is_none() + !*self.sni_callback_processed.lock() } /// Check if SNI callback is configured @@ -2666,9 +2682,13 @@ mod _ssl { self.context.read().sni_callback.read().is_some() } - /// Save ClientHello data from PyObjectRef for potential connection recreation + /// Save ClientHello data for potential connection recreation. pub(crate) fn save_client_hello_from_bytes(&self, bytes_data: &[u8]) { - *self.client_hello_buffer.lock() = Some(bytes_data.to_vec()); + let mut buffer = self.client_hello_buffer.lock(); + match buffer.as_mut() { + Some(existing) => existing.extend_from_slice(bytes_data), + None => *buffer = Some(bytes_data.to_vec()), + } } /// Get the extracted SNI name from resolver @@ -2790,22 +2810,81 @@ mod _ssl { .call((self.sock.clone(), vm.ctx.new_int(size)), vm) } - /// Peek at socket data without consuming it (MSG_PEEK). - /// Used during TLS shutdown to avoid consuming post-TLS cleartext data. - pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult { - #[cfg(not(windows))] - use libc::MSG_PEEK; - #[cfg(windows)] - use windows_sys::Win32::Networking::WinSock::MSG_PEEK; + // Helper to receive data for at most one TLS record. + // May return incomplete data but never returns more when completes a + // previously incomplete TLS record. + pub(crate) fn sock_recv_at_most_one_tls_record( + &self, + vm: &VirtualMachine, + ) -> PyResult { + let obj_to_bytes = |bytes_obj| { + PyBytesRef::try_from_object(vm, bytes_obj) + .map_err(|_| vm.new_os_error("Expected bytes from recv".to_string())) + }; - self.sock_recv_method.call( - ( - self.sock.clone(), - vm.ctx.new_int(size), - vm.new_pyobj(MSG_PEEK), - ), - vm, - ) + let tls_record_header_buf = self + .tls_record_header_buf + .clone() + .downcast::() + .expect("BUG: tls_record_header_buf is not PyByteArray"); + + let buf_len = tls_record_header_buf.borrow_buf().len(); + + let (mut with_header, mut remaining_record_body_len) = + if buf_len < TLS_RECORD_HEADER_SIZE { + // We do not have a full TLS record header, start receiving one. + let bytes_obj = self.sock_recv(TLS_RECORD_HEADER_SIZE - buf_len, vm)?; + let bytes = obj_to_bytes(bytes_obj)?; + + let mut buf = tls_record_header_buf.borrow_buf_mut(); + buf.extend_from_slice(bytes.as_bytes()); + + if buf.len() < TLS_RECORD_HEADER_SIZE { + return Ok(bytes); + } + + // Parse the remaining length. + let record_body_len = u16::from_be_bytes([buf[3], buf[4]]); + // Validity of length value will be checked by rustls. + + // Zero-length TLS record. + if record_body_len == 0 { + buf.clear(); + return Ok(bytes); + } + + let mut bytes_vec = bytes.as_bytes().to_vec(); + bytes_vec.reserve(record_body_len as usize); + (Some(bytes_vec), record_body_len) + } else { + let buf = tls_record_header_buf.borrow_buf(); + let remaining_record_body_len = u16::from_be_bytes([buf[3], buf[4]]); + (None, remaining_record_body_len) + }; + + // We have full record header and are in a process of receiving a record. + let bytes_obj = self.sock_recv(remaining_record_body_len as usize, vm)?; + let bytes = obj_to_bytes(bytes_obj)?; + + if let Some(with_header) = with_header.as_mut() { + with_header.extend_from_slice(bytes.as_bytes()); + } + + let mut buf = tls_record_header_buf.borrow_buf_mut(); + remaining_record_body_len -= bytes.len() as u16; + if remaining_record_body_len == 0 { + // Record received completely, need to start a new one beginning with its header. + buf.clear(); + } else { + // Update remaining length in the header. + buf.as_mut_slice()[3..5].copy_from_slice(&remaining_record_body_len.to_be_bytes()); + } + + if let Some(with_header) = with_header { + Ok(vm.ctx.new_bytes(with_header)) + } else { + Ok(bytes) + } } /// Socket send - just sends data, caller must handle pending flush @@ -3455,6 +3534,7 @@ mod _ssl { // Now safe to call Python callback (no locks held) self.invoke_sni_callback(sni_name.as_deref(), vm)?; + *self.sni_callback_processed.lock() = true; // Clear connection to trigger recreation *self.connection.lock() = None; @@ -4293,7 +4373,7 @@ mod _ssl { // transitions to cleartext. Without peeking, sock_recv may consume // cleartext data meant for the application after unwrap(). if self.incoming_bio.is_none() { - return self.try_read_close_notify_socket(conn, vm); + return Ok(self.try_read_close_notify_socket(conn, vm)); } // BIO mode: read from incoming BIO @@ -4340,67 +4420,25 @@ mod _ssl { &self, conn: &mut TlsConnection, vm: &VirtualMachine, - ) -> PyResult { - // Peek at the first 5 bytes (TLS record header size) - let peeked_obj = match self.sock_peek(5, vm) { - Ok(obj) => obj, - Err(e) => { - if is_blocking_io_error(&e, vm) { - return Ok(false); - } - return Ok(true); - } - }; - - let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)?; - let peek_data = peeked.borrow_buf(); - - if peek_data.is_empty() { - return Ok(true); // EOF - } - - // TLS record content types: ChangeCipherSpec(20), Alert(21), - // Handshake(22), ApplicationData(23) - let content_type = peek_data[0]; - if !(20..=23).contains(&content_type) { - // Not a TLS record - post-TLS cleartext data. - // Peer has completed TLS shutdown; don't consume this data. - return Ok(true); - } - - // Determine how many bytes to read for exactly one TLS record - let recv_size = if peek_data.len() >= 5 { - let record_length = u16::from_be_bytes([peek_data[3], peek_data[4]]) as usize; - 5 + record_length - } else { - // Partial header available - read just these bytes for now - peek_data.len() - }; - - drop(peek_data); - drop(peeked); - - // Now consume exactly one TLS record from the socket - match self.sock_recv(recv_size, vm) { - Ok(bytes_obj) => { - let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; - let data = bytes.borrow_buf(); - + ) -> bool { + // Consume at most one TLS record from the socket + match self.sock_recv_at_most_one_tls_record(vm) { + Ok(data) => { if data.is_empty() { - return Ok(true); + return true; } let data_slice: &[u8] = data.as_ref(); let mut cursor = std::io::Cursor::new(data_slice); let _ = conn.read_tls(&mut cursor); let _ = conn.process_new_packets(); - Ok(false) + false } Err(e) => { if is_blocking_io_error(&e, vm) { - return Ok(false); + return false; } - Ok(true) + true } } } diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index ed3880940b9..62efe2c653a 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -1163,14 +1163,10 @@ fn handshake_write_loop( Ok(made_progress) } -/// Read TLS handshake data from socket/BIO +/// Read at most one TLS record from the TCP socket. /// -/// Waits for and reads TLS records from the peer, handling SNI callback setup. -/// Returns (made_progress, is_first_sni_read). -/// TLS record header size (content_type + version + length). -const TLS_RECORD_HEADER_SIZE: usize = 5; - -/// Read exactly one TLS record from the TCP socket. +/// May return incomplete data but never returns more when completes a +/// previously incomplete TLS record. /// /// OpenSSL reads one TLS record at a time (no read-ahead by default). /// Rustls, however, consumes all available TCP data when fed via read_tls(). @@ -1183,77 +1179,32 @@ const TLS_RECORD_HEADER_SIZE: usize = 5; /// Fix: peek at the TCP buffer to find the first complete TLS record boundary /// and recv() only that many bytes. Any remaining data stays in the kernel /// buffer and remains visible to select(). -fn recv_one_tls_record(socket: &PySSLSocket, vm: &VirtualMachine) -> SslResult { - // Peek at what is available without consuming it. - let peeked_obj = match socket.sock_peek(SSL3_RT_MAX_PLAIN_LENGTH, vm) { - Ok(d) => d, - Err(e) => { - if is_blocking_io_error(&e, vm) { - return Err(SslError::WantRead); - } - return Err(SslError::Py(e)); - } - }; - - let peeked = ArgBytesLike::try_from_object(vm, peeked_obj) - .map_err(|_| SslError::Syscall("Expected bytes-like object from peek".to_string()))?; - let peeked_bytes = peeked.borrow_buf(); - - if peeked_bytes.is_empty() { - // Empty peek means the peer has closed the TCP connection (FIN). - // Non-blocking sockets would have returned EAGAIN/EWOULDBLOCK - // (caught above as WantRead), so empty bytes here always means EOF. - return Err(SslError::Eof); - } - - if peeked_bytes.len() < TLS_RECORD_HEADER_SIZE { - // Not enough data for a TLS record header yet. - // Read all available bytes so rustls can buffer the partial header; - // this avoids busy-waiting because the kernel buffer is now empty - // and select() will only wake us when new data arrives. - return socket.sock_recv(peeked_bytes.len(), vm).map_err(|e| { - if is_blocking_io_error(&e, vm) { - SslError::WantRead - } else { - SslError::Py(e) - } - }); - } - - // Parse the TLS record length from the header. - let record_body_len = u16::from_be_bytes([peeked_bytes[3], peeked_bytes[4]]) as usize; - let total_record_size = TLS_RECORD_HEADER_SIZE + record_body_len; - - let recv_size = if peeked_bytes.len() >= total_record_size { - // Complete record available — consume exactly one record. - total_record_size - } else { - // Incomplete record — consume everything so the kernel buffer is - // drained and select() will block until more data arrives. - peeked_bytes.len() - }; - - // Must drop the borrow before calling sock_recv (which re-enters Python). - drop(peeked_bytes); - drop(peeked); - - socket.sock_recv(recv_size, vm).map_err(|e| { +fn recv_at_most_one_tls_record( + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult { + let bytes = socket.sock_recv_at_most_one_tls_record(vm).map_err(|e| { if is_blocking_io_error(&e, vm) { SslError::WantRead } else { SslError::Py(e) } - }) + })?; + if bytes.is_empty() { + Err(SslError::Eof) + } else { + Ok(bytes.into()) + } } -/// Read a single TLS record for post-handshake I/O while preserving the +/// Read up to a single TLS record for post-handshake I/O while preserving the /// SSL-vs-socket error precedence from the old sock_recv() path. -fn recv_one_tls_record_for_data( +fn recv_at_most_one_tls_record_for_data( conn: &mut TlsConnection, socket: &PySSLSocket, vm: &VirtualMachine, ) -> SslResult { - match recv_one_tls_record(socket, vm) { + match recv_at_most_one_tls_record(socket, vm) { Ok(data) => Ok(data), Err(SslError::Eof) => { if let Err(rustls_err) = conn.process_new_packets() { @@ -1285,9 +1236,10 @@ fn handshake_read_data( return Ok((false, false)); } - // SERVER-SPECIFIC: Check if this is the first read (for SNI callback) - // Must check BEFORE reading data, so we can detect first time - let is_first_sni_read = is_server && socket.is_first_sni_read(); + // SERVER-SPECIFIC: Check if this is before the SNI callback. + // sock_recv() may return only part of a TLS record, so keep capturing + // ClientHello fragments until process_new_packets() has produced a response. + let is_first_sni_read = is_server && socket.has_sni_callback() && socket.is_first_sni_read(); // Wait for data in socket mode if !is_bio { @@ -1308,7 +1260,7 @@ fn handshake_read_data( // record. This matches OpenSSL's default (no read-ahead) behaviour // and keeps remaining data in the kernel buffer where select() can // detect it. - recv_one_tls_record(socket, vm)? + recv_at_most_one_tls_record(socket, vm)? } else { match socket.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { Ok(d) => d, @@ -1324,7 +1276,7 @@ fn handshake_read_data( } }; - // SERVER-SPECIFIC: Save ClientHello on first read for potential connection recreation + // SERVER-SPECIFIC: Save ClientHello fragments for potential connection recreation. if is_first_sni_read { // Extract bytes from PyObjectRef use rustpython_vm::builtins::PyBytes; @@ -1506,10 +1458,10 @@ pub(super) fn ssl_do_handshake( return Err(SslError::from_rustls(e)); } - // SERVER-SPECIFIC: Check SNI callback after processing packets - // SNI name is extracted during process_new_packets() - // Invoke callback on FIRST read if callback is configured, regardless of SNI presence - if is_server && is_first_sni_read && socket.has_sni_callback() { + // SERVER-SPECIFIC: Check SNI callback after processing packets. + // A partial TLS record can be read without producing any handshake + // response. Wait until rustls has processed a complete ClientHello. + if is_server && is_first_sni_read && socket.has_sni_callback() && conn.wants_write() { // IMPORTANT: Do NOT call the callback here! // The connection lock is still held, which would cause deadlock. // Return SniCallbackRestart to signal do_handshake to: @@ -1753,7 +1705,7 @@ pub(super) fn ssl_read( // Blocking socket or socket with timeout: try to read more data from socket. // Even though rustls says it doesn't want to read, more TLS records may arrive. // Use single-record reading to avoid consuming close_notify alongside data. - let data = recv_one_tls_record_for_data(conn, socket, vm)?; + let data = recv_at_most_one_tls_record_for_data(conn, socket, vm)?; let bytes_read = data .clone() @@ -1944,7 +1896,7 @@ pub(super) fn ssl_write( return Err(SslError::WantRead); } // For socket mode, try to read TLS data - let recv_result = socket.sock_recv(4096, vm).map_err(SslError::Py)?; + let recv_result = recv_at_most_one_tls_record_for_data(conn, socket, vm)?; ssl_read_tls_records(conn, recv_result, false, vm)?; conn.process_new_packets().map_err(SslError::from_rustls)?; // Continue loop @@ -2157,7 +2109,7 @@ fn ssl_ensure_data_available( // consuming a close_notify that arrives alongside application data, // keeping it in the kernel buffer where select() can detect it. let data = if !is_bio { - recv_one_tls_record_for_data(conn, socket, vm)? + recv_at_most_one_tls_record_for_data(conn, socket, vm)? } else { match socket.sock_recv(2048, vm) { Ok(data) => data, From bc2dbf98f69e0168debc052a64dfb9a91251edd4 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Tue, 19 May 2026 04:54:17 +0200 Subject: [PATCH 03/11] Remove special handling of rustls "buffer full" errors First of all, existing code does not really work and this leads to an infinite loop: https://github.com/RustPython/RustPython/issues/7891 Second, this should never happen when rustls used properly (wrt wants_read() and wants_write()) and thus all such errors are implementation bugs that must be properly fixed. --- crates/stdlib/src/ssl/compat.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 62efe2c653a..6b9cfc26ffc 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -1740,11 +1740,6 @@ pub(super) fn ssl_read( // Successfully read and processed TLS data // Continue loop to try reading plaintext } - Err(SslError::Io(ref io_err)) if io_err.to_string().contains("message buffer full") => { - // This case should be rare now that ssl_read_tls_records handles buffer full - // Just continue loop to try again - continue; - } Err(e) => { // Other errors - check for buffered plaintext before propagating match try_read_plaintext(conn, buf)? { @@ -2020,6 +2015,9 @@ fn ssl_read_tls_records( } Ok(n) => { offset += n; + if offset < bytes_data.len() { + conn.process_new_packets().map_err(SslError::from_rustls)?; + } } Err(e) => { return Err(SslError::Io(e)); @@ -2027,14 +2025,12 @@ fn ssl_read_tls_records( } } else { offset += read_bytes; + if offset < bytes_data.len() { + conn.process_new_packets().map_err(SslError::from_rustls)?; + } } } Err(e) => { - // Check if it's a buffer full error (unlikely but handle it) - if e.to_string().contains("buffer full") { - conn.process_new_packets().map_err(SslError::from_rustls)?; - continue; - } // Real error - propagate it return Err(SslError::Io(e)); } From 023c75440db352a9f35a42d7f360d3b30ff2ce06 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Tue, 19 May 2026 05:01:11 +0200 Subject: [PATCH 04/11] Replace own TlsConnection with rustls::Connection --- crates/stdlib/src/ssl.rs | 38 ++++----- crates/stdlib/src/ssl/compat.rs | 147 +++----------------------------- 2 files changed, 30 insertions(+), 155 deletions(-) diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index eff3ec8c421..b9182f0effa 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -75,7 +75,8 @@ mod _ssl { use parking_lot::{Mutex as ParkingMutex, RwLock as ParkingRwLock}; use pem_rfc7468::{LineEnding, encode_string}; use rustls::{ - ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection, + ClientConfig, ClientConnection, Connection, HandshakeKind, RootCertStore, ServerConfig, + ServerConnection, client::{ClientSessionMemoryCache, ClientSessionStore}, crypto::SupportedKxGroup, pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName}, @@ -94,9 +95,8 @@ mod _ssl { // Import compat module (OpenSSL compatibility layer) use super::compat::{ ClientConfigOptions, MultiCertResolver, ProtocolSettings, ServerConfigOptions, SslError, - TlsConnection, create_client_config, create_server_config, curve_name_to_kx_group, - extract_cipher_info, get_cipher_encryption_desc, is_blocking_io_error, - normalize_cipher_name, ssl_do_handshake, + create_client_config, create_server_config, curve_name_to_kx_group, extract_cipher_info, + get_cipher_encryption_desc, is_blocking_io_error, normalize_cipher_name, ssl_do_handshake, }; // Type aliases for better readability @@ -2337,7 +2337,7 @@ mod _ssl { server_hostname: PyRwLock>, // TLS connection state #[pytraverse(skip)] - connection: PyMutex>, + connection: PyMutex>, // Handshake completed flag #[pytraverse(skip)] handshake_done: PyMutex, @@ -2578,7 +2578,7 @@ mod _ssl { .connection .lock() .as_ref() - .is_some_and(|conn| conn.is_session_resumed()); + .is_some_and(|conn| conn.handshake_kind() == Some(HandshakeKind::Resumed)); *self.session_was_reused.lock() = was_resumed; @@ -3204,7 +3204,7 @@ mod _ssl { /// Returns the configured ServerConnection. fn initialize_server_connection( &self, - conn_guard: &mut Option, + conn_guard: &mut Option, vm: &VirtualMachine, ) -> PyResult<()> { let ctx = self.context.read(); @@ -3374,11 +3374,11 @@ mod _ssl { vm.new_value_error(format!("Failed to create server connection: {e}")) })?; - *conn_guard = Some(TlsConnection::Server(conn)); + *conn_guard = Some(Connection::Server(conn)); // If ClientHello buffer exists (from SNI callback), re-inject it if let Some(ref hello_data) = *self.client_hello_buffer.lock() - && let Some(TlsConnection::Server(ref mut server)) = *conn_guard + && let Some(Connection::Server(ref mut server)) = *conn_guard { let mut cursor = std::io::Cursor::new(hello_data.as_slice()); let _ = server.read_tls(&mut cursor); @@ -3501,14 +3501,14 @@ mod _ssl { vm.new_value_error(format!("Failed to create client connection: {e}")) })?; - *conn_guard = Some(TlsConnection::Client(conn)); + *conn_guard = Some(Connection::Client(conn)); } } // Perform the actual handshake by exchanging data with the socket/BIO let conn = conn_guard.as_mut().expect("unreachable"); - let is_client = matches!(conn, TlsConnection::Client(_)); + let is_client = matches!(conn, Connection::Client(_)); let handshake_result = ssl_do_handshake(conn, self, vm); drop(conn_guard); @@ -4335,7 +4335,7 @@ mod _ssl { } // Helper: Write all pending TLS data (including close_notify) to outgoing buffer/BIO - fn write_pending_tls(&self, conn: &mut TlsConnection, vm: &VirtualMachine) -> PyResult<()> { + fn write_pending_tls(&self, conn: &mut Connection, vm: &VirtualMachine) -> PyResult<()> { // First, flush any previously pending TLS output // Must succeed before sending new data to maintain order self.flush_pending_tls_output(vm, None)?; @@ -4365,7 +4365,7 @@ mod _ssl { // Returns true if peer closed connection (with or without close_notify) fn try_read_close_notify( &self, - conn: &mut TlsConnection, + conn: &mut Connection, vm: &VirtualMachine, ) -> PyResult { // In socket mode, peek first to avoid consuming post-TLS cleartext @@ -4418,7 +4418,7 @@ mod _ssl { /// such knob, so we enforce record-level reads manually via peek. fn try_read_close_notify_socket( &self, - conn: &mut TlsConnection, + conn: &mut Connection, vm: &VirtualMachine, ) -> bool { // Consume at most one TLS record from the socket @@ -4444,11 +4444,7 @@ mod _ssl { } // Helper: Check if peer has sent close_notify - fn check_peer_closed( - &self, - conn: &mut TlsConnection, - vm: &VirtualMachine, - ) -> PyResult { + fn check_peer_closed(&self, conn: &mut Connection, vm: &VirtualMachine) -> PyResult { // Process any remaining packets and check peer_has_closed let io_state = conn .process_new_packets() @@ -4510,12 +4506,12 @@ mod _ssl { let conn_guard = self.connection.lock(); if let Some(conn) = conn_guard.as_ref() { let version = match conn { - TlsConnection::Client(_) => { + Connection::Client(_) => { return Err(vm.new_value_error( "Post-handshake authentication requires server socket", )); } - TlsConnection::Server(server) => server.protocol_version(), + Connection::Server(server) => server.protocol_version(), }; // Post-handshake auth is only available in TLS 1.3 diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 6b9cfc26ffc..1b7f79e3577 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -19,14 +19,13 @@ use crate::socket::{SelectKind, timeout_error_msg}; use crate::vm::VirtualMachine; use alloc::sync::Arc; use parking_lot::RwLock as ParkingRwLock; +use rustls::Connection; use rustls::RootCertStore; use rustls::client::ClientConfig; -use rustls::client::ClientConnection; use rustls::crypto::SupportedKxGroup; use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer}; use rustls::server::ResolvesServerCert; use rustls::server::ServerConfig; -use rustls::server::ServerConnection; use rustls::sign::CertifiedKey; use rustpython_vm::builtins::{PyBaseException, PyBaseExceptionRef}; use rustpython_vm::convert::IntoPyException; @@ -263,126 +262,6 @@ pub(super) fn create_ssl_cert_verification_error( Ok(exc.upcast()) } -/// Unified TLS connection type (client or server) -#[derive(Debug)] -pub(super) enum TlsConnection { - Client(ClientConnection), - Server(ServerConnection), -} - -impl TlsConnection { - /// Check if handshake is in progress - pub(super) fn is_handshaking(&self) -> bool { - match self { - Self::Client(conn) => conn.is_handshaking(), - Self::Server(conn) => conn.is_handshaking(), - } - } - - /// Check if connection wants to read data - pub(super) fn wants_read(&self) -> bool { - match self { - Self::Client(conn) => conn.wants_read(), - Self::Server(conn) => conn.wants_read(), - } - } - - /// Check if connection wants to write data - pub(super) fn wants_write(&self) -> bool { - match self { - Self::Client(conn) => conn.wants_write(), - Self::Server(conn) => conn.wants_write(), - } - } - - /// Read TLS data from socket - pub(super) fn read_tls(&mut self, reader: &mut dyn std::io::Read) -> std::io::Result { - match self { - Self::Client(conn) => conn.read_tls(reader), - Self::Server(conn) => conn.read_tls(reader), - } - } - - /// Write TLS data to socket - pub(super) fn write_tls(&mut self, writer: &mut dyn std::io::Write) -> std::io::Result { - match self { - Self::Client(conn) => conn.write_tls(writer), - Self::Server(conn) => conn.write_tls(writer), - } - } - - /// Process new TLS packets - pub(super) fn process_new_packets(&mut self) -> Result { - match self { - Self::Client(conn) => conn.process_new_packets(), - Self::Server(conn) => conn.process_new_packets(), - } - } - - /// Get reader for plaintext data (rustls native type) - pub(super) fn reader(&mut self) -> rustls::Reader<'_> { - match self { - Self::Client(conn) => conn.reader(), - Self::Server(conn) => conn.reader(), - } - } - - /// Get writer for plaintext data (rustls native type) - pub(super) fn writer(&mut self) -> rustls::Writer<'_> { - match self { - Self::Client(conn) => conn.writer(), - Self::Server(conn) => conn.writer(), - } - } - - /// Check if session was resumed - pub(super) fn is_session_resumed(&self) -> bool { - use rustls::HandshakeKind; - match self { - Self::Client(conn) => { - matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) - } - Self::Server(conn) => { - matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) - } - } - } - - /// Send close_notify alert - pub(super) fn send_close_notify(&mut self) { - match self { - Self::Client(conn) => conn.send_close_notify(), - Self::Server(conn) => conn.send_close_notify(), - } - } - - /// Get negotiated ALPN protocol - pub(super) fn alpn_protocol(&self) -> Option<&[u8]> { - match self { - Self::Client(conn) => conn.alpn_protocol(), - Self::Server(conn) => conn.alpn_protocol(), - } - } - - /// Get negotiated cipher suite - pub(super) fn negotiated_cipher_suite(&self) -> Option { - match self { - Self::Client(conn) => conn.negotiated_cipher_suite(), - Self::Server(conn) => conn.negotiated_cipher_suite(), - } - } - - /// Get peer certificates - pub(super) fn peer_certificates( - &self, - ) -> Option<&[rustls::pki_types::CertificateDer<'static>]> { - match self { - Self::Client(conn) => conn.peer_certificates(), - Self::Server(conn) => conn.peer_certificates(), - } - } -} - /// Error types matching OpenSSL error codes #[derive(Debug)] pub(super) enum SslError { @@ -1120,7 +999,7 @@ fn send_all_bytes( /// Drains all pending TLS data from rustls and sends it to the peer. /// Returns whether any progress was made. fn handshake_write_loop( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, force_initial_write: bool, vm: &VirtualMachine, @@ -1200,7 +1079,7 @@ fn recv_at_most_one_tls_record( /// Read up to a single TLS record for post-handshake I/O while preserving the /// SSL-vs-socket error precedence from the old sock_recv() path. fn recv_at_most_one_tls_record_for_data( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, vm: &VirtualMachine, ) -> SslResult { @@ -1226,7 +1105,7 @@ fn recv_at_most_one_tls_record_for_data( } fn handshake_read_data( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, is_bio: bool, is_server: bool, @@ -1296,7 +1175,7 @@ fn handshake_read_data( /// Tries to send NewSessionTicket in non-blocking mode to avoid deadlocks. /// Returns true if handshake is complete and we should exit. fn handle_handshake_complete( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, _is_server: bool, vm: &VirtualMachine, @@ -1371,7 +1250,7 @@ fn handle_handshake_complete( /// /// Returns Ok(Some(n)) if n bytes were read, Ok(None) if would block, /// or Err on real errors. -fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult> { +fn try_read_plaintext(conn: &mut Connection, buf: &mut [u8]) -> SslResult> { let mut reader = conn.reader(); match reader.read(buf) { Ok(0) => { @@ -1400,7 +1279,7 @@ fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult SslResult<()> { @@ -1410,7 +1289,7 @@ pub(super) fn ssl_do_handshake( } let is_bio = socket.is_bio_mode(); - let is_server = matches!(conn, TlsConnection::Server(_)); + let is_server = matches!(conn, Connection::Server(_)); let mut first_iteration = true; // Track if this is the first loop iteration let mut iteration_count = 0; @@ -1567,7 +1446,7 @@ pub(super) fn ssl_do_handshake( /// /// = SSL_read_ex() pub(super) fn ssl_read( - conn: &mut TlsConnection, + conn: &mut Connection, buf: &mut [u8], socket: &PySSLSocket, vm: &VirtualMachine, @@ -1764,7 +1643,7 @@ pub(super) fn ssl_read( /// /// = SSL_write_ex() pub(super) fn ssl_write( - conn: &mut TlsConnection, + conn: &mut Connection, data: &[u8], socket: &PySSLSocket, vm: &VirtualMachine, @@ -1941,7 +1820,7 @@ pub(super) fn ssl_write( // Helper functions (private-ish, used by public SSL functions) /// Write TLS records from rustls to socket -fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult> { +fn ssl_write_tls_records(conn: &mut Connection) -> SslResult> { let mut buf = Vec::new(); let n = conn .write_tls(&mut buf as &mut dyn std::io::Write) @@ -1952,7 +1831,7 @@ fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult> { /// Read TLS records from socket to rustls fn ssl_read_tls_records( - conn: &mut TlsConnection, + conn: &mut Connection, data: PyObjectRef, is_bio: bool, vm: &VirtualMachine, @@ -2066,7 +1945,7 @@ fn is_connection_closed_error(exc: &Py, vm: &VirtualMachine) -> /// Ensure TLS data is available for reading /// Returns the number of bytes read from the socket fn ssl_ensure_data_available( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, vm: &VirtualMachine, ) -> SslResult { From e99f0c4103812562dc2af5f3817b972758bade31 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Tue, 19 May 2026 08:18:08 +0200 Subject: [PATCH 05/11] Fix waiting on a socket 1) Ensure that socket_wait() called from TLS glue code allows threads 2) Ensure that socket_wait() called from TLS glue code properly handles EINTR on *nix 3) Ensure that select() or poll() error conditions are checked 4) Use poll() on *nix so socket descriptor values are not limited --- crates/host_env/src/socket.rs | 33 ------ crates/stdlib/src/socket.rs | 204 +++++++++++++++++++++----------- crates/stdlib/src/ssl.rs | 30 ++--- crates/stdlib/src/ssl/compat.rs | 10 +- 4 files changed, 148 insertions(+), 129 deletions(-) diff --git a/crates/host_env/src/socket.rs b/crates/host_env/src/socket.rs index d6f1e078c72..c409132a725 100644 --- a/crates/host_env/src/socket.rs +++ b/crates/host_env/src/socket.rs @@ -3,20 +3,10 @@ use crate::os::CheckLibcResult; #[cfg(unix)] use core::ffi::CStr; #[cfg(unix)] -use core::time::Duration; -#[cfg(unix)] use std::os::fd::AsRawFd; #[cfg(unix)] use std::{io, os::fd::BorrowedFd}; -#[cfg(unix)] -#[derive(Copy, Clone)] -pub enum PollKind { - Read, - Write, - Connect, -} - #[cfg(all(unix, not(target_os = "redox")))] pub fn sethostname(hostname: &str) -> io::Result<()> { nix::unistd::sethostname(hostname).map_err(io::Error::from) @@ -111,29 +101,6 @@ pub fn setsockopt_none(fd: libc::c_int, level: i32, name: i32, optlen: u32) -> i Ok(()) } -#[cfg(unix)] -pub fn poll_socket( - fd: BorrowedFd<'_>, - kind: PollKind, - interval: Option, -) -> io::Result { - use nix::poll::{PollFd, PollFlags, PollTimeout, poll}; - - let events = match kind { - PollKind::Read => PollFlags::POLLIN, - PollKind::Write => PollFlags::POLLOUT, - PollKind::Connect => PollFlags::POLLOUT | PollFlags::POLLERR, - }; - let mut pollfd = [PollFd::new(fd, events)]; - let timeout = match interval { - Some(d) => d.try_into().unwrap_or(PollTimeout::MAX), - None => PollTimeout::NONE, - }; - poll(&mut pollfd, timeout) - .map(|ret| ret == 0) - .map_err(io::Error::from) -} - #[cfg(any( target_os = "dragonfly", target_os = "freebsd", diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index 55aae79df61..6aff5d452c4 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -3,7 +3,7 @@ pub(crate) use _socket::module_def; #[cfg(feature = "ssl")] -pub(super) use _socket::{PySocket, SelectKind, sock_select, timeout_error_msg}; +pub(super) use _socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg}; #[pymodule] mod _socket { @@ -1060,20 +1060,20 @@ mod _socket { fn sock_op( &self, vm: &VirtualMachine, - select: SelectKind, + wait_kind: SockWaitKind, f: F, ) -> Result where F: FnMut() -> io::Result, { let timeout = self.get_timeout().ok(); - self.sock_op_timeout_err(vm, select, timeout, f) + self.sock_op_timeout_err(vm, wait_kind, timeout, f) } fn sock_op_timeout_err( &self, vm: &VirtualMachine, - select: SelectKind, + wait_kind: SockWaitKind, timeout: Option, mut f: F, ) -> Result @@ -1083,19 +1083,9 @@ mod _socket { let deadline = timeout.map(Deadline::new); loop { - if deadline.is_some() || matches!(select, SelectKind::Connect) { - let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; + if deadline.is_some() || matches!(wait_kind, SockWaitKind::Connect) { let sock = self.sock()?; - let res = vm.allow_threads(|| sock_select(&sock, select, interval)); - match res { - Ok(true) => return Err(IoOrPyException::Timeout), - Err(e) if e.kind() == io::ErrorKind::Interrupted => { - vm.check_signals()?; - continue; - } - Err(e) => return Err(e.into()), - Ok(false) => {} // no timeout, continue as normal - } + sock_wait_deadline(&sock, wait_kind, &deadline, vm)?; } let err = loop { @@ -1339,16 +1329,11 @@ mod _socket { }; if wait_connect { - // basically, connect() is async, and it registers an "error" on the socket when it's - // done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up - // from poll and the error is EISCONN then we know that the connect is done - self.sock_op(vm, SelectKind::Connect, || { + self.sock_op(vm, SockWaitKind::Connect, || { let sock = self.sock()?; let err = sock.take_error()?; match err { - Some(e) if e.posix_errno() == libc::EISCONN => Ok(()), Some(e) => Err(e), - // TODO: is this accurate? None => Ok(()), } }) @@ -1587,7 +1572,8 @@ mod _socket { ) -> Result<(RawSocket, PyObjectRef), IoOrPyException> { // Use accept_raw() instead of accept() to avoid socket2's set_common_flags() // which tries to set SO_NOSIGPIPE and fails with EINVAL on Unix domain sockets on macOS - let (sock, addr) = self.sock_op(vm, SelectKind::Read, || self.sock()?.accept_raw())?; + let (sock, addr) = + self.sock_op(vm, SockWaitKind::Read, || self.sock()?.accept_raw())?; let fd = into_sock_fileno(sock); Ok((fd, get_addr_tuple(&addr, vm))) } @@ -1602,7 +1588,7 @@ mod _socket { let flags = flags.unwrap_or(0); let mut buffer = Vec::with_capacity(bufsize); let sock = self.sock()?; - let n = self.sock_op(vm, SelectKind::Read, || { + let n = self.sock_op(vm, SockWaitKind::Read, || { sock.recv_with_flags(buffer.spare_capacity_mut(), flags) })?; unsafe { buffer.set_len(n) }; @@ -1633,7 +1619,7 @@ mod _socket { }; let buf = &mut buf[..read_len]; - self.sock_op(vm, SelectKind::Read, || { + self.sock_op(vm, SockWaitKind::Read, || { sock.recv_with_flags(unsafe { slice_as_uninit(buf) }, flags) }) } @@ -1650,7 +1636,7 @@ mod _socket { .to_usize() .ok_or_else(|| vm.new_value_error("negative buffersize in recvfrom"))?; let mut buffer = Vec::with_capacity(bufsize); - let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || { self.sock()? .recv_from_with_flags(buffer.spare_capacity_mut(), flags) })?; @@ -1681,7 +1667,7 @@ mod _socket { }; let flags = flags.unwrap_or(0); let sock = self.sock()?; - let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || { sock.recv_from_with_flags(unsafe { slice_as_uninit(buf) }, flags) })?; Ok((n, get_addr_tuple(&addr, vm))) @@ -1697,7 +1683,7 @@ mod _socket { let flags = flags.unwrap_or(0); let buf = bytes.borrow_buf(); let buf = &*buf; - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { self.sock()?.send_with_flags(buf, flags) }) } @@ -1721,7 +1707,7 @@ mod _socket { // now we have like 3 layers of interrupt loop :) while buf_offset < buf.len() { let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; - self.sock_op_timeout_err(vm, SelectKind::Write, interval, || { + self.sock_op_timeout_err(vm, SockWaitKind::Write, interval, || { let subbuf = &buf[buf_offset..]; buf_offset += self.sock()?.send_with_flags(subbuf, flags)?; Ok(()) @@ -1754,7 +1740,7 @@ mod _socket { let addr = self.extract_address(address, "sendto", vm)?; let buf = bytes.borrow_buf(); let buf = &*buf; - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { self.sock()?.send_to_with_flags(buf, &addr, flags) }) } @@ -1812,7 +1798,7 @@ mod _socket { } } - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { let sock = self.sock()?; sock.sendmsg(&msg, flags) }) @@ -1848,7 +1834,7 @@ mod _socket { .collect::>(); let iv = iv.map(|iv| iv.borrow_buf().to_vec()); - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { let sock = self.sock()?; let fd = unsafe { BorrowedFd::borrow_raw(sock_fileno(&sock)) }; host_socket::sendmsg_afalg(fd, &buffers, op, iv.as_deref(), assoclen, flags) @@ -1881,7 +1867,7 @@ mod _socket { let flags = flags.unwrap_or(0); let msg = self - .sock_op(vm, SelectKind::Read, || { + .sock_op(vm, SockWaitKind::Read, || { let sock = self.sock()?; let fd = unsafe { std::os::fd::BorrowedFd::borrow_raw(sock_fileno(&sock)) }; host_socket::recvmsg(fd, bufsize, ancbufsize, flags) @@ -2436,61 +2422,135 @@ mod _socket { } #[derive(Copy, Clone)] - pub(crate) enum SelectKind { + pub(crate) enum SockWaitKind { Read, Write, Connect, } - /// returns true if timed out - pub(crate) fn sock_select( + /// returns Ok(true) on timeout + pub(crate) fn sock_wait( + sock: &Socket, + wait_kind: SockWaitKind, + timeout: Option, + vm: &VirtualMachine, + ) -> PyResult { + match sock_wait_deadline(sock, wait_kind, &timeout.map(Deadline::new), vm) { + Ok(()) => Ok(false), + Err(IoOrPyException::Timeout) => Ok(true), + Err(e) => Err(e.into_pyexception(vm)), + } + } + + /// returns Err(IoOrPyException::Timeout) on timeout + fn sock_wait_deadline( sock: &Socket, - kind: SelectKind, - interval: Option, - ) -> io::Result { + wait_kind: SockWaitKind, + deadline: &Option, + vm: &VirtualMachine, + ) -> Result<(), IoOrPyException> { #[cfg(unix)] { - use std::os::fd::AsFd; - let kind = match kind { - SelectKind::Read => host_socket::PollKind::Read, - SelectKind::Write => host_socket::PollKind::Write, - SelectKind::Connect => host_socket::PollKind::Connect, - }; - host_socket::poll_socket(sock.as_fd(), kind, interval) + use rustpython_host_env::select::{PollFd, poll_fds}; + + let mut events = 0; + if matches!(wait_kind, SockWaitKind::Read) { + events |= libc::POLLIN | libc::POLLPRI; + } + if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) { + events |= libc::POLLOUT; + } + let mut fds = [PollFd { + fd: sock_fileno(sock), + events, + revents: 0, + }; 1]; + + loop { + let (timeout, is_capped) = deadline + .as_ref() + .map(|d| { + d.time_until().map(|t| { + let timeout_ms = t.as_millis(); + let is_capped = timeout_ms > i32::MAX as u128; + let timeout = if is_capped { + i32::MAX + } else { + timeout_ms as i32 + }; + (timeout, is_capped) + }) + }) + .transpose()? + .unwrap_or((-1, false)); + + match vm.allow_threads(|| poll_fds(&mut fds, timeout)) { + Ok(0) => { + if is_capped { + continue; + } + break Err(IoOrPyException::Timeout); + } + + Ok(_) => { + if fds[0].revents & libc::POLLNVAL != 0 { + break Err(io::Error::from_raw_os_error(libc::EBADF).into()); + } + break Ok(()); + } + + Err(e) => { + if e.kind() == io::ErrorKind::Interrupted { + vm.check_signals()?; + continue; + } + break Err(e.into()); + } + } + } } #[cfg(windows)] { - use rustpython_host_env::select as host_select; + use rustpython_host_env::select::{FdSet, select, timeval}; - let fd = sock_fileno(sock); + let fd = sock_fileno(sock) as usize; - let mut reads = host_select::FdSet::new(); - let mut writes = host_select::FdSet::new(); - let mut errs = host_select::FdSet::new(); + let mut reads = FdSet::new(); + let mut writes = FdSet::new(); + let mut errs = FdSet::new(); - let fd = fd as usize; - match kind { - SelectKind::Read => reads.insert(fd), - SelectKind::Write => writes.insert(fd), - SelectKind::Connect => { - writes.insert(fd); - errs.insert(fd); - } + if matches!(wait_kind, SockWaitKind::Read) { + reads.insert(fd); + errs.insert(fd); + } + if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) { + writes.insert(fd); + errs.insert(fd); } - let mut interval = interval.map(|dur| host_select::timeval { - tv_sec: dur.as_secs() as _, - tv_usec: dur.subsec_micros() as _, - }); - - host_select::select( - fd as i32 + 1, - &mut reads, - &mut writes, - &mut errs, - interval.as_mut(), - ) - .map(|ret| ret == 0) + let mut timeout = deadline + .as_ref() + .map(|d| { + d.time_until().map(|dur| timeval { + tv_sec: dur.as_secs() as _, + tv_usec: dur.subsec_micros() as _, + }) + }) + .transpose()?; + + match vm.allow_threads(|| { + select( + 0, // nfds is ignored on windows + &mut reads, + &mut writes, + &mut errs, + timeout.as_mut(), + ) + }) { + Ok(0) => Err(IoOrPyException::Timeout), + Ok(_) => Ok(()), + Err(e) => Err(e.into()), + } } } diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index b9182f0effa..a87abce68d8 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -36,7 +36,7 @@ mod _ssl { hash::PyHash, lock::{PyMutex, PyRwLock}, }, - socket::{PySocket, SelectKind, sock_select, timeout_error_msg}, + socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg}, vm::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, @@ -2606,7 +2606,7 @@ mod _ssl { // Internal implementation with timeout control pub(crate) fn sock_wait_for_io_impl( &self, - kind: SelectKind, + wait_kind: SockWaitKind, vm: &VirtualMachine, ) -> PyResult { if self.is_bio_mode() { @@ -2631,16 +2631,13 @@ mod _ssl { .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, kind, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; - - Ok(timed_out) + sock_wait(&socket, wait_kind, timeout, vm) } // Internal implementation with explicit timeout override pub(crate) fn sock_wait_for_io_with_timeout( &self, - kind: SelectKind, + wait_kind: SockWaitKind, timeout: Option, vm: &VirtualMachine, ) -> PyResult { @@ -2661,10 +2658,7 @@ mod _ssl { .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, kind, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; - - Ok(timed_out) + sock_wait(&socket, wait_kind, timeout, vm).map_err(|e| e.into_pyexception(vm)) } // SNI (Server Name Indication) Helper Methods: @@ -2934,13 +2928,12 @@ mod _ssl { socket_timeout }; - // Use sock_select directly with calculated timeout + // Use sock_wait directly with calculated timeout let py_socket: PyRef = self.sock.clone().try_into_value(vm)?; let socket = py_socket .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, SelectKind::Write, timeout_to_use) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + let timed_out = sock_wait(&socket, SockWaitKind::Write, timeout_to_use, vm)?; if timed_out { // Keep unsent data in pending buffer @@ -3001,7 +2994,7 @@ mod _ssl { let mut sent_total = 0; while sent_total < buf.len() { - let timed_out = self.sock_wait_for_io_impl(SelectKind::Write, vm)?; + let timed_out = self.sock_wait_for_io_impl(SockWaitKind::Write, vm)?; if timed_out { // Save unsent data to pending buffer self.pending_tls_output @@ -3073,8 +3066,7 @@ mod _ssl { let socket = py_socket .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, SelectKind::Write, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + let timed_out = sock_wait(&socket, SockWaitKind::Write, timeout, vm)?; if timed_out { return Err( @@ -3090,7 +3082,7 @@ mod _ssl { let mut pending = self.pending_tls_output.lock(); pending.drain(..sent); } - // If sent == 0, loop will retry with sock_select + // If sent == 0, loop will retry with sock_wait } Err(e) => { if is_blocking_io_error(&e, vm) { @@ -4264,7 +4256,7 @@ mod _ssl { // Wait for socket to be readable let timed_out = self.sock_wait_for_io_with_timeout( - SelectKind::Read, + SockWaitKind::Read, remaining_timeout, vm, )?; diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 1b7f79e3577..b30719e5c84 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -15,7 +15,7 @@ #[path = "../openssl/ssl_data_31.rs"] mod ssl_data; -use crate::socket::{SelectKind, timeout_error_msg}; +use crate::socket::{SockWaitKind, timeout_error_msg}; use crate::vm::VirtualMachine; use alloc::sync::Arc; use parking_lot::RwLock as ParkingRwLock; @@ -937,11 +937,11 @@ fn send_all_bytes( )); } socket - .sock_wait_for_io_with_timeout(SelectKind::Write, Some(dl - now), vm) + .sock_wait_for_io_with_timeout(SockWaitKind::Write, Some(dl - now), vm) .map_err(SslError::Py)? } else { socket - .sock_wait_for_io_impl(SelectKind::Write, vm) + .sock_wait_for_io_impl(SockWaitKind::Write, vm) .map_err(SslError::Py)? }; if timed_out { @@ -1123,7 +1123,7 @@ fn handshake_read_data( // Wait for data in socket mode if !is_bio { let timed_out = socket - .sock_wait_for_io_impl(SelectKind::Read, vm) + .sock_wait_for_io_impl(SockWaitKind::Read, vm) .map_err(SslError::Py)?; if timed_out { @@ -1967,7 +1967,7 @@ fn ssl_ensure_data_available( { // Socket has timeout - use select to enforce it let timed_out = socket - .sock_wait_for_io_impl(SelectKind::Read, vm) + .sock_wait_for_io_impl(SockWaitKind::Read, vm) .map_err(SslError::Py)?; if timed_out { // Socket not ready within timeout - raise socket.timeout From e9e359cb00b51e4a4c7af89c03318a016992bd5b Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Wed, 20 May 2026 22:06:18 +0200 Subject: [PATCH 06/11] Remove dead code from rustls glue --- crates/stdlib/src/ssl.rs | 30 ++----------- crates/stdlib/src/ssl/compat.rs | 77 ++++----------------------------- crates/stdlib/src/ssl/error.rs | 2 - 3 files changed, 13 insertions(+), 96 deletions(-) diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index a87abce68d8..9a812b67c92 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -75,8 +75,7 @@ mod _ssl { use parking_lot::{Mutex as ParkingMutex, RwLock as ParkingRwLock}; use pem_rfc7468::{LineEnding, encode_string}; use rustls::{ - ClientConfig, ClientConnection, Connection, HandshakeKind, RootCertStore, ServerConfig, - ServerConnection, + ClientConnection, Connection, HandshakeKind, RootCertStore, ServerConfig, ServerConnection, client::{ClientSessionMemoryCache, ClientSessionStore}, crypto::SupportedKxGroup, pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName}, @@ -398,8 +397,7 @@ mod _ssl { // Session data structure for tracking TLS sessions #[derive(Debug, Clone)] struct SessionData { - #[allow(dead_code)] - server_name: String, + _server_name: String, session_id: Vec, creation_time: SystemTime, lifetime: u64, @@ -477,7 +475,7 @@ mod _ssl { let creation_time = SystemTime::now(); let server_name_str = server_name.to_str(); let session_data = SessionData { - server_name: server_name_str.as_ref().to_string(), + _server_name: server_name_str.as_ref().to_string(), session_id: generate_session_id_from_metadata( server_name_str.as_ref(), &creation_time, @@ -521,7 +519,7 @@ mod _ssl { let creation_time = SystemTime::now(); let server_name_str = server_name.to_str(); let session_data = SessionData { - server_name: server_name_str.to_string(), + _server_name: server_name_str.to_string(), session_id: generate_session_id_from_metadata( server_name_str.as_ref(), &creation_time, @@ -720,10 +718,6 @@ mod _ssl { #[pytraverse(skip)] verify_flags: PyRwLock, // Rustls configuration (built lazily) - #[allow(dead_code)] - #[pytraverse(skip)] - client_config: PyRwLock>>, - #[allow(dead_code)] #[pytraverse(skip)] server_config: PyRwLock>>, // Certificate store @@ -747,19 +741,11 @@ mod _ssl { #[pytraverse(skip)] cert_keys: PyRwLock>, // Options - #[allow(dead_code)] #[pytraverse(skip)] options: PyRwLock, // ALPN protocols - #[allow(dead_code)] #[pytraverse(skip)] alpn_protocols: PyRwLock>>, - // ALPN strict matching flag - // When false (default), mimics OpenSSL behavior: no ALPN negotiation failure - // When true, requires ALPN match (Rustls default behavior) - #[allow(dead_code)] - #[pytraverse(skip)] - require_alpn_match: PyRwLock, // TLS 1.3 features #[pytraverse(skip)] post_handshake_auth: PyRwLock, @@ -1895,7 +1881,6 @@ mod _ssl { owner: PyRwLock::new(args.owner.into_option()), // Filter out Python None objects - only store actual SSLSession objects session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), - verified_chain: PyRwLock::new(None), incoming_bio: None, outgoing_bio: None, sni_state: PyRwLock::new(None), @@ -1973,7 +1958,6 @@ mod _ssl { owner: PyRwLock::new(args.owner.into_option()), // Filter out Python None objects - only store actual SSLSession objects session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), - verified_chain: PyRwLock::new(None), incoming_bio: Some(args.incoming), outgoing_bio: Some(args.outgoing), sni_state: PyRwLock::new(None), @@ -2277,7 +2261,6 @@ mod _ssl { check_hostname: PyRwLock::new(protocol == PROTOCOL_TLS_CLIENT), verify_mode: PyRwLock::new(default_verify_mode), verify_flags: PyRwLock::new(default_verify_flags), - client_config: PyRwLock::new(None), server_config: PyRwLock::new(None), root_certs: PyRwLock::new(RootCertStore::empty()), ca_certs_der: PyRwLock::new(Vec::new()), @@ -2286,7 +2269,6 @@ mod _ssl { cert_keys: PyRwLock::new(Vec::new()), options: PyRwLock::new(default_options), alpn_protocols: PyRwLock::new(Vec::new()), - require_alpn_match: PyRwLock::new(false), post_handshake_auth: PyRwLock::new(false), num_tickets: PyRwLock::new(2), // TLS 1.3 default minimum_version: PyRwLock::new(min_version), @@ -2348,10 +2330,6 @@ mod _ssl { owner: PyRwLock>, // Session for resumption session: PyRwLock>, - // Verified certificate chain (built during verification) - #[allow(dead_code)] - #[pytraverse(skip)] - verified_chain: PyRwLock>>>, // MemoryBIO mode (optional) incoming_bio: Option>, outgoing_bio: Option>, diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index b30719e5c84..68b17ab1aff 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -94,74 +94,15 @@ const X509_V_FLAG_CRL_CHECK: i32 = 4; // verification. They are used to map rustls certificate errors to OpenSSL // error codes for compatibility. -pub(super) use x509::{ - X509_V_ERR_CERT_HAS_EXPIRED, X509_V_ERR_CERT_NOT_YET_VALID, X509_V_ERR_CERT_REVOKED, - X509_V_ERR_HOSTNAME_MISMATCH, X509_V_ERR_INVALID_PURPOSE, X509_V_ERR_IP_ADDRESS_MISMATCH, - X509_V_ERR_UNABLE_TO_GET_CRL, X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY, - X509_V_ERR_UNSPECIFIED, -}; - -#[allow(dead_code)] -mod x509 { - pub(super) const X509_V_OK: i32 = 0; - pub(crate) const X509_V_ERR_UNSPECIFIED: i32 = 1; - pub(super) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT: i32 = 2; - pub(crate) const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3; - pub(super) const X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: i32 = 4; - pub(super) const X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE: i32 = 5; - pub(super) const X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY: i32 = 6; - pub(super) const X509_V_ERR_CERT_SIGNATURE_FAILURE: i32 = 7; - pub(super) const X509_V_ERR_CRL_SIGNATURE_FAILURE: i32 = 8; - pub(crate) const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9; - pub(crate) const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10; - pub(super) const X509_V_ERR_CRL_NOT_YET_VALID: i32 = 11; - pub(super) const X509_V_ERR_CRL_HAS_EXPIRED: i32 = 12; - pub(super) const X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: i32 = 13; - pub(super) const X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: i32 = 14; - pub(super) const X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD: i32 = 15; - pub(super) const X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD: i32 = 16; - pub(super) const X509_V_ERR_OUT_OF_MEM: i32 = 17; - pub(super) const X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: i32 = 18; - pub(super) const X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: i32 = 19; - pub(crate) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20; - pub(super) const X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: i32 = 21; - pub(super) const X509_V_ERR_CERT_CHAIN_TOO_LONG: i32 = 22; - pub(crate) const X509_V_ERR_CERT_REVOKED: i32 = 23; - pub(super) const X509_V_ERR_INVALID_CA: i32 = 24; - pub(super) const X509_V_ERR_PATH_LENGTH_EXCEEDED: i32 = 25; - pub(crate) const X509_V_ERR_INVALID_PURPOSE: i32 = 26; - pub(super) const X509_V_ERR_CERT_UNTRUSTED: i32 = 27; - pub(super) const X509_V_ERR_CERT_REJECTED: i32 = 28; - pub(super) const X509_V_ERR_SUBJECT_ISSUER_MISMATCH: i32 = 29; - pub(super) const X509_V_ERR_AKID_SKID_MISMATCH: i32 = 30; - pub(super) const X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH: i32 = 31; - pub(super) const X509_V_ERR_KEYUSAGE_NO_CERTSIGN: i32 = 32; - pub(super) const X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER: i32 = 33; - pub(super) const X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION: i32 = 34; - pub(super) const X509_V_ERR_KEYUSAGE_NO_CRL_SIGN: i32 = 35; - pub(super) const X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION: i32 = 36; - pub(super) const X509_V_ERR_INVALID_NON_CA: i32 = 37; - pub(super) const X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED: i32 = 38; - pub(super) const X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE: i32 = 39; - pub(super) const X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED: i32 = 40; - pub(super) const X509_V_ERR_INVALID_EXTENSION: i32 = 41; - pub(super) const X509_V_ERR_INVALID_POLICY_EXTENSION: i32 = 42; - pub(super) const X509_V_ERR_NO_EXPLICIT_POLICY: i32 = 43; - pub(super) const X509_V_ERR_DIFFERENT_CRL_SCOPE: i32 = 44; - pub(super) const X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE: i32 = 45; - pub(super) const X509_V_ERR_UNNESTED_RESOURCE: i32 = 46; - pub(super) const X509_V_ERR_PERMITTED_VIOLATION: i32 = 47; - pub(super) const X509_V_ERR_EXCLUDED_VIOLATION: i32 = 48; - pub(super) const X509_V_ERR_SUBTREE_MINMAX: i32 = 49; - pub(super) const X509_V_ERR_APPLICATION_VERIFICATION: i32 = 50; - pub(super) const X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE: i32 = 51; - pub(super) const X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX: i32 = 52; - pub(super) const X509_V_ERR_UNSUPPORTED_NAME_SYNTAX: i32 = 53; - pub(super) const X509_V_ERR_CRL_PATH_VALIDATION_ERROR: i32 = 54; - pub(crate) const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62; - pub(super) const X509_V_ERR_EMAIL_MISMATCH: i32 = 63; - pub(crate) const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64; -} +pub(super) const X509_V_ERR_UNSPECIFIED: i32 = 1; +pub(super) const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3; +pub(super) const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9; +pub(super) const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10; +pub(super) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20; +pub(super) const X509_V_ERR_CERT_REVOKED: i32 = 23; +pub(super) const X509_V_ERR_INVALID_PURPOSE: i32 = 26; +pub(super) const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62; +pub(super) const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64; // Certificate Error Conversion Functions: diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index d12cd834d1b..07ff4488698 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -125,7 +125,6 @@ pub(crate) mod ssl_error { ) } - #[allow(dead_code, reason = "This seems like a false positive")] pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { vm.new_os_subtype_error( PySSLZeroReturnError::class(&vm.ctx).to_owned(), @@ -134,7 +133,6 @@ pub(crate) mod ssl_error { ) } - #[allow(dead_code, reason = "This seems like a false positive")] pub(crate) fn create_ssl_syscall_error( vm: &VirtualMachine, msg: impl Into, From 7a10b53ba0369d56ef0adebc8b530929079e6937 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Thu, 21 May 2026 03:00:55 +0200 Subject: [PATCH 07/11] Do not present rustls errors as OSError(0, "Success") --- crates/stdlib/src/ssl/compat.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 68b17ab1aff..ae4110eab2e 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -404,6 +404,16 @@ impl SslError { // Use the proper cert verification error creator create_ssl_cert_verification_error(vm, &cert_err).expect("unlikely to happen") } + Self::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + create_ssl_eof_error(vm).upcast() + } + Self::Io(err) if err.raw_os_error().is_none() => vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + format!("SSL error: {err}"), + ) + .upcast(), Self::Io(err) => err.into_pyexception(vm), Self::SniCallbackRestart => { // This should be handled at PySSLSocket level From 86ec36b3111c1fde7f022818db06b055f36303f9 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Thu, 21 May 2026 07:57:39 +0200 Subject: [PATCH 08/11] Remove infinite loop "detection" from rustls glue TLS handshake cannot be infinite. Any infinite loop here is a serious bug in implementation and should be fixed properly. This code triggers in some cases (very short reads) with misleading `ssl_error.SSLWantReadError: The operation did not complete (read)`. --- crates/stdlib/src/ssl/compat.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index ae4110eab2e..ef284991495 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -1242,10 +1242,7 @@ pub(super) fn ssl_do_handshake( let is_bio = socket.is_bio_mode(); let is_server = matches!(conn, Connection::Server(_)); let mut first_iteration = true; // Track if this is the first loop iteration - let mut iteration_count = 0; - loop { - iteration_count += 1; let mut made_progress = false; // IMPORTANT: In BIO mode, force initial write even if wants_write() is false @@ -1363,11 +1360,6 @@ pub(super) fn ssl_do_handshake( if !should_continue { break; } - - // Safety check: prevent truly infinite loops (should never happen) - if iteration_count > 1000 { - break; - } } // If we exit the loop without completing handshake, return appropriate error @@ -1381,9 +1373,9 @@ pub(super) fn ssl_do_handshake( return Err(SslError::WantRead); } // Neither wants_read nor wants_write - this is a real error - Err(SslError::Syscall(format!( - "SSL handshake failed: incomplete after {iteration_count} iterations", - ))) + Err(SslError::Syscall( + "SSL handshake failed: incomplete handshake".to_string(), + )) } else { // Handshake completed successfully (shouldn't reach here normally) Ok(()) From fbe1f8068a18e6103f1705ee0449f02bf21313c5 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Thu, 21 May 2026 07:40:32 +0200 Subject: [PATCH 09/11] Add test for 1-byte max recv in TLS client --- extra_tests/snippets/stdlib_ssl_short_recv.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 extra_tests/snippets/stdlib_ssl_short_recv.py diff --git a/extra_tests/snippets/stdlib_ssl_short_recv.py b/extra_tests/snippets/stdlib_ssl_short_recv.py new file mode 100644 index 00000000000..4ec36e5b7e0 --- /dev/null +++ b/extra_tests/snippets/stdlib_ssl_short_recv.py @@ -0,0 +1,88 @@ +import os +import socket +import ssl +import sys +import threading + +if sys.implementation.name.lower() != "rustpython": + print("Ignored: stdlib_ssl_short_recv (RustPython only)") + raise SystemExit + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +CERTFILE = os.path.join(ROOT_DIR, "Lib/test/certdata/keycert.pem") +DATA = b"x" * 128 + +orig_recv = socket.socket.recv +client_sockname = None +recv_n = {} + + +def new_recv(sock, bufsize, flags=0): + sockname = sock.getsockname() + if sockname not in recv_n: + recv_n[sockname] = 0 + + bufsize = 1 + + if flags & socket.MSG_PEEK == 0: + recv_n[sockname] += 1 + return orig_recv(sock, bufsize, flags) + + +socket.socket.recv = new_recv + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +listener.bind(("127.0.0.1", 0)) +listener.listen(1) +addr, port = listener.getsockname() +server_errors = [] + + +def server(): + try: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(CERTFILE) + + sock, _ = listener.accept() + sock.settimeout(5.0) + + ssock = server_context.wrap_socket(sock, server_side=True) + try: + ssock.sendall(DATA) + finally: + ssock.close() + except BaseException as exc: + server_errors.append(exc) + finally: + listener.close() + + +thread = threading.Thread(target=server) +thread.start() + +raw = socket.create_connection((addr, port), timeout=5.0) +client_sockname = raw.getsockname() +raw.settimeout(5.0) + +client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +client_context.check_hostname = False +client_context.verify_mode = ssl.CERT_NONE + +client = client_context.wrap_socket(raw, server_hostname=None) +try: + chunks = [] + while sum(len(chunk) for chunk in chunks) < len(DATA): + chunk = client.recv(20000) + if not chunk: + break + chunks.append(chunk) +finally: + client.close() + +thread.join(10.0) +assert not thread.is_alive(), "server thread did not stop" +assert not server_errors, server_errors +assert b"".join(chunks) == DATA +assert len(recv_n) == 2 +assert all(n > 100 for n in recv_n.values()) From 6905963c9f86ad5c87744ad23c975db1f27c1166 Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Thu, 21 May 2026 10:19:16 +0200 Subject: [PATCH 10/11] Add regression test for https://github.com/RustPython/RustPython/issues/7891 --- .../stdlib_urllib_https_misaligned_recv.py | 246 ++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py diff --git a/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py b/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py new file mode 100644 index 00000000000..18ac9ab010a --- /dev/null +++ b/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py @@ -0,0 +1,246 @@ +import os +import socket +import ssl +import sys +import threading +import time +import urllib.request + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +CERTFILE = os.path.join(ROOT_DIR, "Lib/test/certdata/keycert.pem") +BODY = b"x" * 407_676 + +# TLS record body sizes observed from https://crates.io/api/v1/crates/tokio. +TLS_RECORD_BODY_SIZES = [ + 2855, + 281, + 53, + 218, + 1095, + 1395, + 1395, + 483, + 1395, + 1395, + 1395, + 1395, + 48, + 1360, + 1354, + 1395, + 1395, + 1395, + 1367, + 1395, + 1395, + 1395, + 1395, + 1326, + 1395, + 1395, + 1395, + 47, + 1395, + 1395, + 1395, + 1395, + 95, + 1395, + 1332, + 1287, + 1388, + 1395, + 1395, + 1374, + 1395, + 1380, + 794, + 791, + 1395, + 1381, + 1395, + 1395, + 1395, + 1333, + 1395, + 1395, + 1395, + 1395, + 1395, + 1395, + 965, + 16401, + 3914, + 2526, + 1041, + 8209, + 9233, + 16401, + 11650, + 10262, + 7486, + 3468, + 692, + 1041, + 16401, + 12242, + 9466, + 1041, + 8209, + 9233, + 8209, + 9233, + 16401, + 1041, + 8209, + 9233, + 6161, + 2065, + 9233, + 16401, + 16358, + 10806, + 1041, + 8209, + 16401, + 3914, + 16401, + 16401, + 3089, + 9233, + 4642, + 478, + 8209, + 3140, + 1752, + 9233, + 8209, + 8209, + 16401, + 16064, + 14676, + 13288, + 2065, + 16401, + 1041, + 8209, + 16401, + 1041, + 6374, + 1007, +] + +server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +server_context.load_cert_chain(CERTFILE) +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +listener.bind(("127.0.0.1", 0)) +listener.listen(1) +addr, port = listener.getsockname() +server_errors = [] +finished = False + + +def guard_timeout(): + time.sleep(20) + if not finished: + print( + "stdlib_urllib_https_misaligned_recv.py timed out", + file=sys.stderr, + flush=True, + ) + os.abort() + + +threading.Thread(target=guard_timeout, daemon=True).start() + + +def drain_outgoing(outgoing, conn): + while True: + try: + data = outgoing.read() + except ssl.SSLWantReadError: + return + if not data: + return + conn.sendall(data) + + +def run_server(): + try: + conn, _ = listener.accept() + conn.settimeout(5.0) + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + tls = server_context.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + tls.do_handshake() + break + except ssl.SSLWantReadError: + drain_outgoing(outgoing, conn) + incoming.write(conn.recv(65536)) + except ssl.SSLWantWriteError: + pass + drain_outgoing(outgoing, conn) + + request = b"" + while b"\r\n\r\n" not in request: + try: + request += tls.read(65536) + except ssl.SSLWantReadError: + drain_outgoing(outgoing, conn) + incoming.write(conn.recv(65536)) + drain_outgoing(outgoing, conn) + + response = ( + b"HTTP/1.1 200 OK\r\n" + b"Connection: close\r\n" + + b"Content-Length: " + + str(len(BODY)).encode() + + b"\r\n" + + b"Content-Type: application/json\r\n" + + b"\r\n" + + BODY + ) + plaintext_sizes = [max(1, n - 17) for n in TLS_RECORD_BODY_SIZES] + pos = 0 + while pos < len(response): + size = plaintext_sizes.pop(0) if plaintext_sizes else 16384 + end = min(len(response), pos + size) + while pos < end: + try: + pos += tls.write(response[pos:end]) + except ssl.SSLWantWriteError: + pass + drain_outgoing(outgoing, conn) + conn.close() + except BaseException as exc: + server_errors.append(exc) + finally: + listener.close() + + +thread = threading.Thread(target=run_server) +thread.start() + +client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +client_context.check_hostname = False +client_context.verify_mode = ssl.CERT_NONE +opener = urllib.request.build_opener( + urllib.request.ProxyHandler({}), + urllib.request.HTTPSHandler(context=client_context), +) +try: + with opener.open(f"https://{addr}:{port}/", timeout=5.0) as response: + body = response.read() + + thread.join(10.0) + assert not thread.is_alive(), "server thread did not stop" + assert not server_errors, server_errors + assert body == BODY +finally: + finished = True From 6c0e29287f71ef84ec9b84039bcfe7c141e5172a Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Thu, 21 May 2026 12:54:23 +0200 Subject: [PATCH 11/11] Fix constants in rustls glue code * Deduplicate verify flags / record-size constants * Larger "max encrypted TLS record length" --- crates/stdlib/src/ssl.rs | 41 +++++++++++++++++++++++++-------- crates/stdlib/src/ssl/cert.rs | 2 +- crates/stdlib/src/ssl/compat.rs | 22 ++++-------------- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 9a812b67c92..8548220380f 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -153,8 +153,28 @@ mod _ssl { // Buffer sizes and limits (OpenSSL/CPython compatibility) const PEM_BUFSIZE: usize = 1024; + // OpenSSL: ssl/ssl_local.h + const SSL3_RT_HEADER_LENGTH: usize = 5; + // This is the maximum MAC (digest) size used by the SSL library. Currently + // maximum of 20 is used by SHA1, but we reserve for future extension for + // 512-bit hashes. + const SSL3_RT_MAX_MD_SIZE: usize = 64; + // Maximum plaintext length: defined by SSL/TLS standards const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; + // Maximum compression overhead: defined by SSL/TLS standards + const SSL3_RT_MAX_COMPRESSED_OVERHEAD: usize = 1024; + // The standards give a maximum encryption overhead of 1024 bytes. In + // practice the value is lower than this. The overhead is the maximum number + // of padding bytes (256) plus the mac size. + const SSL3_RT_MAX_ENCRYPTED_OVERHEAD: usize = 256 + SSL3_RT_MAX_MD_SIZE; + const SSL3_RT_MAX_COMPRESSED_LENGTH: usize = + SSL3_RT_MAX_PLAIN_LENGTH + SSL3_RT_MAX_COMPRESSED_OVERHEAD; + const SSL3_RT_MAX_ENCRYPTED_LENGTH: usize = + SSL3_RT_MAX_ENCRYPTED_OVERHEAD + SSL3_RT_MAX_COMPRESSED_LENGTH; + pub(crate) const SSL3_RT_MAX_PACKET_SIZE: usize = + SSL3_RT_MAX_ENCRYPTED_LENGTH + SSL3_RT_HEADER_LENGTH; + // SSL session cache size (common practice, similar to OpenSSL defaults) const SSL_SESSION_CACHE_SIZE: usize = 256; @@ -166,21 +186,26 @@ mod _ssl { #[pyattr] const CERT_REQUIRED: i32 = 2; - // Certificate requirements + // SSL Verification Flags / Certificate requirements #[pyattr] const VERIFY_DEFAULT: i32 = 0; #[pyattr] const VERIFY_CRL_CHECK_LEAF: i32 = 4; #[pyattr] const VERIFY_CRL_CHECK_CHAIN: i32 = 12; + /// VERIFY_X509_STRICT flag for RFC 5280 strict compliance + /// When set, performs additional validation including AKI extension checks #[pyattr] - const VERIFY_X509_STRICT: i32 = 32; + pub(crate) const VERIFY_X509_STRICT: i32 = 32; #[pyattr] const VERIFY_ALLOW_PROXY_CERTS: i32 = 64; #[pyattr] const VERIFY_X509_TRUSTED_FIRST: i32 = 32768; + /// VERIFY_X509_PARTIAL_CHAIN flag for partial chain validation + /// When set, accept certificates if any certificate in the chain is in the trust store + /// (not just root CAs). This matches OpenSSL's X509_V_FLAG_PARTIAL_CHAIN behavior. #[pyattr] - const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; + pub(crate) const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; // Options (OpenSSL-compatible flags, mostly no-op in rustls) #[pyattr] @@ -4315,7 +4340,7 @@ mod _ssl { break; } - let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let mut buf = vec![0u8; SSL3_RT_MAX_PACKET_SIZE]; let written = conn .write_tls(&mut buf.as_mut_slice()) .map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?; @@ -4347,7 +4372,7 @@ mod _ssl { } // BIO mode: read from incoming BIO - match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + match self.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(bytes_obj) => { let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; let data = bytes.borrow_buf(); @@ -4386,11 +4411,7 @@ mod _ssl { /// /// Equivalent to OpenSSL's `SSL_set_read_ahead(ssl, 0)` — rustls has no /// such knob, so we enforce record-level reads manually via peek. - fn try_read_close_notify_socket( - &self, - conn: &mut Connection, - vm: &VirtualMachine, - ) -> bool { + fn try_read_close_notify_socket(&self, conn: &mut Connection, vm: &VirtualMachine) -> bool { // Consume at most one TLS record from the socket match self.sock_recv_at_most_one_tls_record(vm) { Ok(data) => { diff --git a/crates/stdlib/src/ssl/cert.rs b/crates/stdlib/src/ssl/cert.rs index 835e3f37c6b..513d2929eb2 100644 --- a/crates/stdlib/src/ssl/cert.rs +++ b/crates/stdlib/src/ssl/cert.rs @@ -22,7 +22,7 @@ use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; use std::collections::HashSet; use x509_parser::prelude::*; -use super::compat::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}; +use super::_ssl::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}; // Certificate Verification Constants diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index ef284991495..2d6bb369e9a 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -35,7 +35,7 @@ use std::io::Read; use std::sync::Once; // Import PySSLSocket from parent module -use super::_ssl::PySSLSocket; +use super::_ssl::{PySSLSocket, SSL3_RT_MAX_PACKET_SIZE, VERIFY_X509_STRICT}; // Import error types and helper functions from error module use super::error::{ @@ -43,16 +43,6 @@ use super::error::{ create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, }; -// SSL Verification Flags -/// VERIFY_X509_STRICT flag for RFC 5280 strict compliance -/// When set, performs additional validation including AKI extension checks -pub(super) const VERIFY_X509_STRICT: i32 = 0x20; - -/// VERIFY_X509_PARTIAL_CHAIN flag for partial chain validation -/// When set, accept certificates if any certificate in the chain is in the trust store -/// (not just root CAs). This matches OpenSSL's X509_V_FLAG_PARTIAL_CHAIN behavior. -pub(super) const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; - // CryptoProvider Initialization: /// Ensure the default CryptoProvider is installed (thread-safe, runs once) @@ -72,10 +62,6 @@ fn ensure_default_provider() { // OpenSSL Constants: -// OpenSSL TLS record maximum plaintext size (ssl/ssl_local.h) -// #define SSL3_RT_MAX_PLAIN_LENGTH 16384 -const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; - // OpenSSL error library codes (include/openssl/err.h) // #define ERR_LIB_SSL 20 const ERR_LIB_SSL: i32 = 20; @@ -1092,7 +1078,7 @@ fn handshake_read_data( // detect it. recv_at_most_one_tls_record(socket, vm)? } else { - match socket.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + match socket.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(d) => d, Err(e) => { if is_blocking_io_error(&e, vm) { @@ -1312,7 +1298,7 @@ pub(super) fn ssl_do_handshake( if conn.wants_write() { // Write all pending TLS data to outgoing BIO loop { - let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let mut buf = vec![0u8; SSL3_RT_MAX_PACKET_SIZE]; let n = match conn.write_tls(&mut buf.as_mut_slice()) { Ok(n) => n, Err(_) => break, @@ -1929,7 +1915,7 @@ fn ssl_ensure_data_available( let data = if !is_bio { recv_at_most_one_tls_record_for_data(conn, socket, vm)? } else { - match socket.sock_recv(2048, vm) { + match socket.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(data) => data, Err(e) => { if is_blocking_io_error(&e, vm) {