Skip to content
Merged
33 changes: 0 additions & 33 deletions crates/host_env/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,10 @@ use crate::os::CheckLibcResult;
#[cfg(unix)]
use core::ffi::CStr;
#[cfg(unix)]
use core::time::Duration;
#[cfg(unix)]
use std::os::fd::AsRawFd;
#[cfg(unix)]
use std::{io, os::fd::BorrowedFd};

#[cfg(unix)]
#[derive(Copy, Clone)]
pub enum PollKind {
Read,
Write,
Connect,
}

#[cfg(all(unix, not(target_os = "redox")))]
pub fn sethostname(hostname: &str) -> io::Result<()> {
nix::unistd::sethostname(hostname).map_err(io::Error::from)
Expand Down Expand Up @@ -111,29 +101,6 @@ pub fn setsockopt_none(fd: libc::c_int, level: i32, name: i32, optlen: u32) -> i
Ok(())
}

#[cfg(unix)]
pub fn poll_socket(
fd: BorrowedFd<'_>,
kind: PollKind,
interval: Option<Duration>,
) -> io::Result<bool> {
use nix::poll::{PollFd, PollFlags, PollTimeout, poll};

let events = match kind {
PollKind::Read => PollFlags::POLLIN,
PollKind::Write => PollFlags::POLLOUT,
PollKind::Connect => PollFlags::POLLOUT | PollFlags::POLLERR,
};
let mut pollfd = [PollFd::new(fd, events)];
let timeout = match interval {
Some(d) => d.try_into().unwrap_or(PollTimeout::MAX),
None => PollTimeout::NONE,
};
poll(&mut pollfd, timeout)
.map(|ret| ret == 0)
.map_err(io::Error::from)
}

#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
Expand Down
204 changes: 132 additions & 72 deletions crates/stdlib/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub(crate) use _socket::module_def;

#[cfg(feature = "ssl")]
pub(super) use _socket::{PySocket, SelectKind, sock_select, timeout_error_msg};
pub(super) use _socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg};

#[pymodule]
mod _socket {
Expand Down Expand Up @@ -1060,20 +1060,20 @@ mod _socket {
fn sock_op<F, R>(
&self,
vm: &VirtualMachine,
select: SelectKind,
wait_kind: SockWaitKind,
f: F,
) -> Result<R, IoOrPyException>
where
F: FnMut() -> io::Result<R>,
{
let timeout = self.get_timeout().ok();
self.sock_op_timeout_err(vm, select, timeout, f)
self.sock_op_timeout_err(vm, wait_kind, timeout, f)
}

fn sock_op_timeout_err<F, R>(
&self,
vm: &VirtualMachine,
select: SelectKind,
wait_kind: SockWaitKind,
timeout: Option<Duration>,
mut f: F,
) -> Result<R, IoOrPyException>
Expand All @@ -1083,19 +1083,9 @@ mod _socket {
let deadline = timeout.map(Deadline::new);

loop {
if deadline.is_some() || matches!(select, SelectKind::Connect) {
let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?;
if deadline.is_some() || matches!(wait_kind, SockWaitKind::Connect) {
let sock = self.sock()?;
let res = vm.allow_threads(|| sock_select(&sock, select, interval));
match res {
Ok(true) => return Err(IoOrPyException::Timeout),
Err(e) if e.kind() == io::ErrorKind::Interrupted => {
vm.check_signals()?;
continue;
}
Err(e) => return Err(e.into()),
Ok(false) => {} // no timeout, continue as normal
}
sock_wait_deadline(&sock, wait_kind, &deadline, vm)?;
}

let err = loop {
Expand Down Expand Up @@ -1339,16 +1329,11 @@ mod _socket {
};

if wait_connect {
// basically, connect() is async, and it registers an "error" on the socket when it's
// done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up
// from poll and the error is EISCONN then we know that the connect is done
self.sock_op(vm, SelectKind::Connect, || {
self.sock_op(vm, SockWaitKind::Connect, || {
let sock = self.sock()?;
let err = sock.take_error()?;
match err {
Some(e) if e.posix_errno() == libc::EISCONN => Ok(()),
Some(e) => Err(e),
// TODO: is this accurate?
None => Ok(()),
}
})
Expand Down Expand Up @@ -1587,7 +1572,8 @@ mod _socket {
) -> Result<(RawSocket, PyObjectRef), IoOrPyException> {
// Use accept_raw() instead of accept() to avoid socket2's set_common_flags()
// which tries to set SO_NOSIGPIPE and fails with EINVAL on Unix domain sockets on macOS
let (sock, addr) = self.sock_op(vm, SelectKind::Read, || self.sock()?.accept_raw())?;
let (sock, addr) =
self.sock_op(vm, SockWaitKind::Read, || self.sock()?.accept_raw())?;
let fd = into_sock_fileno(sock);
Ok((fd, get_addr_tuple(&addr, vm)))
}
Expand All @@ -1602,7 +1588,7 @@ mod _socket {
let flags = flags.unwrap_or(0);
let mut buffer = Vec::with_capacity(bufsize);
let sock = self.sock()?;
let n = self.sock_op(vm, SelectKind::Read, || {
let n = self.sock_op(vm, SockWaitKind::Read, || {
sock.recv_with_flags(buffer.spare_capacity_mut(), flags)
})?;
unsafe { buffer.set_len(n) };
Expand Down Expand Up @@ -1633,7 +1619,7 @@ mod _socket {
};

let buf = &mut buf[..read_len];
self.sock_op(vm, SelectKind::Read, || {
self.sock_op(vm, SockWaitKind::Read, || {
sock.recv_with_flags(unsafe { slice_as_uninit(buf) }, flags)
})
}
Expand All @@ -1650,7 +1636,7 @@ mod _socket {
.to_usize()
.ok_or_else(|| vm.new_value_error("negative buffersize in recvfrom"))?;
let mut buffer = Vec::with_capacity(bufsize);
let (n, addr) = self.sock_op(vm, SelectKind::Read, || {
let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || {
self.sock()?
.recv_from_with_flags(buffer.spare_capacity_mut(), flags)
})?;
Expand Down Expand Up @@ -1681,7 +1667,7 @@ mod _socket {
};
let flags = flags.unwrap_or(0);
let sock = self.sock()?;
let (n, addr) = self.sock_op(vm, SelectKind::Read, || {
let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || {
sock.recv_from_with_flags(unsafe { slice_as_uninit(buf) }, flags)
})?;
Ok((n, get_addr_tuple(&addr, vm)))
Expand All @@ -1697,7 +1683,7 @@ mod _socket {
let flags = flags.unwrap_or(0);
let buf = bytes.borrow_buf();
let buf = &*buf;
self.sock_op(vm, SelectKind::Write, || {
self.sock_op(vm, SockWaitKind::Write, || {
self.sock()?.send_with_flags(buf, flags)
})
}
Expand All @@ -1721,7 +1707,7 @@ mod _socket {
// now we have like 3 layers of interrupt loop :)
while buf_offset < buf.len() {
let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?;
self.sock_op_timeout_err(vm, SelectKind::Write, interval, || {
self.sock_op_timeout_err(vm, SockWaitKind::Write, interval, || {
let subbuf = &buf[buf_offset..];
buf_offset += self.sock()?.send_with_flags(subbuf, flags)?;
Ok(())
Expand Down Expand Up @@ -1754,7 +1740,7 @@ mod _socket {
let addr = self.extract_address(address, "sendto", vm)?;
let buf = bytes.borrow_buf();
let buf = &*buf;
self.sock_op(vm, SelectKind::Write, || {
self.sock_op(vm, SockWaitKind::Write, || {
self.sock()?.send_to_with_flags(buf, &addr, flags)
})
}
Expand Down Expand Up @@ -1812,7 +1798,7 @@ mod _socket {
}
}

self.sock_op(vm, SelectKind::Write, || {
self.sock_op(vm, SockWaitKind::Write, || {
let sock = self.sock()?;
sock.sendmsg(&msg, flags)
})
Expand Down Expand Up @@ -1848,7 +1834,7 @@ mod _socket {
.collect::<Vec<_>>();
let iv = iv.map(|iv| iv.borrow_buf().to_vec());

self.sock_op(vm, SelectKind::Write, || {
self.sock_op(vm, SockWaitKind::Write, || {
let sock = self.sock()?;
let fd = unsafe { BorrowedFd::borrow_raw(sock_fileno(&sock)) };
host_socket::sendmsg_afalg(fd, &buffers, op, iv.as_deref(), assoclen, flags)
Expand Down Expand Up @@ -1881,7 +1867,7 @@ mod _socket {
let flags = flags.unwrap_or(0);

let msg = self
.sock_op(vm, SelectKind::Read, || {
.sock_op(vm, SockWaitKind::Read, || {
let sock = self.sock()?;
let fd = unsafe { std::os::fd::BorrowedFd::borrow_raw(sock_fileno(&sock)) };
host_socket::recvmsg(fd, bufsize, ancbufsize, flags)
Expand Down Expand Up @@ -2436,61 +2422,135 @@ mod _socket {
}

#[derive(Copy, Clone)]
pub(crate) enum SelectKind {
pub(crate) enum SockWaitKind {
Read,
Write,
Connect,
}

/// returns true if timed out
pub(crate) fn sock_select(
/// returns Ok(true) on timeout
pub(crate) fn sock_wait(
sock: &Socket,
wait_kind: SockWaitKind,
timeout: Option<Duration>,
vm: &VirtualMachine,
) -> PyResult<bool> {
match sock_wait_deadline(sock, wait_kind, &timeout.map(Deadline::new), vm) {
Ok(()) => Ok(false),
Err(IoOrPyException::Timeout) => Ok(true),
Err(e) => Err(e.into_pyexception(vm)),
}
}

/// returns Err(IoOrPyException::Timeout) on timeout
fn sock_wait_deadline(
sock: &Socket,
kind: SelectKind,
interval: Option<Duration>,
) -> io::Result<bool> {
wait_kind: SockWaitKind,
deadline: &Option<Deadline>,
vm: &VirtualMachine,
) -> Result<(), IoOrPyException> {
#[cfg(unix)]
{
use std::os::fd::AsFd;
let kind = match kind {
SelectKind::Read => host_socket::PollKind::Read,
SelectKind::Write => host_socket::PollKind::Write,
SelectKind::Connect => host_socket::PollKind::Connect,
};
host_socket::poll_socket(sock.as_fd(), kind, interval)
use rustpython_host_env::select::{PollFd, poll_fds};

let mut events = 0;
if matches!(wait_kind, SockWaitKind::Read) {
events |= libc::POLLIN | libc::POLLPRI;
}
if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) {
events |= libc::POLLOUT;
}
let mut fds = [PollFd {
fd: sock_fileno(sock),
events,
revents: 0,
}; 1];

loop {
let (timeout, is_capped) = deadline
.as_ref()
.map(|d| {
d.time_until().map(|t| {
let timeout_ms = t.as_millis();
let is_capped = timeout_ms > i32::MAX as u128;
let timeout = if is_capped {
i32::MAX
} else {
timeout_ms as i32
};
(timeout, is_capped)
})
})
.transpose()?
.unwrap_or((-1, false));

match vm.allow_threads(|| poll_fds(&mut fds, timeout)) {
Ok(0) => {
if is_capped {
continue;
}
break Err(IoOrPyException::Timeout);
}

Ok(_) => {
if fds[0].revents & libc::POLLNVAL != 0 {
break Err(io::Error::from_raw_os_error(libc::EBADF).into());
}
break Ok(());
}

Err(e) => {
if e.kind() == io::ErrorKind::Interrupted {
vm.check_signals()?;
continue;
}
break Err(e.into());
}
}
}
}
#[cfg(windows)]
{
use rustpython_host_env::select as host_select;
use rustpython_host_env::select::{FdSet, select, timeval};

let fd = sock_fileno(sock);
let fd = sock_fileno(sock) as usize;

let mut reads = host_select::FdSet::new();
let mut writes = host_select::FdSet::new();
let mut errs = host_select::FdSet::new();
let mut reads = FdSet::new();
let mut writes = FdSet::new();
let mut errs = FdSet::new();

let fd = fd as usize;
match kind {
SelectKind::Read => reads.insert(fd),
SelectKind::Write => writes.insert(fd),
SelectKind::Connect => {
writes.insert(fd);
errs.insert(fd);
}
if matches!(wait_kind, SockWaitKind::Read) {
reads.insert(fd);
errs.insert(fd);
}
if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) {
writes.insert(fd);
errs.insert(fd);
}

let mut interval = interval.map(|dur| host_select::timeval {
tv_sec: dur.as_secs() as _,
tv_usec: dur.subsec_micros() as _,
});

host_select::select(
fd as i32 + 1,
&mut reads,
&mut writes,
&mut errs,
interval.as_mut(),
)
.map(|ret| ret == 0)
let mut timeout = deadline
.as_ref()
.map(|d| {
d.time_until().map(|dur| timeval {
tv_sec: dur.as_secs() as _,
tv_usec: dur.subsec_micros() as _,
})
})
.transpose()?;

match vm.allow_threads(|| {
select(
0, // nfds is ignored on windows
&mut reads,
&mut writes,
&mut errs,
timeout.as_mut(),
)
}) {
Ok(0) => Err(IoOrPyException::Timeout),
Ok(_) => Ok(()),
Err(e) => Err(e.into()),
}
}
}

Expand Down
Loading
Loading