diff --git a/crates/stdlib/src/overlapped.rs b/crates/stdlib/src/overlapped.rs index eb2a968c042..7520569c055 100644 --- a/crates/stdlib/src/overlapped.rs +++ b/crates/stdlib/src/overlapped.rs @@ -8,11 +8,12 @@ mod _overlapped { // straight-forward port of Modules/overlapped.c use crate::vm::{ - Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, - builtins::{PyBaseExceptionRef, PyBytesRef, PyType}, + AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, + builtins::{PyBaseExceptionRef, PyBytesRef, PyModule, PyStrRef, PyTupleRef, PyType}, common::lock::PyMutex, - convert::{ToPyException, ToPyObject}, + convert::ToPyObject, function::OptionalArg, + object::{Traverse, TraverseFn}, protocol::PyBuffer, types::{Constructor, Destructor}, }; @@ -22,6 +23,13 @@ mod _overlapped { System::IO::OVERLAPPED, }; + pub(crate) fn module_exec(vm: &VirtualMachine, module: &Py) -> PyResult<()> { + let _ = vm.import("_socket", 0)?; + initialize_winsock_extensions(vm)?; + __module_exec(vm, module); + Ok(()) + } + #[pyattr] use windows_sys::Win32::{ Foundation::{ @@ -49,8 +57,8 @@ mod _overlapped { fn initialize_winsock_extensions(vm: &VirtualMachine) -> PyResult<()> { use windows_sys::Win32::Networking::WinSock::{ - IPPROTO_TCP, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCK_STREAM, SOCKET_ERROR, WSAIoctl, - closesocket, socket, + INVALID_SOCKET, IPPROTO_TCP, SIO_GET_EXTENSION_FUNCTION_POINTER, SOCK_STREAM, + SOCKET_ERROR, WSAGetLastError, WSAIoctl, closesocket, socket, }; // GUIDs for extension functions @@ -89,10 +97,9 @@ mod _overlapped { } let s = unsafe { socket(AF_INET as i32, SOCK_STREAM, IPPROTO_TCP) }; - if s == windows_sys::Win32::Networking::WinSock::INVALID_SOCKET { - return Err( - vm.new_os_error("Failed to create socket for WSA extension init".to_owned()) - ); + if s == INVALID_SOCKET { + let err = unsafe { WSAGetLastError() } as u32; + return Err(set_from_windows_err(err, vm)); } let mut dw_bytes: u32 = 0; @@ -114,8 +121,9 @@ mod _overlapped { ) }; if ret == SOCKET_ERROR { + let err = unsafe { WSAGetLastError() } as u32; unsafe { closesocket(s) }; - return Err(vm.new_os_error("Failed to get WSA extension function".to_owned())); + return Err(set_from_windows_err(err, vm)); } let _ = $lock.set(func_ptr); }}; @@ -131,7 +139,7 @@ mod _overlapped { } #[pyattr] - #[pyclass(name)] + #[pyclass(name, traverse)] #[derive(PyPayload)] struct Overlapped { inner: PyMutex, @@ -147,6 +155,35 @@ mod _overlapped { unsafe impl Sync for OverlappedInner {} unsafe impl Send for OverlappedInner {} + unsafe impl Traverse for OverlappedInner { + fn traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + match &self.data { + OverlappedData::Read(buf) | OverlappedData::Accept(buf) => { + buf.traverse(traverse_fn); + } + OverlappedData::ReadInto(buf) | OverlappedData::Write(buf) => { + buf.traverse(traverse_fn); + } + OverlappedData::WriteTo(wt) => { + wt.buf.traverse(traverse_fn); + } + OverlappedData::ReadFrom(rf) => { + if let Some(result) = &rf.result { + result.traverse(traverse_fn); + } + rf.allocated_buffer.traverse(traverse_fn); + } + OverlappedData::ReadFromInto(rfi) => { + if let Some(result) = &rfi.result { + result.traverse(traverse_fn); + } + rfi.user_buffer.traverse(traverse_fn); + } + _ => {} + } + } + } + impl core::fmt::Debug for Overlapped { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let zelf = self.inner.lock(); @@ -182,6 +219,8 @@ mod _overlapped { } struct OverlappedReadFrom { + // A (buffer, (host, port)) tuple + result: Option, // The actual read buffer allocated_buffer: PyBytesRef, address: SOCKADDR_IN6, @@ -191,6 +230,7 @@ mod _overlapped { impl core::fmt::Debug for OverlappedReadFrom { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("OverlappedReadFrom") + .field("result", &self.result) .field("allocated_buffer", &self.allocated_buffer) .field("address_length", &self.address_length) .finish() @@ -198,6 +238,8 @@ mod _overlapped { } struct OverlappedReadFromInto { + // A (number of bytes read, (host, port)) tuple + result: Option, /* Buffer passed by the user */ user_buffer: PyBuffer, address: SOCKADDR_IN6, @@ -207,6 +249,7 @@ mod _overlapped { impl core::fmt::Debug for OverlappedReadFromInto { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("OverlappedReadFromInto") + .field("result", &self.result) .field("user_buffer", &self.user_buffer) .field("address_length", &self.address_length) .finish() @@ -234,9 +277,19 @@ mod _overlapped { } } - fn from_windows_err(err: u32, vm: &VirtualMachine) -> PyBaseExceptionRef { - debug_assert_ne!(err, 0, "call errno_err instead"); - std::io::Error::from_raw_os_error(err as i32).to_pyexception(vm) + fn set_from_windows_err(err: u32, vm: &VirtualMachine) -> PyBaseExceptionRef { + let err = if err == 0 { + unsafe { GetLastError() } + } else { + err + }; + let errno = crate::vm::common::os::winerror_to_errno(err as i32); + let message = std::io::Error::from_raw_os_error(err as i32).to_string(); + let exc = vm.new_errno_error(errno, message); + let _ = exc + .as_object() + .set_attr("winerror", err.to_pyobject(vm), vm); + exc.upcast() } fn HasOverlappedIoCompleted(overlapped: &OVERLAPPED) -> bool { @@ -244,135 +297,117 @@ mod _overlapped { } /// Parse a Python address tuple to SOCKADDR - fn parse_address(addr_obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<(Vec, i32)> { - use crate::vm::builtins::PyTuple; - use windows_sys::Win32::Networking::WinSock::WSAStringToAddressW; + fn parse_address(addr_obj: &PyTupleRef, vm: &VirtualMachine) -> PyResult<(Vec, i32)> { + use windows_sys::Win32::Networking::WinSock::{WSAGetLastError, WSAStringToAddressW}; - let tuple = addr_obj - .downcast_ref::() - .ok_or_else(|| vm.new_type_error("address must be a tuple".to_owned()))?; + match addr_obj.len() { + 2 => { + // IPv4: (host, port) + let host: PyStrRef = addr_obj[0].clone().try_into_value(vm)?; + let port: u16 = addr_obj[1].clone().try_to_value(vm)?; - let tuple_len = tuple.len(); + let mut addr: SOCKADDR_IN = unsafe { std::mem::zeroed() }; + addr.sin_family = AF_INET; - if tuple_len == 2 { - // IPv4: (host, port) - let host: String = tuple[0].try_to_value(vm)?; - let port: u16 = tuple[1].try_to_value(vm)?; + let host_wide: Vec = host.as_str().encode_utf16().chain([0]).collect(); + let mut addr_len = std::mem::size_of::() as i32; - let mut addr: SOCKADDR_IN = unsafe { std::mem::zeroed() }; - addr.sin_family = AF_INET; - addr.sin_port = port.to_be(); + let ret = unsafe { + WSAStringToAddressW( + host_wide.as_ptr(), + AF_INET as i32, + std::ptr::null(), + &mut addr as *mut _ as *mut SOCKADDR, + &mut addr_len, + ) + }; - // Convert host string to address - let host_wide: Vec = host.encode_utf16().chain(std::iter::once(0)).collect(); - let mut addr_len = std::mem::size_of::() as i32; + if ret < 0 { + let err = unsafe { WSAGetLastError() } as u32; + return Err(set_from_windows_err(err, vm)); + } - let ret = unsafe { - WSAStringToAddressW( - host_wide.as_ptr(), - AF_INET as i32, - std::ptr::null(), - &mut addr as *mut _ as *mut SOCKADDR, - &mut addr_len, - ) - }; + // Restore port (WSAStringToAddressW overwrites it) + addr.sin_port = port.to_be(); - if ret != 0 { - return Err(vm.new_os_error(format!("Invalid IPv4 address: {}", host))); + let bytes = unsafe { + std::slice::from_raw_parts( + &addr as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + Ok((bytes.to_vec(), addr_len)) } + 4 => { + // IPv6: (host, port, flowinfo, scope_id) + let host: PyStrRef = addr_obj[0].clone().try_into_value(vm)?; + let port: u16 = addr_obj[1].clone().try_to_value(vm)?; + let flowinfo: u32 = addr_obj[2].clone().try_to_value(vm)?; + let scope_id: u32 = addr_obj[3].clone().try_to_value(vm)?; - // Restore port (WSAStringToAddressW overwrites it) - addr.sin_port = port.to_be(); + let mut addr: SOCKADDR_IN6 = unsafe { std::mem::zeroed() }; + addr.sin6_family = AF_INET6; - let bytes = unsafe { - std::slice::from_raw_parts( - &addr as *const _ as *const u8, - std::mem::size_of::(), - ) - }; - Ok((bytes.to_vec(), std::mem::size_of::() as i32)) - } else if tuple_len == 4 { - // IPv6: (host, port, flowinfo, scope_id) - let host: String = tuple[0].try_to_value(vm)?; - let port: u16 = tuple[1].try_to_value(vm)?; - let flowinfo: u32 = tuple[2].try_to_value(vm)?; - let scope_id: u32 = tuple[3].try_to_value(vm)?; + let host_wide: Vec = host.as_str().encode_utf16().chain([0]).collect(); + let mut addr_len = std::mem::size_of::() as i32; - let mut addr: SOCKADDR_IN6 = unsafe { std::mem::zeroed() }; - addr.sin6_family = AF_INET6; - addr.sin6_port = port.to_be(); - addr.sin6_flowinfo = flowinfo; - addr.Anonymous.sin6_scope_id = scope_id; + let ret = unsafe { + WSAStringToAddressW( + host_wide.as_ptr(), + AF_INET6 as i32, + std::ptr::null(), + &mut addr as *mut _ as *mut SOCKADDR, + &mut addr_len, + ) + }; - let host_wide: Vec = host.encode_utf16().chain(std::iter::once(0)).collect(); - let mut addr_len = std::mem::size_of::() as i32; + if ret < 0 { + let err = unsafe { WSAGetLastError() } as u32; + return Err(set_from_windows_err(err, vm)); + } - let ret = unsafe { - WSAStringToAddressW( - host_wide.as_ptr(), - AF_INET6 as i32, - std::ptr::null(), - &mut addr as *mut _ as *mut SOCKADDR, - &mut addr_len, - ) - }; + // Restore fields that WSAStringToAddressW might overwrite + addr.sin6_port = port.to_be(); + addr.sin6_flowinfo = flowinfo; + addr.Anonymous.sin6_scope_id = scope_id; - if ret != 0 { - return Err(vm.new_os_error(format!("Invalid IPv6 address: {}", host))); + let bytes = unsafe { + std::slice::from_raw_parts( + &addr as *const _ as *const u8, + std::mem::size_of::(), + ) + }; + Ok((bytes.to_vec(), addr_len)) } - - // Restore fields that WSAStringToAddressW might overwrite - addr.sin6_port = port.to_be(); - addr.sin6_flowinfo = flowinfo; - addr.Anonymous.sin6_scope_id = scope_id; - - let bytes = unsafe { - std::slice::from_raw_parts( - &addr as *const _ as *const u8, - std::mem::size_of::(), - ) - }; - Ok((bytes.to_vec(), std::mem::size_of::() as i32)) - } else { - Err(vm.new_value_error("address tuple must have 2 or 4 elements".to_owned())) + _ => Err(vm.new_value_error("illegal address_as_bytes argument".to_owned())), } } /// Parse a SOCKADDR_IN6 (which can also hold IPv4 addresses) to a Python address tuple - fn unparse_address(addr: &SOCKADDR_IN6, _addr_len: i32, vm: &VirtualMachine) -> PyObjectRef { + fn unparse_address(addr: &SOCKADDR_IN6, _addr_len: i32, vm: &VirtualMachine) -> PyResult { + use std::net::{Ipv4Addr, Ipv6Addr}; + unsafe { let family = addr.sin6_family; if family == AF_INET { // IPv4 address stored in SOCKADDR_IN6 structure let addr_in = &*(addr as *const SOCKADDR_IN6 as *const SOCKADDR_IN); let ip_bytes = addr_in.sin_addr.S_un.S_un_b; - let ip_str = format!( - "{}.{}.{}.{}", - ip_bytes.s_b1, ip_bytes.s_b2, ip_bytes.s_b3, ip_bytes.s_b4 - ); + let ip_str = + Ipv4Addr::new(ip_bytes.s_b1, ip_bytes.s_b2, ip_bytes.s_b3, ip_bytes.s_b4) + .to_string(); let port = u16::from_be(addr_in.sin_port); - (ip_str, port).to_pyobject(vm) + Ok((ip_str, port).to_pyobject(vm)) } else if family == AF_INET6 { // IPv6 address let ip_bytes = addr.sin6_addr.u.Byte; - let ip_str = format!( - "{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}", - u16::from_be_bytes([ip_bytes[0], ip_bytes[1]]), - u16::from_be_bytes([ip_bytes[2], ip_bytes[3]]), - u16::from_be_bytes([ip_bytes[4], ip_bytes[5]]), - u16::from_be_bytes([ip_bytes[6], ip_bytes[7]]), - u16::from_be_bytes([ip_bytes[8], ip_bytes[9]]), - u16::from_be_bytes([ip_bytes[10], ip_bytes[11]]), - u16::from_be_bytes([ip_bytes[12], ip_bytes[13]]), - u16::from_be_bytes([ip_bytes[14], ip_bytes[15]]), - ); + let ip_str = Ipv6Addr::from(ip_bytes).to_string(); let port = u16::from_be(addr.sin6_port); - let flowinfo = addr.sin6_flowinfo; + let flowinfo = u32::from_be(addr.sin6_flowinfo); let scope_id = addr.Anonymous.sin6_scope_id; - (ip_str, port, flowinfo, scope_id).to_pyobject(vm) + Ok((ip_str, port, flowinfo, scope_id).to_pyobject(vm)) } else { - // Unknown address family, return None - vm.ctx.none() + Err(vm.new_value_error("recvfrom returned unsupported address family".to_owned())) } } } @@ -422,7 +457,7 @@ mod _overlapped { }; // CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between if ret == 0 && unsafe { GetLastError() } != Foundation::ERROR_NOT_FOUND { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(()) } @@ -430,7 +465,7 @@ mod _overlapped { #[pymethod] fn getresult(zelf: &Py, wait: OptionalArg, vm: &VirtualMachine) -> PyResult { use windows_sys::Win32::Foundation::{ - ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_SUCCESS, + ERROR_BROKEN_PIPE, ERROR_MORE_DATA, ERROR_SUCCESS, }; let mut inner = zelf.inner.lock(); @@ -466,62 +501,65 @@ mod _overlapped { match err { ERROR_SUCCESS | ERROR_MORE_DATA => {} ERROR_BROKEN_PIPE => { - // For read operations, broken pipe is acceptable - match &inner.data { - OverlappedData::Read(_) | OverlappedData::ReadInto(_) => {} - OverlappedData::ReadFrom(_) => {} - OverlappedData::ReadFromInto(_) => {} - _ => return Err(from_windows_err(err, vm)), + let allow_broken_pipe = match &inner.data { + OverlappedData::Read(_) | OverlappedData::ReadInto(_) => true, + OverlappedData::ReadFrom(_) => true, + OverlappedData::ReadFromInto(rfi) => rfi.result.is_some(), + _ => false, + }; + if !allow_broken_pipe { + return Err(set_from_windows_err(err, vm)); } } - ERROR_IO_PENDING => { - return Err(from_windows_err(err, vm)); - } - _ => return Err(from_windows_err(err, vm)), + _ => return Err(set_from_windows_err(err, vm)), } // Return result based on operation type - match &inner.data { + match &mut inner.data { OverlappedData::Read(buf) => { - let bytes = buf.as_bytes(); - let result = if transferred as usize != bytes.len() { - vm.ctx.new_bytes(bytes[..transferred as usize].to_vec()) + let len = buf.as_bytes().len(); + let result = if transferred as usize != len { + let resized = vm + .ctx + .new_bytes(buf.as_bytes()[..transferred as usize].to_vec()); + *buf = resized.clone(); + resized } else { buf.clone() }; Ok(result.into()) } - OverlappedData::ReadInto(_) => Ok(vm.ctx.new_int(transferred).into()), - OverlappedData::Write(_) | OverlappedData::WriteTo(_) => { - Ok(vm.ctx.new_int(transferred).into()) - } - OverlappedData::Accept(_) => Ok(vm.ctx.none()), - OverlappedData::Connect(_) => Ok(vm.ctx.none()), - OverlappedData::Disconnect => Ok(vm.ctx.none()), - OverlappedData::ConnectNamedPipe => Ok(vm.ctx.none()), - OverlappedData::WaitNamedPipeAndConnect => Ok(vm.ctx.none()), - OverlappedData::TransmitFile => Ok(vm.ctx.none()), OverlappedData::ReadFrom(rf) => { - let bytes = rf.allocated_buffer.as_bytes(); - let resized_buf = if transferred as usize != bytes.len() { - vm.ctx.new_bytes(bytes[..transferred as usize].to_vec()) + let len = rf.allocated_buffer.as_bytes().len(); + let resized_buf = if transferred as usize != len { + let resized = vm.ctx.new_bytes( + rf.allocated_buffer.as_bytes()[..transferred as usize].to_vec(), + ); + rf.allocated_buffer = resized.clone(); + resized } else { rf.allocated_buffer.clone() }; - let addr_tuple = unparse_address(&rf.address, rf.address_length, vm); - Ok(vm - .ctx - .new_tuple(vec![resized_buf.into(), addr_tuple]) - .into()) + let addr_tuple = unparse_address(&rf.address, rf.address_length, vm)?; + if let Some(result) = &rf.result { + return Ok(result.clone()); + } + let result = vm.ctx.new_tuple(vec![resized_buf.into(), addr_tuple]); + rf.result = Some(result.clone().into()); + Ok(result.into()) } OverlappedData::ReadFromInto(rfi) => { - let addr_tuple = unparse_address(&rfi.address, rfi.address_length, vm); - Ok(vm + let addr_tuple = unparse_address(&rfi.address, rfi.address_length, vm)?; + if let Some(result) = &rfi.result { + return Ok(result.clone()); + } + let result = vm .ctx - .new_tuple(vec![vm.ctx.new_int(transferred).into(), addr_tuple]) - .into()) + .new_tuple(vec![vm.ctx.new_int(transferred).into(), addr_tuple]); + rfi.result = Some(result.clone().into()); + Ok(result.into()) } - _ => Ok(vm.ctx.none()), + _ => Ok(vm.ctx.new_int(transferred).into()), } } @@ -567,12 +605,12 @@ mod _overlapped { match err { ERROR_BROKEN_PIPE => { mark_as_completed(&mut inner.overlapped); - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -597,6 +635,9 @@ mod _overlapped { inner.handle = handle as HANDLE; let buf_len = buf.desc.len; + if buf_len > u32::MAX as usize { + return Err(vm.new_value_error("buffer too large".to_owned())); + } // For async read, buffer must be contiguous - we can't use a temporary copy // because Windows writes data directly to the buffer after this call returns @@ -627,12 +668,12 @@ mod _overlapped { match err { ERROR_BROKEN_PIPE => { mark_as_completed(&mut inner.overlapped); - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -694,12 +735,12 @@ mod _overlapped { match err { ERROR_BROKEN_PIPE => { mark_as_completed(&mut inner.overlapped); - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -726,6 +767,9 @@ mod _overlapped { let mut flags = flags; inner.handle = handle as HANDLE; let buf_len = buf.desc.len; + if buf_len > u32::MAX as usize { + return Err(vm.new_value_error("buffer too large".to_owned())); + } let Some(contiguous) = buf.as_contiguous_mut() else { return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); @@ -761,12 +805,12 @@ mod _overlapped { match err { ERROR_BROKEN_PIPE => { mark_as_completed(&mut inner.overlapped); - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -789,6 +833,9 @@ mod _overlapped { inner.handle = handle as HANDLE; let buf_len = buf.desc.len; + if buf_len > u32::MAX as usize { + return Err(vm.new_value_error("buffer too large".to_owned())); + } // For async write, buffer must be contiguous - we can't use a temporary copy // because Windows reads from the buffer after this call returns @@ -820,7 +867,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -844,6 +891,9 @@ mod _overlapped { inner.handle = handle as HANDLE; let buf_len = buf.desc.len; + if buf_len > u32::MAX as usize { + return Err(vm.new_value_error("buffer too large".to_owned())); + } let Some(contiguous) = buf.as_contiguous() else { return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); @@ -880,7 +930,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -896,8 +946,6 @@ mod _overlapped { use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; use windows_sys::Win32::Networking::WinSock::WSAGetLastError; - initialize_winsock_extensions(vm)?; - let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted".to_owned())); @@ -950,7 +998,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -960,14 +1008,12 @@ mod _overlapped { fn ConnectEx( zelf: &Py, socket: isize, - address: PyObjectRef, + address: PyTupleRef, vm: &VirtualMachine, ) -> PyResult { use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; use windows_sys::Win32::Networking::WinSock::WSAGetLastError; - initialize_winsock_extensions(vm)?; - let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted".to_owned())); @@ -1021,7 +1067,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -1037,8 +1083,6 @@ mod _overlapped { use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; use windows_sys::Win32::Networking::WinSock::WSAGetLastError; - initialize_winsock_extensions(vm)?; - let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted".to_owned())); @@ -1070,7 +1114,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -1092,8 +1136,6 @@ mod _overlapped { use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; use windows_sys::Win32::Networking::WinSock::WSAGetLastError; - initialize_winsock_extensions(vm)?; - let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { return Err(vm.new_value_error("operation already attempted".to_owned())); @@ -1140,7 +1182,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -1178,7 +1220,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(false), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -1190,7 +1232,7 @@ mod _overlapped { handle: isize, buf: PyBuffer, flags: u32, - address: PyObjectRef, + address: PyTupleRef, vm: &VirtualMachine, ) -> PyResult { use windows_sys::Win32::Foundation::{ERROR_IO_PENDING, ERROR_SUCCESS}; @@ -1205,6 +1247,9 @@ mod _overlapped { inner.handle = handle as HANDLE; let buf_len = buf.desc.len; + if buf_len > u32::MAX as usize { + return Err(vm.new_value_error("buffer too large".to_owned())); + } let Some(contiguous) = buf.as_contiguous() else { return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); @@ -1253,7 +1298,7 @@ mod _overlapped { ERROR_SUCCESS | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -1290,6 +1335,7 @@ mod _overlapped { let address_length = std::mem::size_of::() as i32; inner.data = OverlappedData::ReadFrom(OverlappedReadFrom { + result: None, allocated_buffer: buf.clone(), address, address_length, @@ -1334,12 +1380,12 @@ mod _overlapped { match err { ERROR_BROKEN_PIPE => { mark_as_completed(&mut inner.overlapped); - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -1371,16 +1417,16 @@ mod _overlapped { return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); }; - // Validate size against buffer length to prevent buffer overflow - let buf_len = buf.desc.len as u32; - if size > buf_len { - return Err(vm.new_value_error("size exceeds buffer length".to_owned())); + let buf_len = buf.desc.len; + if buf_len > u32::MAX as usize { + return Err(vm.new_value_error("buffer too large".to_owned())); } let address: SOCKADDR_IN6 = unsafe { std::mem::zeroed() }; let address_length = std::mem::size_of::() as i32; inner.data = OverlappedData::ReadFromInto(OverlappedReadFromInto { + result: None, user_buffer: buf.clone(), address, address_length, @@ -1425,12 +1471,12 @@ mod _overlapped { match err { ERROR_BROKEN_PIPE => { mark_as_completed(&mut inner.overlapped); - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } ERROR_SUCCESS | ERROR_MORE_DATA | ERROR_IO_PENDING => Ok(vm.ctx.none()), _ => { inner.data = OverlappedData::NotStarted; - Err(from_windows_err(err, vm)) + Err(set_from_windows_err(err, vm)) } } } @@ -1444,7 +1490,7 @@ mod _overlapped { if event == INVALID_HANDLE_VALUE { event = unsafe { - windows_sys::Win32::System::Threading::CreateEventA( + windows_sys::Win32::System::Threading::CreateEventW( core::ptr::null(), Foundation::TRUE, Foundation::FALSE, @@ -1452,7 +1498,7 @@ mod _overlapped { ) as isize }; if event == NULL { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } } @@ -1473,7 +1519,10 @@ mod _overlapped { } impl Destructor for Overlapped { - fn del(zelf: &Py, _vm: &VirtualMachine) -> PyResult<()> { + fn del(zelf: &Py, vm: &VirtualMachine) -> PyResult<()> { + use windows_sys::Win32::Foundation::{ + ERROR_NOT_FOUND, ERROR_OPERATION_ABORTED, ERROR_SUCCESS, + }; use windows_sys::Win32::System::IO::{CancelIoEx, GetOverlappedResult}; let mut inner = zelf.inner.lock(); @@ -1481,24 +1530,39 @@ mod _overlapped { // Cancel pending I/O and wait for completion if !HasOverlappedIoCompleted(&inner.overlapped) - && !matches!( - inner.data, - OverlappedData::None | OverlappedData::NotStarted - ) + && !matches!(inner.data, OverlappedData::NotStarted) { let cancelled = unsafe { CancelIoEx(inner.handle, &inner.overlapped) } != 0; + let mut transferred: u32 = 0; + let ret = unsafe { + GetOverlappedResult( + inner.handle, + &inner.overlapped, + &mut transferred, + if cancelled { 1 } else { 0 }, + ) + }; - if cancelled { - // Wait for the cancellation to complete - let mut transferred: u32 = 0; - unsafe { - GetOverlappedResult( - inner.handle, - &inner.overlapped, - &mut transferred, - 1, // bWait = TRUE - ) - }; + let err = if ret != 0 { + ERROR_SUCCESS + } else { + unsafe { GetLastError() } + }; + match err { + ERROR_SUCCESS | ERROR_NOT_FOUND | ERROR_OPERATION_ABORTED => {} + _ => { + let msg = format!( + "{:?} still has pending operation at deallocation, the process may crash", + zelf + ); + let exc = vm.new_runtime_error(msg); + let err_msg = Some(format!( + "Exception ignored while deallocating overlapped operation {:?}", + zelf + )); + let obj: PyObjectRef = zelf.to_owned().into(); + vm.run_unraisable(exc, err_msg, obj); + } } } @@ -1519,8 +1583,9 @@ mod _overlapped { #[pyfunction] fn ConnectPipe(address: String, vm: &VirtualMachine) -> PyResult { + use windows_sys::Win32::Foundation::{GENERIC_READ, GENERIC_WRITE}; use windows_sys::Win32::Storage::FileSystem::{ - CreateFileW, FILE_FLAG_OVERLAPPED, FILE_GENERIC_READ, FILE_GENERIC_WRITE, OPEN_EXISTING, + CreateFileW, FILE_FLAG_OVERLAPPED, OPEN_EXISTING, }; let address_wide: Vec = address.encode_utf16().chain(std::iter::once(0)).collect(); @@ -1528,7 +1593,7 @@ mod _overlapped { let handle = unsafe { CreateFileW( address_wide.as_ptr(), - FILE_GENERIC_READ | FILE_GENERIC_WRITE, + GENERIC_READ | GENERIC_WRITE, 0, std::ptr::null(), OPEN_EXISTING, @@ -1538,7 +1603,7 @@ mod _overlapped { }; if handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(handle as isize) @@ -1561,7 +1626,7 @@ mod _overlapped { ) as isize }; if r == 0 { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(r) } @@ -1589,7 +1654,7 @@ mod _overlapped { if err == Foundation::WAIT_TIMEOUT { return Ok(vm.ctx.none()); } else { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(err, vm)); } } @@ -1619,7 +1684,7 @@ mod _overlapped { ) }; if ret == 0 { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(()) } @@ -1707,7 +1772,7 @@ mod _overlapped { unsafe { let _ = std::sync::Arc::from_raw(data_ptr); } - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } // Store in registry for cleanup tracking @@ -1739,7 +1804,7 @@ mod _overlapped { // (callback may have already fired, or may never fire) cleanup_wait_callback_data(wait_handle); if ret == 0 { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(()) } @@ -1752,14 +1817,16 @@ mod _overlapped { // Cleanup callback data regardless of UnregisterWaitEx result cleanup_wait_callback_data(wait_handle); if ret == 0 { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(()) } #[pyfunction] fn BindLocal(socket: isize, family: i32, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::Networking::WinSock::{INADDR_ANY, SOCKET_ERROR, bind}; + use windows_sys::Win32::Networking::WinSock::{ + INADDR_ANY, SOCKET_ERROR, WSAGetLastError, bind, + }; let ret = if family == AF_INET as i32 { let mut addr: SOCKADDR_IN = unsafe { std::mem::zeroed() }; @@ -1786,11 +1853,12 @@ mod _overlapped { ) } } else { - return Err(vm.new_value_error("family must be AF_INET or AF_INET6".to_owned())); + return Err(vm.new_value_error("expected tuple of length 2 or 4".to_owned())); }; if ret == SOCKET_ERROR { - return Err(vm.new_last_os_error()); + let err = unsafe { WSAGetLastError() } as u32; + return Err(set_from_windows_err(err, vm)); } Ok(()) } @@ -1824,6 +1892,9 @@ mod _overlapped { }; if len == 0 || buffer.is_null() { + if !buffer.is_null() { + unsafe { LocalFree(buffer as *mut _) }; + } return Ok(format!("unknown error code {}", error_code)); } @@ -1837,8 +1908,8 @@ mod _overlapped { } #[pyfunction] - fn WSAConnect(socket: isize, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - use windows_sys::Win32::Networking::WinSock::{SOCKET_ERROR, WSAConnect}; + fn WSAConnect(socket: isize, address: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { + use windows_sys::Win32::Networking::WinSock::{SOCKET_ERROR, WSAConnect, WSAGetLastError}; let (addr_bytes, addr_len) = parse_address(&address, vm)?; @@ -1855,7 +1926,8 @@ mod _overlapped { }; if ret == SOCKET_ERROR { - return Err(vm.new_last_os_error()); + let err = unsafe { WSAGetLastError() } as u32; + return Err(set_from_windows_err(err, vm)); } Ok(()) } @@ -1888,7 +1960,7 @@ mod _overlapped { ) as isize }; if event == NULL { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(event) } @@ -1897,7 +1969,7 @@ mod _overlapped { fn SetEvent(handle: isize, vm: &VirtualMachine) -> PyResult<()> { let ret = unsafe { windows_sys::Win32::System::Threading::SetEvent(handle as HANDLE) }; if ret == 0 { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(()) } @@ -1906,7 +1978,7 @@ mod _overlapped { fn ResetEvent(handle: isize, vm: &VirtualMachine) -> PyResult<()> { let ret = unsafe { windows_sys::Win32::System::Threading::ResetEvent(handle as HANDLE) }; if ret == 0 { - return Err(vm.new_last_os_error()); + return Err(set_from_windows_err(0, vm)); } Ok(()) }