Skip to content

Commit 114fc35

Browse files
committed
deadline
1 parent 92b79a1 commit 114fc35

File tree

2 files changed

+77
-24
lines changed

2 files changed

+77
-24
lines changed

crates/stdlib/src/ssl.rs

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2724,6 +2724,8 @@ mod _ssl {
27242724
recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm)
27252725
}
27262726

2727+
/// Socket send - just sends data, caller must handle pending flush
2728+
/// Use flush_pending_tls_output before this if ordering is important
27272729
pub(crate) fn sock_send(&self, data: &[u8], vm: &VirtualMachine) -> PyResult<PyObjectRef> {
27282730
// In BIO mode, write to outgoing BIO
27292731
if let Some(ref bio) = self.outgoing_bio {
@@ -2742,19 +2744,45 @@ mod _ssl {
27422744
}
27432745

27442746
/// Flush any pending TLS output data to the socket
2745-
/// This should be called before generating new TLS output
2746-
pub(crate) fn flush_pending_tls_output(&self, vm: &VirtualMachine) -> PyResult<()> {
2747+
/// Optional deadline parameter allows respecting a read deadline during flush
2748+
pub(crate) fn flush_pending_tls_output(
2749+
&self,
2750+
vm: &VirtualMachine,
2751+
deadline: Option<std::time::Instant>,
2752+
) -> PyResult<()> {
27472753
let mut pending = self.pending_tls_output.lock();
27482754
if pending.is_empty() {
27492755
return Ok(());
27502756
}
27512757

2752-
let timeout = self.get_socket_timeout(vm)?;
2753-
let is_non_blocking = timeout.map(|t| t.is_zero()).unwrap_or(false);
2758+
let socket_timeout = self.get_socket_timeout(vm)?;
2759+
let is_non_blocking = socket_timeout.map(|t| t.is_zero()).unwrap_or(false);
27542760

27552761
let mut sent_total = 0;
27562762
while sent_total < pending.len() {
2757-
let timed_out = self.sock_wait_for_io_impl(SelectKind::Write, vm)?;
2763+
// Calculate timeout: use deadline if provided, otherwise use socket timeout
2764+
let timeout_to_use = if let Some(dl) = deadline {
2765+
let now = std::time::Instant::now();
2766+
if now >= dl {
2767+
// Deadline already passed
2768+
*pending = pending[sent_total..].to_vec();
2769+
return Err(
2770+
timeout_error_msg(vm, "The operation timed out".to_string()).upcast()
2771+
);
2772+
}
2773+
Some(dl - now)
2774+
} else {
2775+
socket_timeout
2776+
};
2777+
2778+
// Use sock_select directly with calculated timeout
2779+
let py_socket: PyRef<PySocket> = self.sock.clone().try_into_value(vm)?;
2780+
let socket = py_socket
2781+
.sock()
2782+
.map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?;
2783+
let timed_out = sock_select(&socket, SelectKind::Write, timeout_to_use)
2784+
.map_err(|e| vm.new_os_error(format!("select failed: {e}")))?;
2785+
27582786
if timed_out {
27592787
// Keep unsent data in pending buffer
27602788
*pending = pending[sent_total..].to_vec();
@@ -2888,7 +2916,7 @@ mod _ssl {
28882916
);
28892917
}
28902918

2891-
// Try to send pending data
2919+
// Try to send pending data (use raw to avoid recursion)
28922920
match self.sock_send(&pending_data, vm) {
28932921
Ok(result) => {
28942922
let sent: usize = result.try_to_value::<isize>(vm)?.try_into().unwrap_or(0);
@@ -3565,7 +3593,7 @@ mod _ssl {
35653593
// This ensures TLS 1.3 Finished message reaches server before application data
35663594
// Without this, server may not be ready to process our data
35673595
if !is_bio {
3568-
self.flush_pending_tls_output(vm)?;
3596+
self.flush_pending_tls_output(vm, None)?;
35693597
}
35703598

35713599
// Write data in chunks to avoid filling the internal TLS buffer
@@ -3599,7 +3627,7 @@ mod _ssl {
35993627
} else {
36003628
// Socket mode: flush all pending TLS data
36013629
// First, try to send any previously pending data
3602-
self.flush_pending_tls_output(vm)?;
3630+
self.flush_pending_tls_output(vm, None)?;
36033631

36043632
while conn.wants_write() {
36053633
let mut buf = Vec::new();
@@ -3954,7 +3982,7 @@ mod _ssl {
39543982
self.blocking_flush_all_pending(vm)?;
39553983
} else {
39563984
// BIO mode: non-blocking flush (caller handles pending data)
3957-
let _ = self.flush_pending_tls_output(vm);
3985+
let _ = self.flush_pending_tls_output(vm, None);
39583986
}
39593987

39603988
conn.send_close_notify();
@@ -4030,12 +4058,12 @@ mod _ssl {
40304058
Some(0.0) => {
40314059
// Non-blocking: best-effort flush, ignore errors
40324060
// to avoid deadlock with asyncore-based servers
4033-
let _ = self.flush_pending_tls_output(vm);
4061+
let _ = self.flush_pending_tls_output(vm, None);
40344062
}
40354063
Some(_t) => {
40364064
// Timeout mode: use flush with socket's timeout
40374065
// Errors (including timeout) are propagated to caller
4038-
self.flush_pending_tls_output(vm)?;
4066+
self.flush_pending_tls_output(vm, None)?;
40394067
}
40404068
None => {
40414069
// Blocking mode: wait until all pending data is sent
@@ -4075,7 +4103,7 @@ mod _ssl {
40754103
fn write_pending_tls(&self, conn: &mut TlsConnection, vm: &VirtualMachine) -> PyResult<()> {
40764104
// First, flush any previously pending TLS output
40774105
// Must succeed before sending new data to maintain order
4078-
self.flush_pending_tls_output(vm)?;
4106+
self.flush_pending_tls_output(vm, None)?;
40794107

40804108
loop {
40814109
if !conn.wants_write() {

crates/stdlib/src/ssl/compat.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,17 +1010,35 @@ pub(super) fn is_blocking_io_error(err: &Py<PyBaseException>, vm: &VirtualMachin
10101010
/// Loops until all bytes are sent. For blocking sockets, this will wait
10111011
/// until all data is sent. For non-blocking sockets, returns WantWrite
10121012
/// if no progress can be made.
1013-
fn send_all_bytes(socket: &PySSLSocket, buf: Vec<u8>, vm: &VirtualMachine) -> SslResult<()> {
1014-
// First, flush any previously pending TLS data
1015-
// Must succeed before sending new data to maintain order
1016-
socket.flush_pending_tls_output(vm).map_err(SslError::Py)?;
1013+
/// Optional deadline parameter allows respecting a read deadline during flush.
1014+
fn send_all_bytes(
1015+
socket: &PySSLSocket,
1016+
buf: Vec<u8>,
1017+
vm: &VirtualMachine,
1018+
deadline: Option<std::time::Instant>,
1019+
) -> SslResult<()> {
1020+
// First, flush any previously pending TLS data with deadline
1021+
socket
1022+
.flush_pending_tls_output(vm, deadline)
1023+
.map_err(SslError::Py)?;
10171024

10181025
if buf.is_empty() {
10191026
return Ok(());
10201027
}
10211028

10221029
let mut sent_total = 0;
10231030
while sent_total < buf.len() {
1031+
// Check deadline before each send attempt
1032+
if let Some(dl) = deadline {
1033+
if std::time::Instant::now() >= dl {
1034+
socket
1035+
.pending_tls_output
1036+
.lock()
1037+
.extend_from_slice(&buf[sent_total..]);
1038+
return Err(SslError::Timeout("The operation timed out".to_string()));
1039+
}
1040+
}
1041+
10241042
match socket.sock_send(&buf[sent_total..], vm) {
10251043
Ok(result) => {
10261044
let sent: usize = result
@@ -1075,7 +1093,9 @@ fn handshake_write_loop(
10751093

10761094
// Flush any previously pending TLS data before generating new output
10771095
// Must succeed before sending new data to maintain order
1078-
socket.flush_pending_tls_output(vm).map_err(SslError::Py)?;
1096+
socket
1097+
.flush_pending_tls_output(vm, None)
1098+
.map_err(SslError::Py)?;
10791099

10801100
while conn.wants_write() || force_initial_write {
10811101
if force_initial_write && !conn.wants_write() {
@@ -1090,7 +1110,7 @@ fn handshake_write_loop(
10901110

10911111
if written > 0 && !buf.is_empty() {
10921112
// Send all bytes to socket, handling partial sends
1093-
send_all_bytes(socket, buf, vm)?;
1113+
send_all_bytes(socket, buf, vm, None)?;
10941114
made_progress = true;
10951115
} else if written == 0 {
10961116
// No data written but wants_write is true - should not happen normally
@@ -1209,7 +1229,7 @@ fn handle_handshake_complete(
12091229
// Do NOT loop on wants_write() - avoid infinite loop/deadlock
12101230
let tls_data = ssl_write_tls_records(conn)?;
12111231
if !tls_data.is_empty() {
1212-
send_all_bytes(socket, tls_data, vm)?;
1232+
send_all_bytes(socket, tls_data, vm, None)?;
12131233
}
12141234

12151235
// IMPORTANT: Don't check wants_write() again!
@@ -1224,7 +1244,7 @@ fn handle_handshake_complete(
12241244
if tls_data.is_empty() {
12251245
break;
12261246
}
1227-
match send_all_bytes(socket, tls_data, vm) {
1247+
match send_all_bytes(socket, tls_data, vm, None) {
12281248
Ok(()) => {}
12291249
Err(SslError::WantWrite) => break,
12301250
Err(e) => return Err(e),
@@ -1314,13 +1334,13 @@ pub(super) fn ssl_do_handshake(
13141334
if !is_bio {
13151335
conn.send_close_notify();
13161336
// Flush any pending TLS data before sending close_notify
1317-
let _ = socket.flush_pending_tls_output(vm);
1337+
let _ = socket.flush_pending_tls_output(vm, None);
13181338
// Actually send the close_notify alert using send_all_bytes
13191339
// for proper partial send handling and retry logic
13201340
if let Ok(alert_data) = ssl_write_tls_records(conn)
13211341
&& !alert_data.is_empty()
13221342
{
1323-
let _ = send_all_bytes(socket, alert_data, vm);
1343+
let _ = send_all_bytes(socket, alert_data, vm, None);
13241344
}
13251345
}
13261346

@@ -1371,7 +1391,7 @@ pub(super) fn ssl_do_handshake(
13711391
break;
13721392
}
13731393
// Send to outgoing BIO
1374-
send_all_bytes(socket, buf[..n].to_vec(), vm)?;
1394+
send_all_bytes(socket, buf[..n].to_vec(), vm, None)?;
13751395
// Check if there's more to write
13761396
if !conn.wants_write() {
13771397
break;
@@ -1496,15 +1516,20 @@ pub(super) fn ssl_read(
14961516
}
14971517

14981518
// Flush pending TLS data before continuing
1519+
// CRITICAL: Pass deadline so flush respects read timeout
14991520
let tls_data = ssl_write_tls_records(conn)?;
15001521
if !tls_data.is_empty() {
15011522
// Use best-effort send - don't fail READ just because WRITE couldn't complete
1502-
match send_all_bytes(socket, tls_data, vm) {
1523+
match send_all_bytes(socket, tls_data, vm, deadline) {
15031524
Ok(()) => {}
15041525
Err(SslError::WantWrite) => {
15051526
// Socket buffer full - acceptable during READ operation
15061527
// Pending data will be sent on next write/read call
15071528
}
1529+
Err(SslError::Timeout(_)) => {
1530+
// Timeout during flush is acceptable during READ
1531+
// Pending data stays buffered for next operation
1532+
}
15081533
Err(e) => return Err(e),
15091534
}
15101535
}

0 commit comments

Comments
 (0)