@@ -50,7 +50,7 @@ mod _ssl {
5050
5151 // Import error types used in this module (others are exposed via pymodule(with(...)))
5252 use super :: error:: {
53- PySSLEOFError , PySSLError , create_ssl_want_read_error, create_ssl_want_write_error,
53+ PySSLError , create_ssl_eof_error , create_ssl_want_read_error, create_ssl_want_write_error,
5454 } ;
5555 use alloc:: sync:: Arc ;
5656 use core:: {
@@ -1903,6 +1903,7 @@ mod _ssl {
19031903 client_hello_buffer : PyMutex :: new ( None ) ,
19041904 shutdown_state : PyMutex :: new ( ShutdownState :: NotStarted ) ,
19051905 pending_tls_output : PyMutex :: new ( Vec :: new ( ) ) ,
1906+ write_buffered_len : PyMutex :: new ( 0 ) ,
19061907 deferred_cert_error : Arc :: new ( ParkingRwLock :: new ( None ) ) ,
19071908 } ;
19081909
@@ -1974,6 +1975,7 @@ mod _ssl {
19741975 client_hello_buffer : PyMutex :: new ( None ) ,
19751976 shutdown_state : PyMutex :: new ( ShutdownState :: NotStarted ) ,
19761977 pending_tls_output : PyMutex :: new ( Vec :: new ( ) ) ,
1978+ write_buffered_len : PyMutex :: new ( 0 ) ,
19771979 deferred_cert_error : Arc :: new ( ParkingRwLock :: new ( None ) ) ,
19781980 } ;
19791981
@@ -2345,6 +2347,10 @@ mod _ssl {
23452347 // but the socket cannot accept all the data immediately
23462348 #[ pytraverse( skip) ]
23472349 pub ( crate ) pending_tls_output : PyMutex < Vec < u8 > > ,
2350+ // Tracks bytes already buffered in rustls for the current write operation
2351+ // Prevents duplicate writes when retrying after WantWrite/WantRead
2352+ #[ pytraverse( skip) ]
2353+ pub ( crate ) write_buffered_len : PyMutex < usize > ,
23482354 // Deferred client certificate verification error (for TLS 1.3)
23492355 // Stores error message if client cert verification failed during handshake
23502356 // Error is raised on first I/O operation after handshake
@@ -2604,6 +2610,36 @@ mod _ssl {
26042610 Ok ( timed_out)
26052611 }
26062612
2613+ // Internal implementation with explicit timeout override
2614+ pub ( crate ) fn sock_wait_for_io_with_timeout (
2615+ & self ,
2616+ kind : SelectKind ,
2617+ timeout : Option < std:: time:: Duration > ,
2618+ vm : & VirtualMachine ,
2619+ ) -> PyResult < bool > {
2620+ if self . is_bio_mode ( ) {
2621+ // BIO mode doesn't use select
2622+ return Ok ( false ) ;
2623+ }
2624+
2625+ if let Some ( t) = timeout
2626+ && t. is_zero ( )
2627+ {
2628+ // Non-blocking mode - don't use select
2629+ return Ok ( false ) ;
2630+ }
2631+
2632+ let py_socket: PyRef < PySocket > = self . sock . clone ( ) . try_into_value ( vm) ?;
2633+ let socket = py_socket
2634+ . sock ( )
2635+ . map_err ( |e| vm. new_os_error ( format ! ( "Failed to get socket: {e}" ) ) ) ?;
2636+
2637+ let timed_out = sock_select ( & socket, kind, timeout)
2638+ . map_err ( |e| vm. new_os_error ( format ! ( "select failed: {e}" ) ) ) ?;
2639+
2640+ Ok ( timed_out)
2641+ }
2642+
26072643 // SNI (Server Name Indication) Helper Methods:
26082644 // These methods support the server-side handshake SNI callback mechanism
26092645
@@ -2783,6 +2819,7 @@ mod _ssl {
27832819 let is_non_blocking = socket_timeout. map ( |t| t. is_zero ( ) ) . unwrap_or ( false ) ;
27842820
27852821 let mut sent_total = 0 ;
2822+
27862823 while sent_total < pending. len ( ) {
27872824 // Calculate timeout: use deadline if provided, otherwise use socket timeout
27882825 let timeout_to_use = if let Some ( dl) = deadline {
@@ -2810,6 +2847,9 @@ mod _ssl {
28102847 if timed_out {
28112848 // Keep unsent data in pending buffer
28122849 * pending = pending[ sent_total..] . to_vec ( ) ;
2850+ if is_non_blocking {
2851+ return Err ( create_ssl_want_write_error ( vm) . upcast ( ) ) ;
2852+ }
28132853 return Err (
28142854 timeout_error_msg ( vm, "The write operation timed out" . to_string ( ) ) . upcast ( ) ,
28152855 ) ;
@@ -2824,6 +2864,7 @@ mod _ssl {
28242864 * pending = pending[ sent_total..] . to_vec ( ) ;
28252865 return Err ( create_ssl_want_write_error ( vm) . upcast ( ) ) ;
28262866 }
2867+ // Socket said ready but sent 0 bytes - retry
28272868 continue ;
28282869 }
28292870 sent_total += sent;
@@ -2916,6 +2957,9 @@ mod _ssl {
29162957 pub ( crate ) fn blocking_flush_all_pending ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
29172958 // Get socket timeout to respect during flush
29182959 let timeout = self . get_socket_timeout ( vm) ?;
2960+ if timeout. map ( |t| t. is_zero ( ) ) . unwrap_or ( false ) {
2961+ return self . flush_pending_tls_output ( vm, None ) ;
2962+ }
29192963
29202964 loop {
29212965 let pending_data = {
@@ -2948,8 +2992,7 @@ mod _ssl {
29482992 let mut pending = self . pending_tls_output . lock ( ) ;
29492993 pending. drain ( ..sent) ;
29502994 }
2951- // If sent == 0, socket wasn't ready despite select() saying so
2952- // Continue loop to retry - this avoids infinite loops
2995+ // If sent == 0, loop will retry with sock_select
29532996 }
29542997 Err ( e) => {
29552998 if is_blocking_io_error ( & e, vm) {
@@ -3515,16 +3558,60 @@ mod _ssl {
35153558 return_data ( buf, & buffer, vm)
35163559 }
35173560 Err ( crate :: ssl:: compat:: SslError :: Eof ) => {
3561+ // If plaintext is still buffered, return it before EOF.
3562+ let pending = {
3563+ let mut conn_guard = self . connection . lock ( ) ;
3564+ let conn = match conn_guard. as_mut ( ) {
3565+ Some ( conn) => conn,
3566+ None => return Err ( create_ssl_eof_error ( vm) . upcast ( ) ) ,
3567+ } ;
3568+ use std:: io:: BufRead ;
3569+ let mut reader = conn. reader ( ) ;
3570+ reader. fill_buf ( ) . map ( |buf| buf. len ( ) ) . unwrap_or ( 0 )
3571+ } ;
3572+ if pending > 0 {
3573+ let mut buf = vec ! [ 0u8 ; pending. min( len) ] ;
3574+ let read_retry = {
3575+ let mut conn_guard = self . connection . lock ( ) ;
3576+ let conn = conn_guard
3577+ . as_mut ( )
3578+ . ok_or_else ( || vm. new_value_error ( "Connection not established" ) ) ?;
3579+ crate :: ssl:: compat:: ssl_read ( conn, & mut buf, self , vm)
3580+ } ;
3581+ if let Ok ( n) = read_retry {
3582+ buf. truncate ( n) ;
3583+ return return_data ( buf, & buffer, vm) ;
3584+ }
3585+ }
35183586 // EOF occurred in violation of protocol (unexpected closure)
3519- Err ( vm
3520- . new_os_subtype_error (
3521- PySSLEOFError :: class ( & vm. ctx ) . to_owned ( ) ,
3522- None ,
3523- "EOF occurred in violation of protocol" ,
3524- )
3525- . upcast ( ) )
3587+ Err ( create_ssl_eof_error ( vm) . upcast ( ) )
35263588 }
35273589 Err ( crate :: ssl:: compat:: SslError :: ZeroReturn ) => {
3590+ // If plaintext is still buffered, return it before clean EOF.
3591+ let pending = {
3592+ let mut conn_guard = self . connection . lock ( ) ;
3593+ let conn = match conn_guard. as_mut ( ) {
3594+ Some ( conn) => conn,
3595+ None => return return_data ( vec ! [ ] , & buffer, vm) ,
3596+ } ;
3597+ use std:: io:: BufRead ;
3598+ let mut reader = conn. reader ( ) ;
3599+ reader. fill_buf ( ) . map ( |buf| buf. len ( ) ) . unwrap_or ( 0 )
3600+ } ;
3601+ if pending > 0 {
3602+ let mut buf = vec ! [ 0u8 ; pending. min( len) ] ;
3603+ let read_retry = {
3604+ let mut conn_guard = self . connection . lock ( ) ;
3605+ let conn = conn_guard
3606+ . as_mut ( )
3607+ . ok_or_else ( || vm. new_value_error ( "Connection not established" ) ) ?;
3608+ crate :: ssl:: compat:: ssl_read ( conn, & mut buf, self , vm)
3609+ } ;
3610+ if let Ok ( n) = read_retry {
3611+ buf. truncate ( n) ;
3612+ return return_data ( buf, & buffer, vm) ;
3613+ }
3614+ }
35283615 // Clean closure with close_notify - return empty data
35293616 return_data ( vec ! [ ] , & buffer, vm)
35303617 }
@@ -3580,21 +3667,17 @@ mod _ssl {
35803667 let data_bytes = data. borrow_buf ( ) ;
35813668 let data_len = data_bytes. len ( ) ;
35823669
3583- // return 0 immediately for empty write
35843670 if data_len == 0 {
35853671 return Ok ( 0 ) ;
35863672 }
35873673
3588- // Ensure handshake is done - if not, complete it first
3589- // This matches OpenSSL behavior where SSL_write() auto-completes handshake
3674+ // Ensure handshake is done (SSL_write auto-completes handshake)
35903675 if !* self . handshake_done . lock ( ) {
35913676 self . do_handshake ( vm) ?;
35923677 }
35933678
3594- // Check if connection has been shut down
3595- // After unwrap()/shutdown(), write operations should fail with SSLError
3596- let shutdown_state = * self . shutdown_state . lock ( ) ;
3597- if shutdown_state != ShutdownState :: NotStarted {
3679+ // Check shutdown state
3680+ if * self . shutdown_state . lock ( ) != ShutdownState :: NotStarted {
35983681 return Err ( vm
35993682 . new_os_subtype_error (
36003683 PySSLError :: class ( & vm. ctx ) . to_owned ( ) ,
@@ -3604,76 +3687,32 @@ mod _ssl {
36043687 . upcast ( ) ) ;
36053688 }
36063689
3607- {
3690+ // Call ssl_write (matches CPython's SSL_write_ex loop)
3691+ let result = {
36083692 let mut conn_guard = self . connection . lock ( ) ;
36093693 let conn = conn_guard
36103694 . as_mut ( )
36113695 . ok_or_else ( || vm. new_value_error ( "Connection not established" ) ) ?;
36123696
3613- let is_bio = self . is_bio_mode ( ) ;
3614- let data : & [ u8 ] = data_bytes . as_ref ( ) ;
3697+ crate :: ssl :: compat :: ssl_write ( conn , data_bytes . as_ref ( ) , self , vm )
3698+ } ;
36153699
3616- // CRITICAL: Flush any pending TLS data before writing new data
3617- // This ensures TLS 1.3 Finished message reaches server before application data
3618- // Without this, server may not be ready to process our data
3619- if !is_bio {
3620- self . flush_pending_tls_output ( vm, None ) ?;
3700+ match result {
3701+ Ok ( n) => {
3702+ self . check_deferred_cert_error ( vm) ?;
3703+ Ok ( n)
36213704 }
3622-
3623- // Write data in chunks to avoid filling the internal TLS buffer
3624- // rustls has a limited internal buffer, so we need to flush periodically
3625- const CHUNK_SIZE : usize = 16384 ; // 16KB chunks (typical TLS record size)
3626- let mut written = 0 ;
3627-
3628- while written < data. len ( ) {
3629- let chunk_end = core:: cmp:: min ( written + CHUNK_SIZE , data. len ( ) ) ;
3630- let chunk = & data[ written..chunk_end] ;
3631-
3632- // Write chunk to TLS layer
3633- {
3634- let mut writer = conn. writer ( ) ;
3635- use std:: io:: Write ;
3636- writer
3637- . write_all ( chunk)
3638- . map_err ( |e| vm. new_os_error ( format ! ( "Write failed: {e}" ) ) ) ?;
3639- // Flush to ensure data is converted to TLS records
3640- writer
3641- . flush ( )
3642- . map_err ( |e| vm. new_os_error ( format ! ( "Flush failed: {e}" ) ) ) ?;
3643- }
3644-
3645- written = chunk_end;
3646-
3647- // Flush TLS data to socket after each chunk
3648- if conn. wants_write ( ) {
3649- if is_bio {
3650- self . write_pending_tls ( conn, vm) ?;
3651- } else {
3652- // Socket mode: flush all pending TLS data
3653- // First, try to send any previously pending data
3654- self . flush_pending_tls_output ( vm, None ) ?;
3655-
3656- while conn. wants_write ( ) {
3657- let mut buf = Vec :: new ( ) ;
3658- conn. write_tls ( & mut buf) . map_err ( |e| {
3659- vm. new_os_error ( format ! ( "TLS write failed: {e}" ) )
3660- } ) ?;
3661-
3662- if !buf. is_empty ( ) {
3663- // Try to send TLS data, saving unsent bytes to pending buffer
3664- self . send_tls_output ( buf, vm) ?;
3665- }
3666- }
3667- }
3668- }
3705+ Err ( crate :: ssl:: compat:: SslError :: WantRead ) => {
3706+ Err ( create_ssl_want_read_error ( vm) . upcast ( ) )
3707+ }
3708+ Err ( crate :: ssl:: compat:: SslError :: WantWrite ) => {
3709+ Err ( create_ssl_want_write_error ( vm) . upcast ( ) )
3710+ }
3711+ Err ( crate :: ssl:: compat:: SslError :: Timeout ( msg) ) => {
3712+ Err ( timeout_error_msg ( vm, msg) . upcast ( ) )
36693713 }
3714+ Err ( e) => Err ( e. into_py_err ( vm) ) ,
36703715 }
3671-
3672- // Check for deferred certificate verification errors (TLS 1.3)
3673- // Must be checked AFTER write completes, as the error may be set during I/O
3674- self . check_deferred_cert_error ( vm) ?;
3675-
3676- Ok ( data_len)
36773716 }
36783717
36793718 #[ pymethod]
@@ -4013,6 +4052,10 @@ mod _ssl {
40134052
40144053 // Write close_notify to outgoing buffer/BIO
40154054 self . write_pending_tls ( conn, vm) ?;
4055+ // Ensure close_notify and any pending TLS data are flushed
4056+ if !is_bio {
4057+ self . flush_pending_tls_output ( vm, None ) ?;
4058+ }
40164059
40174060 // Update state
40184061 * self . shutdown_state . lock ( ) = ShutdownState :: SentCloseNotify ;
0 commit comments