Skip to content

Commit 2a16360

Browse files
authored
Rustls integration improvements (RustPython#7946)
* Do not call `import socket` on each send()/recv() when using rustls Use method references cached during socket creation. * Implement reading of at most one TLS record from socket Previous algorithm didn't take into account that recv() may return less data than requested even for blocking sockets. * Remove special handling of rustls "buffer full" errors First of all, existing code does not really work and this leads to an infinite loop: RustPython#7891 Second, this should never happen when rustls used properly (wrt wants_read() and wants_write()) and thus all such errors are implementation bugs that must be properly fixed. * Replace own TlsConnection with rustls::Connection * Fix waiting on a socket 1) Ensure that socket_wait() called from TLS glue code allows threads 2) Ensure that socket_wait() called from TLS glue code properly handles EINTR on *nix 3) Ensure that select() or poll() error conditions are checked 4) Use poll() on *nix so socket descriptor values are not limited * Remove dead code from rustls glue * Do not present rustls errors as OSError(0, "Success") * Remove infinite loop "detection" from rustls glue TLS handshake cannot be infinite. Any infinite loop here is a serious bug in implementation and should be fixed properly. This code triggers in some cases (very short reads) with misleading `ssl_error.SSLWantReadError: The operation did not complete (read)`. * Add test for 1-byte max recv in TLS client * Add regression test for RustPython#7891 * Fix constants in rustls glue code * Deduplicate verify flags / record-size constants * Larger "max encrypted TLS record length"
1 parent 4eb9534 commit 2a16360

8 files changed

Lines changed: 737 additions & 584 deletions

File tree

crates/host_env/src/socket.rs

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,10 @@ use crate::os::CheckLibcResult;
33
#[cfg(unix)]
44
use core::ffi::CStr;
55
#[cfg(unix)]
6-
use core::time::Duration;
7-
#[cfg(unix)]
86
use std::os::fd::AsRawFd;
97
#[cfg(unix)]
108
use std::{io, os::fd::BorrowedFd};
119

12-
#[cfg(unix)]
13-
#[derive(Copy, Clone)]
14-
pub enum PollKind {
15-
Read,
16-
Write,
17-
Connect,
18-
}
19-
2010
#[cfg(all(unix, not(target_os = "redox")))]
2111
pub fn sethostname(hostname: &str) -> io::Result<()> {
2212
nix::unistd::sethostname(hostname).map_err(io::Error::from)
@@ -111,29 +101,6 @@ pub fn setsockopt_none(fd: libc::c_int, level: i32, name: i32, optlen: u32) -> i
111101
Ok(())
112102
}
113103

114-
#[cfg(unix)]
115-
pub fn poll_socket(
116-
fd: BorrowedFd<'_>,
117-
kind: PollKind,
118-
interval: Option<Duration>,
119-
) -> io::Result<bool> {
120-
use nix::poll::{PollFd, PollFlags, PollTimeout, poll};
121-
122-
let events = match kind {
123-
PollKind::Read => PollFlags::POLLIN,
124-
PollKind::Write => PollFlags::POLLOUT,
125-
PollKind::Connect => PollFlags::POLLOUT | PollFlags::POLLERR,
126-
};
127-
let mut pollfd = [PollFd::new(fd, events)];
128-
let timeout = match interval {
129-
Some(d) => d.try_into().unwrap_or(PollTimeout::MAX),
130-
None => PollTimeout::NONE,
131-
};
132-
poll(&mut pollfd, timeout)
133-
.map(|ret| ret == 0)
134-
.map_err(io::Error::from)
135-
}
136-
137104
#[cfg(any(
138105
target_os = "dragonfly",
139106
target_os = "freebsd",

crates/stdlib/src/socket.rs

Lines changed: 132 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
pub(crate) use _socket::module_def;
44

55
#[cfg(feature = "ssl")]
6-
pub(super) use _socket::{PySocket, SelectKind, sock_select, timeout_error_msg};
6+
pub(super) use _socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg};
77

88
#[pymodule]
99
mod _socket {
@@ -1060,20 +1060,20 @@ mod _socket {
10601060
fn sock_op<F, R>(
10611061
&self,
10621062
vm: &VirtualMachine,
1063-
select: SelectKind,
1063+
wait_kind: SockWaitKind,
10641064
f: F,
10651065
) -> Result<R, IoOrPyException>
10661066
where
10671067
F: FnMut() -> io::Result<R>,
10681068
{
10691069
let timeout = self.get_timeout().ok();
1070-
self.sock_op_timeout_err(vm, select, timeout, f)
1070+
self.sock_op_timeout_err(vm, wait_kind, timeout, f)
10711071
}
10721072

10731073
fn sock_op_timeout_err<F, R>(
10741074
&self,
10751075
vm: &VirtualMachine,
1076-
select: SelectKind,
1076+
wait_kind: SockWaitKind,
10771077
timeout: Option<Duration>,
10781078
mut f: F,
10791079
) -> Result<R, IoOrPyException>
@@ -1083,19 +1083,9 @@ mod _socket {
10831083
let deadline = timeout.map(Deadline::new);
10841084

10851085
loop {
1086-
if deadline.is_some() || matches!(select, SelectKind::Connect) {
1087-
let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?;
1086+
if deadline.is_some() || matches!(wait_kind, SockWaitKind::Connect) {
10881087
let sock = self.sock()?;
1089-
let res = vm.allow_threads(|| sock_select(&sock, select, interval));
1090-
match res {
1091-
Ok(true) => return Err(IoOrPyException::Timeout),
1092-
Err(e) if e.kind() == io::ErrorKind::Interrupted => {
1093-
vm.check_signals()?;
1094-
continue;
1095-
}
1096-
Err(e) => return Err(e.into()),
1097-
Ok(false) => {} // no timeout, continue as normal
1098-
}
1088+
sock_wait_deadline(&sock, wait_kind, &deadline, vm)?;
10991089
}
11001090

11011091
let err = loop {
@@ -1339,16 +1329,11 @@ mod _socket {
13391329
};
13401330

13411331
if wait_connect {
1342-
// basically, connect() is async, and it registers an "error" on the socket when it's
1343-
// done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up
1344-
// from poll and the error is EISCONN then we know that the connect is done
1345-
self.sock_op(vm, SelectKind::Connect, || {
1332+
self.sock_op(vm, SockWaitKind::Connect, || {
13461333
let sock = self.sock()?;
13471334
let err = sock.take_error()?;
13481335
match err {
1349-
Some(e) if e.posix_errno() == libc::EISCONN => Ok(()),
13501336
Some(e) => Err(e),
1351-
// TODO: is this accurate?
13521337
None => Ok(()),
13531338
}
13541339
})
@@ -1587,7 +1572,8 @@ mod _socket {
15871572
) -> Result<(RawSocket, PyObjectRef), IoOrPyException> {
15881573
// Use accept_raw() instead of accept() to avoid socket2's set_common_flags()
15891574
// which tries to set SO_NOSIGPIPE and fails with EINVAL on Unix domain sockets on macOS
1590-
let (sock, addr) = self.sock_op(vm, SelectKind::Read, || self.sock()?.accept_raw())?;
1575+
let (sock, addr) =
1576+
self.sock_op(vm, SockWaitKind::Read, || self.sock()?.accept_raw())?;
15911577
let fd = into_sock_fileno(sock);
15921578
Ok((fd, get_addr_tuple(&addr, vm)))
15931579
}
@@ -1602,7 +1588,7 @@ mod _socket {
16021588
let flags = flags.unwrap_or(0);
16031589
let mut buffer = Vec::with_capacity(bufsize);
16041590
let sock = self.sock()?;
1605-
let n = self.sock_op(vm, SelectKind::Read, || {
1591+
let n = self.sock_op(vm, SockWaitKind::Read, || {
16061592
sock.recv_with_flags(buffer.spare_capacity_mut(), flags)
16071593
})?;
16081594
unsafe { buffer.set_len(n) };
@@ -1633,7 +1619,7 @@ mod _socket {
16331619
};
16341620

16351621
let buf = &mut buf[..read_len];
1636-
self.sock_op(vm, SelectKind::Read, || {
1622+
self.sock_op(vm, SockWaitKind::Read, || {
16371623
sock.recv_with_flags(unsafe { slice_as_uninit(buf) }, flags)
16381624
})
16391625
}
@@ -1650,7 +1636,7 @@ mod _socket {
16501636
.to_usize()
16511637
.ok_or_else(|| vm.new_value_error("negative buffersize in recvfrom"))?;
16521638
let mut buffer = Vec::with_capacity(bufsize);
1653-
let (n, addr) = self.sock_op(vm, SelectKind::Read, || {
1639+
let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || {
16541640
self.sock()?
16551641
.recv_from_with_flags(buffer.spare_capacity_mut(), flags)
16561642
})?;
@@ -1681,7 +1667,7 @@ mod _socket {
16811667
};
16821668
let flags = flags.unwrap_or(0);
16831669
let sock = self.sock()?;
1684-
let (n, addr) = self.sock_op(vm, SelectKind::Read, || {
1670+
let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || {
16851671
sock.recv_from_with_flags(unsafe { slice_as_uninit(buf) }, flags)
16861672
})?;
16871673
Ok((n, get_addr_tuple(&addr, vm)))
@@ -1697,7 +1683,7 @@ mod _socket {
16971683
let flags = flags.unwrap_or(0);
16981684
let buf = bytes.borrow_buf();
16991685
let buf = &*buf;
1700-
self.sock_op(vm, SelectKind::Write, || {
1686+
self.sock_op(vm, SockWaitKind::Write, || {
17011687
self.sock()?.send_with_flags(buf, flags)
17021688
})
17031689
}
@@ -1721,7 +1707,7 @@ mod _socket {
17211707
// now we have like 3 layers of interrupt loop :)
17221708
while buf_offset < buf.len() {
17231709
let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?;
1724-
self.sock_op_timeout_err(vm, SelectKind::Write, interval, || {
1710+
self.sock_op_timeout_err(vm, SockWaitKind::Write, interval, || {
17251711
let subbuf = &buf[buf_offset..];
17261712
buf_offset += self.sock()?.send_with_flags(subbuf, flags)?;
17271713
Ok(())
@@ -1754,7 +1740,7 @@ mod _socket {
17541740
let addr = self.extract_address(address, "sendto", vm)?;
17551741
let buf = bytes.borrow_buf();
17561742
let buf = &*buf;
1757-
self.sock_op(vm, SelectKind::Write, || {
1743+
self.sock_op(vm, SockWaitKind::Write, || {
17581744
self.sock()?.send_to_with_flags(buf, &addr, flags)
17591745
})
17601746
}
@@ -1812,7 +1798,7 @@ mod _socket {
18121798
}
18131799
}
18141800

1815-
self.sock_op(vm, SelectKind::Write, || {
1801+
self.sock_op(vm, SockWaitKind::Write, || {
18161802
let sock = self.sock()?;
18171803
sock.sendmsg(&msg, flags)
18181804
})
@@ -1848,7 +1834,7 @@ mod _socket {
18481834
.collect::<Vec<_>>();
18491835
let iv = iv.map(|iv| iv.borrow_buf().to_vec());
18501836

1851-
self.sock_op(vm, SelectKind::Write, || {
1837+
self.sock_op(vm, SockWaitKind::Write, || {
18521838
let sock = self.sock()?;
18531839
let fd = unsafe { BorrowedFd::borrow_raw(sock_fileno(&sock)) };
18541840
host_socket::sendmsg_afalg(fd, &buffers, op, iv.as_deref(), assoclen, flags)
@@ -1881,7 +1867,7 @@ mod _socket {
18811867
let flags = flags.unwrap_or(0);
18821868

18831869
let msg = self
1884-
.sock_op(vm, SelectKind::Read, || {
1870+
.sock_op(vm, SockWaitKind::Read, || {
18851871
let sock = self.sock()?;
18861872
let fd = unsafe { std::os::fd::BorrowedFd::borrow_raw(sock_fileno(&sock)) };
18871873
host_socket::recvmsg(fd, bufsize, ancbufsize, flags)
@@ -2436,61 +2422,135 @@ mod _socket {
24362422
}
24372423

24382424
#[derive(Copy, Clone)]
2439-
pub(crate) enum SelectKind {
2425+
pub(crate) enum SockWaitKind {
24402426
Read,
24412427
Write,
24422428
Connect,
24432429
}
24442430

2445-
/// returns true if timed out
2446-
pub(crate) fn sock_select(
2431+
/// returns Ok(true) on timeout
2432+
pub(crate) fn sock_wait(
2433+
sock: &Socket,
2434+
wait_kind: SockWaitKind,
2435+
timeout: Option<Duration>,
2436+
vm: &VirtualMachine,
2437+
) -> PyResult<bool> {
2438+
match sock_wait_deadline(sock, wait_kind, &timeout.map(Deadline::new), vm) {
2439+
Ok(()) => Ok(false),
2440+
Err(IoOrPyException::Timeout) => Ok(true),
2441+
Err(e) => Err(e.into_pyexception(vm)),
2442+
}
2443+
}
2444+
2445+
/// returns Err(IoOrPyException::Timeout) on timeout
2446+
fn sock_wait_deadline(
24472447
sock: &Socket,
2448-
kind: SelectKind,
2449-
interval: Option<Duration>,
2450-
) -> io::Result<bool> {
2448+
wait_kind: SockWaitKind,
2449+
deadline: &Option<Deadline>,
2450+
vm: &VirtualMachine,
2451+
) -> Result<(), IoOrPyException> {
24512452
#[cfg(unix)]
24522453
{
2453-
use std::os::fd::AsFd;
2454-
let kind = match kind {
2455-
SelectKind::Read => host_socket::PollKind::Read,
2456-
SelectKind::Write => host_socket::PollKind::Write,
2457-
SelectKind::Connect => host_socket::PollKind::Connect,
2458-
};
2459-
host_socket::poll_socket(sock.as_fd(), kind, interval)
2454+
use rustpython_host_env::select::{PollFd, poll_fds};
2455+
2456+
let mut events = 0;
2457+
if matches!(wait_kind, SockWaitKind::Read) {
2458+
events |= libc::POLLIN | libc::POLLPRI;
2459+
}
2460+
if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) {
2461+
events |= libc::POLLOUT;
2462+
}
2463+
let mut fds = [PollFd {
2464+
fd: sock_fileno(sock),
2465+
events,
2466+
revents: 0,
2467+
}; 1];
2468+
2469+
loop {
2470+
let (timeout, is_capped) = deadline
2471+
.as_ref()
2472+
.map(|d| {
2473+
d.time_until().map(|t| {
2474+
let timeout_ms = t.as_millis();
2475+
let is_capped = timeout_ms > i32::MAX as u128;
2476+
let timeout = if is_capped {
2477+
i32::MAX
2478+
} else {
2479+
timeout_ms as i32
2480+
};
2481+
(timeout, is_capped)
2482+
})
2483+
})
2484+
.transpose()?
2485+
.unwrap_or((-1, false));
2486+
2487+
match vm.allow_threads(|| poll_fds(&mut fds, timeout)) {
2488+
Ok(0) => {
2489+
if is_capped {
2490+
continue;
2491+
}
2492+
break Err(IoOrPyException::Timeout);
2493+
}
2494+
2495+
Ok(_) => {
2496+
if fds[0].revents & libc::POLLNVAL != 0 {
2497+
break Err(io::Error::from_raw_os_error(libc::EBADF).into());
2498+
}
2499+
break Ok(());
2500+
}
2501+
2502+
Err(e) => {
2503+
if e.kind() == io::ErrorKind::Interrupted {
2504+
vm.check_signals()?;
2505+
continue;
2506+
}
2507+
break Err(e.into());
2508+
}
2509+
}
2510+
}
24602511
}
24612512
#[cfg(windows)]
24622513
{
2463-
use rustpython_host_env::select as host_select;
2514+
use rustpython_host_env::select::{FdSet, select, timeval};
24642515

2465-
let fd = sock_fileno(sock);
2516+
let fd = sock_fileno(sock) as usize;
24662517

2467-
let mut reads = host_select::FdSet::new();
2468-
let mut writes = host_select::FdSet::new();
2469-
let mut errs = host_select::FdSet::new();
2518+
let mut reads = FdSet::new();
2519+
let mut writes = FdSet::new();
2520+
let mut errs = FdSet::new();
24702521

2471-
let fd = fd as usize;
2472-
match kind {
2473-
SelectKind::Read => reads.insert(fd),
2474-
SelectKind::Write => writes.insert(fd),
2475-
SelectKind::Connect => {
2476-
writes.insert(fd);
2477-
errs.insert(fd);
2478-
}
2522+
if matches!(wait_kind, SockWaitKind::Read) {
2523+
reads.insert(fd);
2524+
errs.insert(fd);
2525+
}
2526+
if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) {
2527+
writes.insert(fd);
2528+
errs.insert(fd);
24792529
}
24802530

2481-
let mut interval = interval.map(|dur| host_select::timeval {
2482-
tv_sec: dur.as_secs() as _,
2483-
tv_usec: dur.subsec_micros() as _,
2484-
});
2485-
2486-
host_select::select(
2487-
fd as i32 + 1,
2488-
&mut reads,
2489-
&mut writes,
2490-
&mut errs,
2491-
interval.as_mut(),
2492-
)
2493-
.map(|ret| ret == 0)
2531+
let mut timeout = deadline
2532+
.as_ref()
2533+
.map(|d| {
2534+
d.time_until().map(|dur| timeval {
2535+
tv_sec: dur.as_secs() as _,
2536+
tv_usec: dur.subsec_micros() as _,
2537+
})
2538+
})
2539+
.transpose()?;
2540+
2541+
match vm.allow_threads(|| {
2542+
select(
2543+
0, // nfds is ignored on windows
2544+
&mut reads,
2545+
&mut writes,
2546+
&mut errs,
2547+
timeout.as_mut(),
2548+
)
2549+
}) {
2550+
Ok(0) => Err(IoOrPyException::Timeout),
2551+
Ok(_) => Ok(()),
2552+
Err(e) => Err(e.into()),
2553+
}
24942554
}
24952555
}
24962556

0 commit comments

Comments
 (0)