From ab5bc43359478860076811c0496c0cbee63afab8 Mon Sep 17 00:00:00 2001 From: Bas Schoenmaeckers <7943856+bschoenmaeckers@users.noreply.github.com> Date: Fri, 22 May 2026 13:00:16 +0200 Subject: [PATCH 01/18] Add float support to c-api (#7943) --- crates/capi/src/floatobject.rs | 41 ++++++++++++++++++++++++++++++++++ crates/capi/src/lib.rs | 1 + 2 files changed, 42 insertions(+) create mode 100644 crates/capi/src/floatobject.rs diff --git a/crates/capi/src/floatobject.rs b/crates/capi/src/floatobject.rs new file mode 100644 index 0000000000..f1bb078106 --- /dev/null +++ b/crates/capi/src/floatobject.rs @@ -0,0 +1,41 @@ +use crate::object::define_py_check; +use crate::{PyObject, pystate::with_vm}; +use core::ffi::c_double; +use rustpython_vm::builtins::PyFloat; + +define_py_check!(fn PyFloat_Check, types.float_type); +define_py_check!(exact fn PyFloat_CheckExact, types.float_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyFloat_FromDouble(value: c_double) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_float(value)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyFloat_AsDouble(obj: *mut PyObject) -> c_double { + with_vm(|vm| { + let obj_ref = unsafe { &*obj }; + let float_obj = obj_ref + .to_owned() + .try_downcast::(vm) + .or_else(|_| obj_ref.try_float(vm))?; + + Ok(float_obj.to_f64()) + }) +} + +#[cfg(false)] +mod tests { + use core::f64::consts::PI; + use pyo3::prelude::*; + use pyo3::types::PyFloat; + + #[test] + fn test_py_float() { + Python::attach(|py| { + let pi = PyFloat::new(py, PI); + assert!(pi.is_instance_of::()); + assert_eq!(pi.extract::().unwrap(), PI); + }) + } +} diff --git a/crates/capi/src/lib.rs b/crates/capi/src/lib.rs index baa4ddc6b7..6b2745d153 100644 --- a/crates/capi/src/lib.rs +++ b/crates/capi/src/lib.rs @@ -13,6 +13,7 @@ pub mod boolobject; pub mod bytesobject; pub mod ceval; pub mod dictobject; +pub mod floatobject; pub mod import; pub mod longobject; pub mod object; From 4eb9534646774b723be4a042734dc893a2e414a8 Mon Sep 17 00:00:00 2001 From: Bas Schoenmaeckers <7943856+bschoenmaeckers@users.noreply.github.com> Date: Fri, 22 May 2026 13:00:46 +0200 Subject: [PATCH 02/18] Add complex number support to c-api (#7945) --- Cargo.lock | 1 + crates/capi/Cargo.toml | 1 + crates/capi/src/complexobject.rs | 52 ++++++++++++++++++++++++++++++++ crates/capi/src/lib.rs | 1 + 4 files changed, 55 insertions(+) create mode 100644 crates/capi/src/complexobject.rs diff --git a/Cargo.lock b/Cargo.lock index b1fe77d955..eab7707b5b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3231,6 +3231,7 @@ dependencies = [ name = "rustpython-capi" version = "0.5.0" dependencies = [ + "num-complex", "pyo3", "rustpython-stdlib", "rustpython-vm", diff --git a/crates/capi/Cargo.toml b/crates/capi/Cargo.toml index 85f42934a9..8c54a8b26a 100644 --- a/crates/capi/Cargo.toml +++ b/crates/capi/Cargo.toml @@ -12,6 +12,7 @@ license.workspace = true crate-type = ["cdylib", "rlib"] [dependencies] +num-complex = { workspace = true } rustpython-vm = { workspace = true, features = ["threading", "compiler"] } rustpython-stdlib = {workspace = true, features = ["threading"] } diff --git a/crates/capi/src/complexobject.rs b/crates/capi/src/complexobject.rs new file mode 100644 index 0000000000..a6b2bb731a --- /dev/null +++ b/crates/capi/src/complexobject.rs @@ -0,0 +1,52 @@ +use crate::object::define_py_check; +use crate::{PyObject, pystate::with_vm}; +use core::ffi::c_double; +use num_complex::{Complex, Complex64}; +use rustpython_vm::builtins::PyComplex; +use rustpython_vm::{PyResult, VirtualMachine}; + +define_py_check!(fn PyComplex_Check, types.complex_type); +define_py_check!(exact fn PyComplex_CheckExact, types.complex_type); + +#[unsafe(no_mangle)] +pub extern "C" fn PyComplex_FromDoubles(real: c_double, imag: c_double) -> *mut PyObject { + with_vm(|vm| vm.ctx.new_complex(Complex::new(real, imag))) +} + +fn try_to_complex(vm: &VirtualMachine, obj: &PyObject) -> PyResult { + obj.try_downcast_ref::(vm).map_or_else( + |type_err| { + if let Some((complex, _)) = obj.to_owned().try_complex(vm)? { + Ok(complex) + } else { + Err(type_err) + } + }, + |complex| Ok(complex.to_complex()), + ) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyComplex_RealAsDouble(obj: *mut PyObject) -> c_double { + with_vm(|vm| try_to_complex(vm, unsafe { &*obj }).map(|complex| complex.re)) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyComplex_ImagAsDouble(obj: *mut PyObject) -> c_double { + with_vm(|vm| try_to_complex(vm, unsafe { &*obj }).map(|complex| complex.im)) +} + +#[cfg(false)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::PyComplex; + + #[test] + fn test_py_int() { + Python::attach(|py| { + let number = PyComplex::from_doubles(py, 1.0, 2.0); + assert_eq!(number.real(), 1.0); + assert_eq!(number.imag(), 2.0); + }) + } +} diff --git a/crates/capi/src/lib.rs b/crates/capi/src/lib.rs index 6b2745d153..7983a36a7e 100644 --- a/crates/capi/src/lib.rs +++ b/crates/capi/src/lib.rs @@ -12,6 +12,7 @@ pub mod abstract_; pub mod boolobject; pub mod bytesobject; pub mod ceval; +pub mod complexobject; pub mod dictobject; pub mod floatobject; pub mod import; From 2a163609cd54b77ded95392ddb61d6ebffd2f1ca Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Fri, 22 May 2026 11:04:15 +0000 Subject: [PATCH 03/18] Rustls integration improvements (#7946) * Do not call `import socket` on each send()/recv() when using rustls Use method references cached during socket creation. * Implement reading of at most one TLS record from socket Previous algorithm didn't take into account that recv() may return less data than requested even for blocking sockets. * Remove special handling of rustls "buffer full" errors First of all, existing code does not really work and this leads to an infinite loop: https://github.com/RustPython/RustPython/issues/7891 Second, this should never happen when rustls used properly (wrt wants_read() and wants_write()) and thus all such errors are implementation bugs that must be properly fixed. * Replace own TlsConnection with rustls::Connection * Fix waiting on a socket 1) Ensure that socket_wait() called from TLS glue code allows threads 2) Ensure that socket_wait() called from TLS glue code properly handles EINTR on *nix 3) Ensure that select() or poll() error conditions are checked 4) Use poll() on *nix so socket descriptor values are not limited * Remove dead code from rustls glue * Do not present rustls errors as OSError(0, "Success") * Remove infinite loop "detection" from rustls glue TLS handshake cannot be infinite. Any infinite loop here is a serious bug in implementation and should be fixed properly. This code triggers in some cases (very short reads) with misleading `ssl_error.SSLWantReadError: The operation did not complete (read)`. * Add test for 1-byte max recv in TLS client * Add regression test for https://github.com/RustPython/RustPython/issues/7891 * Fix constants in rustls glue code * Deduplicate verify flags / record-size constants * Larger "max encrypted TLS record length" --- crates/host_env/src/socket.rs | 33 -- crates/stdlib/src/socket.rs | 204 +++++---- crates/stdlib/src/ssl.rs | 342 ++++++++------- crates/stdlib/src/ssl/cert.rs | 2 +- crates/stdlib/src/ssl/compat.rs | 404 ++++-------------- crates/stdlib/src/ssl/error.rs | 2 - extra_tests/snippets/stdlib_ssl_short_recv.py | 88 ++++ .../stdlib_urllib_https_misaligned_recv.py | 246 +++++++++++ 8 files changed, 737 insertions(+), 584 deletions(-) create mode 100644 extra_tests/snippets/stdlib_ssl_short_recv.py create mode 100644 extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py diff --git a/crates/host_env/src/socket.rs b/crates/host_env/src/socket.rs index d6f1e078c7..c409132a72 100644 --- a/crates/host_env/src/socket.rs +++ b/crates/host_env/src/socket.rs @@ -3,20 +3,10 @@ use crate::os::CheckLibcResult; #[cfg(unix)] use core::ffi::CStr; #[cfg(unix)] -use core::time::Duration; -#[cfg(unix)] use std::os::fd::AsRawFd; #[cfg(unix)] use std::{io, os::fd::BorrowedFd}; -#[cfg(unix)] -#[derive(Copy, Clone)] -pub enum PollKind { - Read, - Write, - Connect, -} - #[cfg(all(unix, not(target_os = "redox")))] pub fn sethostname(hostname: &str) -> io::Result<()> { nix::unistd::sethostname(hostname).map_err(io::Error::from) @@ -111,29 +101,6 @@ pub fn setsockopt_none(fd: libc::c_int, level: i32, name: i32, optlen: u32) -> i Ok(()) } -#[cfg(unix)] -pub fn poll_socket( - fd: BorrowedFd<'_>, - kind: PollKind, - interval: Option, -) -> io::Result { - use nix::poll::{PollFd, PollFlags, PollTimeout, poll}; - - let events = match kind { - PollKind::Read => PollFlags::POLLIN, - PollKind::Write => PollFlags::POLLOUT, - PollKind::Connect => PollFlags::POLLOUT | PollFlags::POLLERR, - }; - let mut pollfd = [PollFd::new(fd, events)]; - let timeout = match interval { - Some(d) => d.try_into().unwrap_or(PollTimeout::MAX), - None => PollTimeout::NONE, - }; - poll(&mut pollfd, timeout) - .map(|ret| ret == 0) - .map_err(io::Error::from) -} - #[cfg(any( target_os = "dragonfly", target_os = "freebsd", diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index 55aae79df6..6aff5d452c 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -3,7 +3,7 @@ pub(crate) use _socket::module_def; #[cfg(feature = "ssl")] -pub(super) use _socket::{PySocket, SelectKind, sock_select, timeout_error_msg}; +pub(super) use _socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg}; #[pymodule] mod _socket { @@ -1060,20 +1060,20 @@ mod _socket { fn sock_op( &self, vm: &VirtualMachine, - select: SelectKind, + wait_kind: SockWaitKind, f: F, ) -> Result where F: FnMut() -> io::Result, { let timeout = self.get_timeout().ok(); - self.sock_op_timeout_err(vm, select, timeout, f) + self.sock_op_timeout_err(vm, wait_kind, timeout, f) } fn sock_op_timeout_err( &self, vm: &VirtualMachine, - select: SelectKind, + wait_kind: SockWaitKind, timeout: Option, mut f: F, ) -> Result @@ -1083,19 +1083,9 @@ mod _socket { let deadline = timeout.map(Deadline::new); loop { - if deadline.is_some() || matches!(select, SelectKind::Connect) { - let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; + if deadline.is_some() || matches!(wait_kind, SockWaitKind::Connect) { let sock = self.sock()?; - let res = vm.allow_threads(|| sock_select(&sock, select, interval)); - match res { - Ok(true) => return Err(IoOrPyException::Timeout), - Err(e) if e.kind() == io::ErrorKind::Interrupted => { - vm.check_signals()?; - continue; - } - Err(e) => return Err(e.into()), - Ok(false) => {} // no timeout, continue as normal - } + sock_wait_deadline(&sock, wait_kind, &deadline, vm)?; } let err = loop { @@ -1339,16 +1329,11 @@ mod _socket { }; if wait_connect { - // basically, connect() is async, and it registers an "error" on the socket when it's - // done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up - // from poll and the error is EISCONN then we know that the connect is done - self.sock_op(vm, SelectKind::Connect, || { + self.sock_op(vm, SockWaitKind::Connect, || { let sock = self.sock()?; let err = sock.take_error()?; match err { - Some(e) if e.posix_errno() == libc::EISCONN => Ok(()), Some(e) => Err(e), - // TODO: is this accurate? None => Ok(()), } }) @@ -1587,7 +1572,8 @@ mod _socket { ) -> Result<(RawSocket, PyObjectRef), IoOrPyException> { // Use accept_raw() instead of accept() to avoid socket2's set_common_flags() // which tries to set SO_NOSIGPIPE and fails with EINVAL on Unix domain sockets on macOS - let (sock, addr) = self.sock_op(vm, SelectKind::Read, || self.sock()?.accept_raw())?; + let (sock, addr) = + self.sock_op(vm, SockWaitKind::Read, || self.sock()?.accept_raw())?; let fd = into_sock_fileno(sock); Ok((fd, get_addr_tuple(&addr, vm))) } @@ -1602,7 +1588,7 @@ mod _socket { let flags = flags.unwrap_or(0); let mut buffer = Vec::with_capacity(bufsize); let sock = self.sock()?; - let n = self.sock_op(vm, SelectKind::Read, || { + let n = self.sock_op(vm, SockWaitKind::Read, || { sock.recv_with_flags(buffer.spare_capacity_mut(), flags) })?; unsafe { buffer.set_len(n) }; @@ -1633,7 +1619,7 @@ mod _socket { }; let buf = &mut buf[..read_len]; - self.sock_op(vm, SelectKind::Read, || { + self.sock_op(vm, SockWaitKind::Read, || { sock.recv_with_flags(unsafe { slice_as_uninit(buf) }, flags) }) } @@ -1650,7 +1636,7 @@ mod _socket { .to_usize() .ok_or_else(|| vm.new_value_error("negative buffersize in recvfrom"))?; let mut buffer = Vec::with_capacity(bufsize); - let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || { self.sock()? .recv_from_with_flags(buffer.spare_capacity_mut(), flags) })?; @@ -1681,7 +1667,7 @@ mod _socket { }; let flags = flags.unwrap_or(0); let sock = self.sock()?; - let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + let (n, addr) = self.sock_op(vm, SockWaitKind::Read, || { sock.recv_from_with_flags(unsafe { slice_as_uninit(buf) }, flags) })?; Ok((n, get_addr_tuple(&addr, vm))) @@ -1697,7 +1683,7 @@ mod _socket { let flags = flags.unwrap_or(0); let buf = bytes.borrow_buf(); let buf = &*buf; - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { self.sock()?.send_with_flags(buf, flags) }) } @@ -1721,7 +1707,7 @@ mod _socket { // now we have like 3 layers of interrupt loop :) while buf_offset < buf.len() { let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; - self.sock_op_timeout_err(vm, SelectKind::Write, interval, || { + self.sock_op_timeout_err(vm, SockWaitKind::Write, interval, || { let subbuf = &buf[buf_offset..]; buf_offset += self.sock()?.send_with_flags(subbuf, flags)?; Ok(()) @@ -1754,7 +1740,7 @@ mod _socket { let addr = self.extract_address(address, "sendto", vm)?; let buf = bytes.borrow_buf(); let buf = &*buf; - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { self.sock()?.send_to_with_flags(buf, &addr, flags) }) } @@ -1812,7 +1798,7 @@ mod _socket { } } - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { let sock = self.sock()?; sock.sendmsg(&msg, flags) }) @@ -1848,7 +1834,7 @@ mod _socket { .collect::>(); let iv = iv.map(|iv| iv.borrow_buf().to_vec()); - self.sock_op(vm, SelectKind::Write, || { + self.sock_op(vm, SockWaitKind::Write, || { let sock = self.sock()?; let fd = unsafe { BorrowedFd::borrow_raw(sock_fileno(&sock)) }; host_socket::sendmsg_afalg(fd, &buffers, op, iv.as_deref(), assoclen, flags) @@ -1881,7 +1867,7 @@ mod _socket { let flags = flags.unwrap_or(0); let msg = self - .sock_op(vm, SelectKind::Read, || { + .sock_op(vm, SockWaitKind::Read, || { let sock = self.sock()?; let fd = unsafe { std::os::fd::BorrowedFd::borrow_raw(sock_fileno(&sock)) }; host_socket::recvmsg(fd, bufsize, ancbufsize, flags) @@ -2436,61 +2422,135 @@ mod _socket { } #[derive(Copy, Clone)] - pub(crate) enum SelectKind { + pub(crate) enum SockWaitKind { Read, Write, Connect, } - /// returns true if timed out - pub(crate) fn sock_select( + /// returns Ok(true) on timeout + pub(crate) fn sock_wait( + sock: &Socket, + wait_kind: SockWaitKind, + timeout: Option, + vm: &VirtualMachine, + ) -> PyResult { + match sock_wait_deadline(sock, wait_kind, &timeout.map(Deadline::new), vm) { + Ok(()) => Ok(false), + Err(IoOrPyException::Timeout) => Ok(true), + Err(e) => Err(e.into_pyexception(vm)), + } + } + + /// returns Err(IoOrPyException::Timeout) on timeout + fn sock_wait_deadline( sock: &Socket, - kind: SelectKind, - interval: Option, - ) -> io::Result { + wait_kind: SockWaitKind, + deadline: &Option, + vm: &VirtualMachine, + ) -> Result<(), IoOrPyException> { #[cfg(unix)] { - use std::os::fd::AsFd; - let kind = match kind { - SelectKind::Read => host_socket::PollKind::Read, - SelectKind::Write => host_socket::PollKind::Write, - SelectKind::Connect => host_socket::PollKind::Connect, - }; - host_socket::poll_socket(sock.as_fd(), kind, interval) + use rustpython_host_env::select::{PollFd, poll_fds}; + + let mut events = 0; + if matches!(wait_kind, SockWaitKind::Read) { + events |= libc::POLLIN | libc::POLLPRI; + } + if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) { + events |= libc::POLLOUT; + } + let mut fds = [PollFd { + fd: sock_fileno(sock), + events, + revents: 0, + }; 1]; + + loop { + let (timeout, is_capped) = deadline + .as_ref() + .map(|d| { + d.time_until().map(|t| { + let timeout_ms = t.as_millis(); + let is_capped = timeout_ms > i32::MAX as u128; + let timeout = if is_capped { + i32::MAX + } else { + timeout_ms as i32 + }; + (timeout, is_capped) + }) + }) + .transpose()? + .unwrap_or((-1, false)); + + match vm.allow_threads(|| poll_fds(&mut fds, timeout)) { + Ok(0) => { + if is_capped { + continue; + } + break Err(IoOrPyException::Timeout); + } + + Ok(_) => { + if fds[0].revents & libc::POLLNVAL != 0 { + break Err(io::Error::from_raw_os_error(libc::EBADF).into()); + } + break Ok(()); + } + + Err(e) => { + if e.kind() == io::ErrorKind::Interrupted { + vm.check_signals()?; + continue; + } + break Err(e.into()); + } + } + } } #[cfg(windows)] { - use rustpython_host_env::select as host_select; + use rustpython_host_env::select::{FdSet, select, timeval}; - let fd = sock_fileno(sock); + let fd = sock_fileno(sock) as usize; - let mut reads = host_select::FdSet::new(); - let mut writes = host_select::FdSet::new(); - let mut errs = host_select::FdSet::new(); + let mut reads = FdSet::new(); + let mut writes = FdSet::new(); + let mut errs = FdSet::new(); - let fd = fd as usize; - match kind { - SelectKind::Read => reads.insert(fd), - SelectKind::Write => writes.insert(fd), - SelectKind::Connect => { - writes.insert(fd); - errs.insert(fd); - } + if matches!(wait_kind, SockWaitKind::Read) { + reads.insert(fd); + errs.insert(fd); + } + if matches!(wait_kind, SockWaitKind::Write | SockWaitKind::Connect) { + writes.insert(fd); + errs.insert(fd); } - let mut interval = interval.map(|dur| host_select::timeval { - tv_sec: dur.as_secs() as _, - tv_usec: dur.subsec_micros() as _, - }); - - host_select::select( - fd as i32 + 1, - &mut reads, - &mut writes, - &mut errs, - interval.as_mut(), - ) - .map(|ret| ret == 0) + let mut timeout = deadline + .as_ref() + .map(|d| { + d.time_until().map(|dur| timeval { + tv_sec: dur.as_secs() as _, + tv_usec: dur.subsec_micros() as _, + }) + }) + .transpose()?; + + match vm.allow_threads(|| { + select( + 0, // nfds is ignored on windows + &mut reads, + &mut writes, + &mut errs, + timeout.as_mut(), + ) + }) { + Ok(0) => Err(IoOrPyException::Timeout), + Ok(_) => Ok(()), + Err(e) => Err(e.into()), + } } } diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 6e06e4e9ef..8548220380 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -36,13 +36,13 @@ mod _ssl { hash::PyHash, lock::{PyMutex, PyRwLock}, }, - socket::{PySocket, SelectKind, sock_select, timeout_error_msg}, + socket::{PySocket, SockWaitKind, sock_wait, timeout_error_msg}, vm::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef, - PyUtf8StrRef, + PyBaseExceptionRef, PyByteArray, PyBytesRef, PyListRef, PyStrRef, PyType, + PyTypeRef, PyUtf8StrRef, }, convert::IntoPyException, function::{ @@ -75,7 +75,7 @@ mod _ssl { use parking_lot::{Mutex as ParkingMutex, RwLock as ParkingRwLock}; use pem_rfc7468::{LineEnding, encode_string}; use rustls::{ - ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection, + ClientConnection, Connection, HandshakeKind, RootCertStore, ServerConfig, ServerConnection, client::{ClientSessionMemoryCache, ClientSessionStore}, crypto::SupportedKxGroup, pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName}, @@ -94,9 +94,8 @@ mod _ssl { // Import compat module (OpenSSL compatibility layer) use super::compat::{ ClientConfigOptions, MultiCertResolver, ProtocolSettings, ServerConfigOptions, SslError, - TlsConnection, create_client_config, create_server_config, curve_name_to_kx_group, - extract_cipher_info, get_cipher_encryption_desc, is_blocking_io_error, - normalize_cipher_name, ssl_do_handshake, + create_client_config, create_server_config, curve_name_to_kx_group, extract_cipher_info, + get_cipher_encryption_desc, is_blocking_io_error, normalize_cipher_name, ssl_do_handshake, }; // Type aliases for better readability @@ -154,8 +153,28 @@ mod _ssl { // Buffer sizes and limits (OpenSSL/CPython compatibility) const PEM_BUFSIZE: usize = 1024; + // OpenSSL: ssl/ssl_local.h + const SSL3_RT_HEADER_LENGTH: usize = 5; + // This is the maximum MAC (digest) size used by the SSL library. Currently + // maximum of 20 is used by SHA1, but we reserve for future extension for + // 512-bit hashes. + const SSL3_RT_MAX_MD_SIZE: usize = 64; + // Maximum plaintext length: defined by SSL/TLS standards const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; + // Maximum compression overhead: defined by SSL/TLS standards + const SSL3_RT_MAX_COMPRESSED_OVERHEAD: usize = 1024; + // The standards give a maximum encryption overhead of 1024 bytes. In + // practice the value is lower than this. The overhead is the maximum number + // of padding bytes (256) plus the mac size. + const SSL3_RT_MAX_ENCRYPTED_OVERHEAD: usize = 256 + SSL3_RT_MAX_MD_SIZE; + const SSL3_RT_MAX_COMPRESSED_LENGTH: usize = + SSL3_RT_MAX_PLAIN_LENGTH + SSL3_RT_MAX_COMPRESSED_OVERHEAD; + const SSL3_RT_MAX_ENCRYPTED_LENGTH: usize = + SSL3_RT_MAX_ENCRYPTED_OVERHEAD + SSL3_RT_MAX_COMPRESSED_LENGTH; + pub(crate) const SSL3_RT_MAX_PACKET_SIZE: usize = + SSL3_RT_MAX_ENCRYPTED_LENGTH + SSL3_RT_HEADER_LENGTH; + // SSL session cache size (common practice, similar to OpenSSL defaults) const SSL_SESSION_CACHE_SIZE: usize = 256; @@ -167,21 +186,26 @@ mod _ssl { #[pyattr] const CERT_REQUIRED: i32 = 2; - // Certificate requirements + // SSL Verification Flags / Certificate requirements #[pyattr] const VERIFY_DEFAULT: i32 = 0; #[pyattr] const VERIFY_CRL_CHECK_LEAF: i32 = 4; #[pyattr] const VERIFY_CRL_CHECK_CHAIN: i32 = 12; + /// VERIFY_X509_STRICT flag for RFC 5280 strict compliance + /// When set, performs additional validation including AKI extension checks #[pyattr] - const VERIFY_X509_STRICT: i32 = 32; + pub(crate) const VERIFY_X509_STRICT: i32 = 32; #[pyattr] const VERIFY_ALLOW_PROXY_CERTS: i32 = 64; #[pyattr] const VERIFY_X509_TRUSTED_FIRST: i32 = 32768; + /// VERIFY_X509_PARTIAL_CHAIN flag for partial chain validation + /// When set, accept certificates if any certificate in the chain is in the trust store + /// (not just root CAs). This matches OpenSSL's X509_V_FLAG_PARTIAL_CHAIN behavior. #[pyattr] - const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; + pub(crate) const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; // Options (OpenSSL-compatible flags, mostly no-op in rustls) #[pyattr] @@ -398,8 +422,7 @@ mod _ssl { // Session data structure for tracking TLS sessions #[derive(Debug, Clone)] struct SessionData { - #[allow(dead_code)] - server_name: String, + _server_name: String, session_id: Vec, creation_time: SystemTime, lifetime: u64, @@ -477,7 +500,7 @@ mod _ssl { let creation_time = SystemTime::now(); let server_name_str = server_name.to_str(); let session_data = SessionData { - server_name: server_name_str.as_ref().to_string(), + _server_name: server_name_str.as_ref().to_string(), session_id: generate_session_id_from_metadata( server_name_str.as_ref(), &creation_time, @@ -521,7 +544,7 @@ mod _ssl { let creation_time = SystemTime::now(); let server_name_str = server_name.to_str(); let session_data = SessionData { - server_name: server_name_str.to_string(), + _server_name: server_name_str.to_string(), session_id: generate_session_id_from_metadata( server_name_str.as_ref(), &creation_time, @@ -720,10 +743,6 @@ mod _ssl { #[pytraverse(skip)] verify_flags: PyRwLock, // Rustls configuration (built lazily) - #[allow(dead_code)] - #[pytraverse(skip)] - client_config: PyRwLock>>, - #[allow(dead_code)] #[pytraverse(skip)] server_config: PyRwLock>>, // Certificate store @@ -747,19 +766,11 @@ mod _ssl { #[pytraverse(skip)] cert_keys: PyRwLock>, // Options - #[allow(dead_code)] #[pytraverse(skip)] options: PyRwLock, // ALPN protocols - #[allow(dead_code)] #[pytraverse(skip)] alpn_protocols: PyRwLock>>, - // ALPN strict matching flag - // When false (default), mimics OpenSSL behavior: no ALPN negotiation failure - // When true, requires ALPN match (Rustls default behavior) - #[allow(dead_code)] - #[pytraverse(skip)] - require_alpn_match: PyRwLock, // TLS 1.3 features #[pytraverse(skip)] post_handshake_auth: PyRwLock, @@ -1825,6 +1836,9 @@ mod _ssl { args: WrapSocketArgs, vm: &VirtualMachine, ) -> PyResult> { + let socket_mod = vm.import("socket", 0)?; + let socket_class = socket_mod.get_attr("socket", vm)?; + // Convert server_hostname to Option // Handle both missing argument and None value let hostname = match args.server_hostname.into_option().flatten() { @@ -1877,6 +1891,12 @@ mod _ssl { // Create _SSLSocket instance let ssl_socket = PySSLSocket { sock: args.sock.clone(), + sock_send_method: socket_class.get_attr("send", vm)?, + sock_recv_method: socket_class.get_attr("recv", vm)?, + tls_record_header_buf: vm + .ctx + .new_bytearray(Vec::with_capacity(TLS_RECORD_HEADER_SIZE)) + .into(), context: PyRwLock::new(zelf), server_side: args.server_side, server_hostname: PyRwLock::new(hostname), @@ -1886,12 +1906,12 @@ mod _ssl { owner: PyRwLock::new(args.owner.into_option()), // Filter out Python None objects - only store actual SSLSession objects session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), - verified_chain: PyRwLock::new(None), incoming_bio: None, outgoing_bio: None, sni_state: PyRwLock::new(None), pending_context: PyRwLock::new(None), client_hello_buffer: PyMutex::new(None), + sni_callback_processed: PyMutex::new(false), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), write_buffered_len: PyMutex::new(0), @@ -1948,7 +1968,12 @@ mod _ssl { // Create _SSLSocket instance with BIO mode let ssl_socket = PySSLSocket { - sock: vm.ctx.none(), // No socket in BIO mode + // No socket in BIO mode + sock: vm.ctx.none(), + sock_send_method: vm.ctx.none(), + sock_recv_method: vm.ctx.none(), + + tls_record_header_buf: vm.ctx.none(), context: PyRwLock::new(zelf), server_side, server_hostname: PyRwLock::new(hostname), @@ -1958,12 +1983,12 @@ mod _ssl { owner: PyRwLock::new(args.owner.into_option()), // Filter out Python None objects - only store actual SSLSession objects session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))), - verified_chain: PyRwLock::new(None), incoming_bio: Some(args.incoming), outgoing_bio: Some(args.outgoing), sni_state: PyRwLock::new(None), pending_context: PyRwLock::new(None), client_hello_buffer: PyMutex::new(None), + sni_callback_processed: PyMutex::new(false), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), write_buffered_len: PyMutex::new(0), @@ -2261,7 +2286,6 @@ mod _ssl { check_hostname: PyRwLock::new(protocol == PROTOCOL_TLS_CLIENT), verify_mode: PyRwLock::new(default_verify_mode), verify_flags: PyRwLock::new(default_verify_flags), - client_config: PyRwLock::new(None), server_config: PyRwLock::new(None), root_certs: PyRwLock::new(RootCertStore::empty()), ca_certs_der: PyRwLock::new(Vec::new()), @@ -2270,7 +2294,6 @@ mod _ssl { cert_keys: PyRwLock::new(Vec::new()), options: PyRwLock::new(default_options), alpn_protocols: PyRwLock::new(Vec::new()), - require_alpn_match: PyRwLock::new(false), post_handshake_auth: PyRwLock::new(false), num_tickets: PyRwLock::new(2), // TLS 1.3 default minimum_version: PyRwLock::new(min_version), @@ -2302,6 +2325,15 @@ mod _ssl { pub(crate) struct PySSLSocket { // Underlying socket sock: PyObjectRef, + // Cached socket.socket.send + #[pytraverse(skip)] + sock_send_method: PyObjectRef, + // Cached socket.socket.recv + #[pytraverse(skip)] + sock_recv_method: PyObjectRef, + // Header of currently read TLS record. + #[pytraverse(skip)] + tls_record_header_buf: PyObjectRef, // SSL context context: PyRwLock>, // Server-side or client-side @@ -2312,7 +2344,7 @@ mod _ssl { server_hostname: PyRwLock>, // TLS connection state #[pytraverse(skip)] - connection: PyMutex>, + connection: PyMutex>, // Handshake completed flag #[pytraverse(skip)] handshake_done: PyMutex, @@ -2323,10 +2355,6 @@ mod _ssl { owner: PyRwLock>, // Session for resumption session: PyRwLock>, - // Verified certificate chain (built during verification) - #[allow(dead_code)] - #[pytraverse(skip)] - verified_chain: PyRwLock>>>, // MemoryBIO mode (optional) incoming_bio: Option>, outgoing_bio: Option>, @@ -2338,6 +2366,9 @@ mod _ssl { // Buffer to store ClientHello for connection recreation #[pytraverse(skip)] client_hello_buffer: PyMutex>>, + // Whether the Python SNI callback has already been run for this handshake + #[pytraverse(skip)] + sni_callback_processed: PyMutex, // Shutdown state for tracking close-notify exchange #[pytraverse(skip)] shutdown_state: PyMutex, @@ -2367,6 +2398,9 @@ mod _ssl { Completed, // unwrap() completed successfully } + /// TLS record header size (content_type + version + length). + const TLS_RECORD_HEADER_SIZE: usize = 5; + #[pyclass(with(Constructor, Representable), flags(BASETYPE))] impl PySSLSocket { // Check if this is BIO mode @@ -2547,7 +2581,7 @@ mod _ssl { .connection .lock() .as_ref() - .is_some_and(|conn| conn.is_session_resumed()); + .is_some_and(|conn| conn.handshake_kind() == Some(HandshakeKind::Resumed)); *self.session_was_reused.lock() = was_resumed; @@ -2575,7 +2609,7 @@ mod _ssl { // Internal implementation with timeout control pub(crate) fn sock_wait_for_io_impl( &self, - kind: SelectKind, + wait_kind: SockWaitKind, vm: &VirtualMachine, ) -> PyResult { if self.is_bio_mode() { @@ -2600,16 +2634,13 @@ mod _ssl { .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, kind, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; - - Ok(timed_out) + sock_wait(&socket, wait_kind, timeout, vm) } // Internal implementation with explicit timeout override pub(crate) fn sock_wait_for_io_with_timeout( &self, - kind: SelectKind, + wait_kind: SockWaitKind, timeout: Option, vm: &VirtualMachine, ) -> PyResult { @@ -2630,19 +2661,16 @@ mod _ssl { .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, kind, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; - - Ok(timed_out) + sock_wait(&socket, wait_kind, timeout, vm).map_err(|e| e.into_pyexception(vm)) } // SNI (Server Name Indication) Helper Methods: // These methods support the server-side handshake SNI callback mechanism /// Check if this is the first read during handshake (for SNI callback) - /// Returns true if we haven't processed ClientHello yet, regardless of SNI presence + /// Returns true until the SNI callback has been processed. pub(crate) fn is_first_sni_read(&self) -> bool { - self.client_hello_buffer.lock().is_none() + !*self.sni_callback_processed.lock() } /// Check if SNI callback is configured @@ -2651,9 +2679,13 @@ mod _ssl { self.context.read().sni_callback.read().is_some() } - /// Save ClientHello data from PyObjectRef for potential connection recreation + /// Save ClientHello data for potential connection recreation. pub(crate) fn save_client_hello_from_bytes(&self, bytes_data: &[u8]) { - *self.client_hello_buffer.lock() = Some(bytes_data.to_vec()); + let mut buffer = self.client_hello_buffer.lock(); + match buffer.as_mut() { + Some(existing) => existing.extend_from_slice(bytes_data), + None => *buffer = Some(bytes_data.to_vec()), + } } /// Get the extracted SNI name from resolver @@ -2771,23 +2803,85 @@ mod _ssl { return read_method.call((vm.ctx.new_int(size),), vm); } - // Normal socket mode - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - - // Call socket.socket.recv(self.sock, size, flags) - let recv_method = socket_class.get_attr("recv", vm)?; - recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm) + self.sock_recv_method + .call((self.sock.clone(), vm.ctx.new_int(size)), vm) } - /// Peek at socket data without consuming it (MSG_PEEK). - /// Used during TLS shutdown to avoid consuming post-TLS cleartext data. - pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult { - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - let recv_method = socket_class.get_attr("recv", vm)?; - let msg_peek = socket_mod.get_attr("MSG_PEEK", vm)?; - recv_method.call((self.sock.clone(), vm.ctx.new_int(size), msg_peek), vm) + // Helper to receive data for at most one TLS record. + // May return incomplete data but never returns more when completes a + // previously incomplete TLS record. + pub(crate) fn sock_recv_at_most_one_tls_record( + &self, + vm: &VirtualMachine, + ) -> PyResult { + let obj_to_bytes = |bytes_obj| { + PyBytesRef::try_from_object(vm, bytes_obj) + .map_err(|_| vm.new_os_error("Expected bytes from recv".to_string())) + }; + + let tls_record_header_buf = self + .tls_record_header_buf + .clone() + .downcast::() + .expect("BUG: tls_record_header_buf is not PyByteArray"); + + let buf_len = tls_record_header_buf.borrow_buf().len(); + + let (mut with_header, mut remaining_record_body_len) = + if buf_len < TLS_RECORD_HEADER_SIZE { + // We do not have a full TLS record header, start receiving one. + let bytes_obj = self.sock_recv(TLS_RECORD_HEADER_SIZE - buf_len, vm)?; + let bytes = obj_to_bytes(bytes_obj)?; + + let mut buf = tls_record_header_buf.borrow_buf_mut(); + buf.extend_from_slice(bytes.as_bytes()); + + if buf.len() < TLS_RECORD_HEADER_SIZE { + return Ok(bytes); + } + + // Parse the remaining length. + let record_body_len = u16::from_be_bytes([buf[3], buf[4]]); + // Validity of length value will be checked by rustls. + + // Zero-length TLS record. + if record_body_len == 0 { + buf.clear(); + return Ok(bytes); + } + + let mut bytes_vec = bytes.as_bytes().to_vec(); + bytes_vec.reserve(record_body_len as usize); + (Some(bytes_vec), record_body_len) + } else { + let buf = tls_record_header_buf.borrow_buf(); + let remaining_record_body_len = u16::from_be_bytes([buf[3], buf[4]]); + (None, remaining_record_body_len) + }; + + // We have full record header and are in a process of receiving a record. + let bytes_obj = self.sock_recv(remaining_record_body_len as usize, vm)?; + let bytes = obj_to_bytes(bytes_obj)?; + + if let Some(with_header) = with_header.as_mut() { + with_header.extend_from_slice(bytes.as_bytes()); + } + + let mut buf = tls_record_header_buf.borrow_buf_mut(); + remaining_record_body_len -= bytes.len() as u16; + if remaining_record_body_len == 0 { + // Record received completely, need to start a new one beginning with its header. + buf.clear(); + } else { + // Update remaining length in the header. + buf.as_mut_slice()[3..5].copy_from_slice(&remaining_record_body_len.to_be_bytes()); + } + + if let Some(with_header) = with_header { + Ok(vm.ctx.new_bytes(with_header)) + } else { + Ok(bytes) + } } /// Socket send - just sends data, caller must handle pending flush @@ -2800,13 +2894,8 @@ mod _ssl { return write_method.call((vm.ctx.new_bytes(data.to_vec()),), vm); } - // Normal socket mode - let socket_mod = vm.import("socket", 0)?; - let socket_class = socket_mod.get_attr("socket", vm)?; - - // Call socket.socket.send(self.sock, data) - let send_method = socket_class.get_attr("send", vm)?; - send_method.call((self.sock.clone(), vm.ctx.new_bytes(data.to_vec())), vm) + self.sock_send_method + .call((self.sock.clone(), vm.ctx.new_bytes(data.to_vec())), vm) } /// Flush any pending TLS output data to the socket @@ -2842,13 +2931,12 @@ mod _ssl { socket_timeout }; - // Use sock_select directly with calculated timeout + // Use sock_wait directly with calculated timeout let py_socket: PyRef = self.sock.clone().try_into_value(vm)?; let socket = py_socket .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, SelectKind::Write, timeout_to_use) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + let timed_out = sock_wait(&socket, SockWaitKind::Write, timeout_to_use, vm)?; if timed_out { // Keep unsent data in pending buffer @@ -2909,7 +2997,7 @@ mod _ssl { let mut sent_total = 0; while sent_total < buf.len() { - let timed_out = self.sock_wait_for_io_impl(SelectKind::Write, vm)?; + let timed_out = self.sock_wait_for_io_impl(SockWaitKind::Write, vm)?; if timed_out { // Save unsent data to pending buffer self.pending_tls_output @@ -2981,8 +3069,7 @@ mod _ssl { let socket = py_socket .sock() .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; - let timed_out = sock_select(&socket, SelectKind::Write, timeout) - .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + let timed_out = sock_wait(&socket, SockWaitKind::Write, timeout, vm)?; if timed_out { return Err( @@ -2998,7 +3085,7 @@ mod _ssl { let mut pending = self.pending_tls_output.lock(); pending.drain(..sent); } - // If sent == 0, loop will retry with sock_select + // If sent == 0, loop will retry with sock_wait } Err(e) => { if is_blocking_io_error(&e, vm) { @@ -3112,7 +3199,7 @@ mod _ssl { /// Returns the configured ServerConnection. fn initialize_server_connection( &self, - conn_guard: &mut Option, + conn_guard: &mut Option, vm: &VirtualMachine, ) -> PyResult<()> { let ctx = self.context.read(); @@ -3282,11 +3369,11 @@ mod _ssl { vm.new_value_error(format!("Failed to create server connection: {e}")) })?; - *conn_guard = Some(TlsConnection::Server(conn)); + *conn_guard = Some(Connection::Server(conn)); // If ClientHello buffer exists (from SNI callback), re-inject it if let Some(ref hello_data) = *self.client_hello_buffer.lock() - && let Some(TlsConnection::Server(ref mut server)) = *conn_guard + && let Some(Connection::Server(ref mut server)) = *conn_guard { let mut cursor = std::io::Cursor::new(hello_data.as_slice()); let _ = server.read_tls(&mut cursor); @@ -3409,14 +3496,14 @@ mod _ssl { vm.new_value_error(format!("Failed to create client connection: {e}")) })?; - *conn_guard = Some(TlsConnection::Client(conn)); + *conn_guard = Some(Connection::Client(conn)); } } // Perform the actual handshake by exchanging data with the socket/BIO let conn = conn_guard.as_mut().expect("unreachable"); - let is_client = matches!(conn, TlsConnection::Client(_)); + let is_client = matches!(conn, Connection::Client(_)); let handshake_result = ssl_do_handshake(conn, self, vm); drop(conn_guard); @@ -3442,6 +3529,7 @@ mod _ssl { // Now safe to call Python callback (no locks held) self.invoke_sni_callback(sni_name.as_deref(), vm)?; + *self.sni_callback_processed.lock() = true; // Clear connection to trigger recreation *self.connection.lock() = None; @@ -4171,7 +4259,7 @@ mod _ssl { // Wait for socket to be readable let timed_out = self.sock_wait_for_io_with_timeout( - SelectKind::Read, + SockWaitKind::Read, remaining_timeout, vm, )?; @@ -4242,7 +4330,7 @@ mod _ssl { } // Helper: Write all pending TLS data (including close_notify) to outgoing buffer/BIO - fn write_pending_tls(&self, conn: &mut TlsConnection, vm: &VirtualMachine) -> PyResult<()> { + fn write_pending_tls(&self, conn: &mut Connection, vm: &VirtualMachine) -> PyResult<()> { // First, flush any previously pending TLS output // Must succeed before sending new data to maintain order self.flush_pending_tls_output(vm, None)?; @@ -4252,7 +4340,7 @@ mod _ssl { break; } - let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let mut buf = vec![0u8; SSL3_RT_MAX_PACKET_SIZE]; let written = conn .write_tls(&mut buf.as_mut_slice()) .map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?; @@ -4272,7 +4360,7 @@ mod _ssl { // Returns true if peer closed connection (with or without close_notify) fn try_read_close_notify( &self, - conn: &mut TlsConnection, + conn: &mut Connection, vm: &VirtualMachine, ) -> PyResult { // In socket mode, peek first to avoid consuming post-TLS cleartext @@ -4280,11 +4368,11 @@ mod _ssl { // transitions to cleartext. Without peeking, sock_recv may consume // cleartext data meant for the application after unwrap(). if self.incoming_bio.is_none() { - return self.try_read_close_notify_socket(conn, vm); + return Ok(self.try_read_close_notify_socket(conn, vm)); } // BIO mode: read from incoming BIO - match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + match self.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(bytes_obj) => { let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; let data = bytes.borrow_buf(); @@ -4323,81 +4411,31 @@ mod _ssl { /// /// Equivalent to OpenSSL's `SSL_set_read_ahead(ssl, 0)` — rustls has no /// such knob, so we enforce record-level reads manually via peek. - fn try_read_close_notify_socket( - &self, - conn: &mut TlsConnection, - vm: &VirtualMachine, - ) -> PyResult { - // Peek at the first 5 bytes (TLS record header size) - let peeked_obj = match self.sock_peek(5, vm) { - Ok(obj) => obj, - Err(e) => { - if is_blocking_io_error(&e, vm) { - return Ok(false); - } - return Ok(true); - } - }; - - let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)?; - let peek_data = peeked.borrow_buf(); - - if peek_data.is_empty() { - return Ok(true); // EOF - } - - // TLS record content types: ChangeCipherSpec(20), Alert(21), - // Handshake(22), ApplicationData(23) - let content_type = peek_data[0]; - if !(20..=23).contains(&content_type) { - // Not a TLS record - post-TLS cleartext data. - // Peer has completed TLS shutdown; don't consume this data. - return Ok(true); - } - - // Determine how many bytes to read for exactly one TLS record - let recv_size = if peek_data.len() >= 5 { - let record_length = u16::from_be_bytes([peek_data[3], peek_data[4]]) as usize; - 5 + record_length - } else { - // Partial header available - read just these bytes for now - peek_data.len() - }; - - drop(peek_data); - drop(peeked); - - // Now consume exactly one TLS record from the socket - match self.sock_recv(recv_size, vm) { - Ok(bytes_obj) => { - let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?; - let data = bytes.borrow_buf(); - + fn try_read_close_notify_socket(&self, conn: &mut Connection, vm: &VirtualMachine) -> bool { + // Consume at most one TLS record from the socket + match self.sock_recv_at_most_one_tls_record(vm) { + Ok(data) => { if data.is_empty() { - return Ok(true); + return true; } let data_slice: &[u8] = data.as_ref(); let mut cursor = std::io::Cursor::new(data_slice); let _ = conn.read_tls(&mut cursor); let _ = conn.process_new_packets(); - Ok(false) + false } Err(e) => { if is_blocking_io_error(&e, vm) { - return Ok(false); + return false; } - Ok(true) + true } } } // Helper: Check if peer has sent close_notify - fn check_peer_closed( - &self, - conn: &mut TlsConnection, - vm: &VirtualMachine, - ) -> PyResult { + fn check_peer_closed(&self, conn: &mut Connection, vm: &VirtualMachine) -> PyResult { // Process any remaining packets and check peer_has_closed let io_state = conn .process_new_packets() @@ -4459,12 +4497,12 @@ mod _ssl { let conn_guard = self.connection.lock(); if let Some(conn) = conn_guard.as_ref() { let version = match conn { - TlsConnection::Client(_) => { + Connection::Client(_) => { return Err(vm.new_value_error( "Post-handshake authentication requires server socket", )); } - TlsConnection::Server(server) => server.protocol_version(), + Connection::Server(server) => server.protocol_version(), }; // Post-handshake auth is only available in TLS 1.3 diff --git a/crates/stdlib/src/ssl/cert.rs b/crates/stdlib/src/ssl/cert.rs index 835e3f37c6..513d2929eb 100644 --- a/crates/stdlib/src/ssl/cert.rs +++ b/crates/stdlib/src/ssl/cert.rs @@ -22,7 +22,7 @@ use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; use std::collections::HashSet; use x509_parser::prelude::*; -use super::compat::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}; +use super::_ssl::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}; // Certificate Verification Constants diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index ed3880940b..2d6bb369e9 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -15,18 +15,17 @@ #[path = "../openssl/ssl_data_31.rs"] mod ssl_data; -use crate::socket::{SelectKind, timeout_error_msg}; +use crate::socket::{SockWaitKind, timeout_error_msg}; use crate::vm::VirtualMachine; use alloc::sync::Arc; use parking_lot::RwLock as ParkingRwLock; +use rustls::Connection; use rustls::RootCertStore; use rustls::client::ClientConfig; -use rustls::client::ClientConnection; use rustls::crypto::SupportedKxGroup; use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer}; use rustls::server::ResolvesServerCert; use rustls::server::ServerConfig; -use rustls::server::ServerConnection; use rustls::sign::CertifiedKey; use rustpython_vm::builtins::{PyBaseException, PyBaseExceptionRef}; use rustpython_vm::convert::IntoPyException; @@ -36,7 +35,7 @@ use std::io::Read; use std::sync::Once; // Import PySSLSocket from parent module -use super::_ssl::PySSLSocket; +use super::_ssl::{PySSLSocket, SSL3_RT_MAX_PACKET_SIZE, VERIFY_X509_STRICT}; // Import error types and helper functions from error module use super::error::{ @@ -44,16 +43,6 @@ use super::error::{ create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, }; -// SSL Verification Flags -/// VERIFY_X509_STRICT flag for RFC 5280 strict compliance -/// When set, performs additional validation including AKI extension checks -pub(super) const VERIFY_X509_STRICT: i32 = 0x20; - -/// VERIFY_X509_PARTIAL_CHAIN flag for partial chain validation -/// When set, accept certificates if any certificate in the chain is in the trust store -/// (not just root CAs). This matches OpenSSL's X509_V_FLAG_PARTIAL_CHAIN behavior. -pub(super) const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000; - // CryptoProvider Initialization: /// Ensure the default CryptoProvider is installed (thread-safe, runs once) @@ -73,10 +62,6 @@ fn ensure_default_provider() { // OpenSSL Constants: -// OpenSSL TLS record maximum plaintext size (ssl/ssl_local.h) -// #define SSL3_RT_MAX_PLAIN_LENGTH 16384 -const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384; - // OpenSSL error library codes (include/openssl/err.h) // #define ERR_LIB_SSL 20 const ERR_LIB_SSL: i32 = 20; @@ -95,74 +80,15 @@ const X509_V_FLAG_CRL_CHECK: i32 = 4; // verification. They are used to map rustls certificate errors to OpenSSL // error codes for compatibility. -pub(super) use x509::{ - X509_V_ERR_CERT_HAS_EXPIRED, X509_V_ERR_CERT_NOT_YET_VALID, X509_V_ERR_CERT_REVOKED, - X509_V_ERR_HOSTNAME_MISMATCH, X509_V_ERR_INVALID_PURPOSE, X509_V_ERR_IP_ADDRESS_MISMATCH, - X509_V_ERR_UNABLE_TO_GET_CRL, X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY, - X509_V_ERR_UNSPECIFIED, -}; - -#[allow(dead_code)] -mod x509 { - pub(super) const X509_V_OK: i32 = 0; - pub(crate) const X509_V_ERR_UNSPECIFIED: i32 = 1; - pub(super) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT: i32 = 2; - pub(crate) const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3; - pub(super) const X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE: i32 = 4; - pub(super) const X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE: i32 = 5; - pub(super) const X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY: i32 = 6; - pub(super) const X509_V_ERR_CERT_SIGNATURE_FAILURE: i32 = 7; - pub(super) const X509_V_ERR_CRL_SIGNATURE_FAILURE: i32 = 8; - pub(crate) const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9; - pub(crate) const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10; - pub(super) const X509_V_ERR_CRL_NOT_YET_VALID: i32 = 11; - pub(super) const X509_V_ERR_CRL_HAS_EXPIRED: i32 = 12; - pub(super) const X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD: i32 = 13; - pub(super) const X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD: i32 = 14; - pub(super) const X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD: i32 = 15; - pub(super) const X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD: i32 = 16; - pub(super) const X509_V_ERR_OUT_OF_MEM: i32 = 17; - pub(super) const X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: i32 = 18; - pub(super) const X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: i32 = 19; - pub(crate) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20; - pub(super) const X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: i32 = 21; - pub(super) const X509_V_ERR_CERT_CHAIN_TOO_LONG: i32 = 22; - pub(crate) const X509_V_ERR_CERT_REVOKED: i32 = 23; - pub(super) const X509_V_ERR_INVALID_CA: i32 = 24; - pub(super) const X509_V_ERR_PATH_LENGTH_EXCEEDED: i32 = 25; - pub(crate) const X509_V_ERR_INVALID_PURPOSE: i32 = 26; - pub(super) const X509_V_ERR_CERT_UNTRUSTED: i32 = 27; - pub(super) const X509_V_ERR_CERT_REJECTED: i32 = 28; - pub(super) const X509_V_ERR_SUBJECT_ISSUER_MISMATCH: i32 = 29; - pub(super) const X509_V_ERR_AKID_SKID_MISMATCH: i32 = 30; - pub(super) const X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH: i32 = 31; - pub(super) const X509_V_ERR_KEYUSAGE_NO_CERTSIGN: i32 = 32; - pub(super) const X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER: i32 = 33; - pub(super) const X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION: i32 = 34; - pub(super) const X509_V_ERR_KEYUSAGE_NO_CRL_SIGN: i32 = 35; - pub(super) const X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION: i32 = 36; - pub(super) const X509_V_ERR_INVALID_NON_CA: i32 = 37; - pub(super) const X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED: i32 = 38; - pub(super) const X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE: i32 = 39; - pub(super) const X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED: i32 = 40; - pub(super) const X509_V_ERR_INVALID_EXTENSION: i32 = 41; - pub(super) const X509_V_ERR_INVALID_POLICY_EXTENSION: i32 = 42; - pub(super) const X509_V_ERR_NO_EXPLICIT_POLICY: i32 = 43; - pub(super) const X509_V_ERR_DIFFERENT_CRL_SCOPE: i32 = 44; - pub(super) const X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE: i32 = 45; - pub(super) const X509_V_ERR_UNNESTED_RESOURCE: i32 = 46; - pub(super) const X509_V_ERR_PERMITTED_VIOLATION: i32 = 47; - pub(super) const X509_V_ERR_EXCLUDED_VIOLATION: i32 = 48; - pub(super) const X509_V_ERR_SUBTREE_MINMAX: i32 = 49; - pub(super) const X509_V_ERR_APPLICATION_VERIFICATION: i32 = 50; - pub(super) const X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE: i32 = 51; - pub(super) const X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX: i32 = 52; - pub(super) const X509_V_ERR_UNSUPPORTED_NAME_SYNTAX: i32 = 53; - pub(super) const X509_V_ERR_CRL_PATH_VALIDATION_ERROR: i32 = 54; - pub(crate) const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62; - pub(super) const X509_V_ERR_EMAIL_MISMATCH: i32 = 63; - pub(crate) const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64; -} +pub(super) const X509_V_ERR_UNSPECIFIED: i32 = 1; +pub(super) const X509_V_ERR_UNABLE_TO_GET_CRL: i32 = 3; +pub(super) const X509_V_ERR_CERT_NOT_YET_VALID: i32 = 9; +pub(super) const X509_V_ERR_CERT_HAS_EXPIRED: i32 = 10; +pub(super) const X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: i32 = 20; +pub(super) const X509_V_ERR_CERT_REVOKED: i32 = 23; +pub(super) const X509_V_ERR_INVALID_PURPOSE: i32 = 26; +pub(super) const X509_V_ERR_HOSTNAME_MISMATCH: i32 = 62; +pub(super) const X509_V_ERR_IP_ADDRESS_MISMATCH: i32 = 64; // Certificate Error Conversion Functions: @@ -263,126 +189,6 @@ pub(super) fn create_ssl_cert_verification_error( Ok(exc.upcast()) } -/// Unified TLS connection type (client or server) -#[derive(Debug)] -pub(super) enum TlsConnection { - Client(ClientConnection), - Server(ServerConnection), -} - -impl TlsConnection { - /// Check if handshake is in progress - pub(super) fn is_handshaking(&self) -> bool { - match self { - Self::Client(conn) => conn.is_handshaking(), - Self::Server(conn) => conn.is_handshaking(), - } - } - - /// Check if connection wants to read data - pub(super) fn wants_read(&self) -> bool { - match self { - Self::Client(conn) => conn.wants_read(), - Self::Server(conn) => conn.wants_read(), - } - } - - /// Check if connection wants to write data - pub(super) fn wants_write(&self) -> bool { - match self { - Self::Client(conn) => conn.wants_write(), - Self::Server(conn) => conn.wants_write(), - } - } - - /// Read TLS data from socket - pub(super) fn read_tls(&mut self, reader: &mut dyn std::io::Read) -> std::io::Result { - match self { - Self::Client(conn) => conn.read_tls(reader), - Self::Server(conn) => conn.read_tls(reader), - } - } - - /// Write TLS data to socket - pub(super) fn write_tls(&mut self, writer: &mut dyn std::io::Write) -> std::io::Result { - match self { - Self::Client(conn) => conn.write_tls(writer), - Self::Server(conn) => conn.write_tls(writer), - } - } - - /// Process new TLS packets - pub(super) fn process_new_packets(&mut self) -> Result { - match self { - Self::Client(conn) => conn.process_new_packets(), - Self::Server(conn) => conn.process_new_packets(), - } - } - - /// Get reader for plaintext data (rustls native type) - pub(super) fn reader(&mut self) -> rustls::Reader<'_> { - match self { - Self::Client(conn) => conn.reader(), - Self::Server(conn) => conn.reader(), - } - } - - /// Get writer for plaintext data (rustls native type) - pub(super) fn writer(&mut self) -> rustls::Writer<'_> { - match self { - Self::Client(conn) => conn.writer(), - Self::Server(conn) => conn.writer(), - } - } - - /// Check if session was resumed - pub(super) fn is_session_resumed(&self) -> bool { - use rustls::HandshakeKind; - match self { - Self::Client(conn) => { - matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) - } - Self::Server(conn) => { - matches!(conn.handshake_kind(), Some(HandshakeKind::Resumed)) - } - } - } - - /// Send close_notify alert - pub(super) fn send_close_notify(&mut self) { - match self { - Self::Client(conn) => conn.send_close_notify(), - Self::Server(conn) => conn.send_close_notify(), - } - } - - /// Get negotiated ALPN protocol - pub(super) fn alpn_protocol(&self) -> Option<&[u8]> { - match self { - Self::Client(conn) => conn.alpn_protocol(), - Self::Server(conn) => conn.alpn_protocol(), - } - } - - /// Get negotiated cipher suite - pub(super) fn negotiated_cipher_suite(&self) -> Option { - match self { - Self::Client(conn) => conn.negotiated_cipher_suite(), - Self::Server(conn) => conn.negotiated_cipher_suite(), - } - } - - /// Get peer certificates - pub(super) fn peer_certificates( - &self, - ) -> Option<&[rustls::pki_types::CertificateDer<'static>]> { - match self { - Self::Client(conn) => conn.peer_certificates(), - Self::Server(conn) => conn.peer_certificates(), - } - } -} - /// Error types matching OpenSSL error codes #[derive(Debug)] pub(super) enum SslError { @@ -584,6 +390,16 @@ impl SslError { // Use the proper cert verification error creator create_ssl_cert_verification_error(vm, &cert_err).expect("unlikely to happen") } + Self::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + create_ssl_eof_error(vm).upcast() + } + Self::Io(err) if err.raw_os_error().is_none() => vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + format!("SSL error: {err}"), + ) + .upcast(), Self::Io(err) => err.into_pyexception(vm), Self::SniCallbackRestart => { // This should be handled at PySSLSocket level @@ -1058,11 +874,11 @@ fn send_all_bytes( )); } socket - .sock_wait_for_io_with_timeout(SelectKind::Write, Some(dl - now), vm) + .sock_wait_for_io_with_timeout(SockWaitKind::Write, Some(dl - now), vm) .map_err(SslError::Py)? } else { socket - .sock_wait_for_io_impl(SelectKind::Write, vm) + .sock_wait_for_io_impl(SockWaitKind::Write, vm) .map_err(SslError::Py)? }; if timed_out { @@ -1120,7 +936,7 @@ fn send_all_bytes( /// Drains all pending TLS data from rustls and sends it to the peer. /// Returns whether any progress was made. fn handshake_write_loop( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, force_initial_write: bool, vm: &VirtualMachine, @@ -1163,14 +979,10 @@ fn handshake_write_loop( Ok(made_progress) } -/// Read TLS handshake data from socket/BIO +/// Read at most one TLS record from the TCP socket. /// -/// Waits for and reads TLS records from the peer, handling SNI callback setup. -/// Returns (made_progress, is_first_sni_read). -/// TLS record header size (content_type + version + length). -const TLS_RECORD_HEADER_SIZE: usize = 5; - -/// Read exactly one TLS record from the TCP socket. +/// May return incomplete data but never returns more when completes a +/// previously incomplete TLS record. /// /// OpenSSL reads one TLS record at a time (no read-ahead by default). /// Rustls, however, consumes all available TCP data when fed via read_tls(). @@ -1183,77 +995,32 @@ const TLS_RECORD_HEADER_SIZE: usize = 5; /// Fix: peek at the TCP buffer to find the first complete TLS record boundary /// and recv() only that many bytes. Any remaining data stays in the kernel /// buffer and remains visible to select(). -fn recv_one_tls_record(socket: &PySSLSocket, vm: &VirtualMachine) -> SslResult { - // Peek at what is available without consuming it. - let peeked_obj = match socket.sock_peek(SSL3_RT_MAX_PLAIN_LENGTH, vm) { - Ok(d) => d, - Err(e) => { - if is_blocking_io_error(&e, vm) { - return Err(SslError::WantRead); - } - return Err(SslError::Py(e)); - } - }; - - let peeked = ArgBytesLike::try_from_object(vm, peeked_obj) - .map_err(|_| SslError::Syscall("Expected bytes-like object from peek".to_string()))?; - let peeked_bytes = peeked.borrow_buf(); - - if peeked_bytes.is_empty() { - // Empty peek means the peer has closed the TCP connection (FIN). - // Non-blocking sockets would have returned EAGAIN/EWOULDBLOCK - // (caught above as WantRead), so empty bytes here always means EOF. - return Err(SslError::Eof); - } - - if peeked_bytes.len() < TLS_RECORD_HEADER_SIZE { - // Not enough data for a TLS record header yet. - // Read all available bytes so rustls can buffer the partial header; - // this avoids busy-waiting because the kernel buffer is now empty - // and select() will only wake us when new data arrives. - return socket.sock_recv(peeked_bytes.len(), vm).map_err(|e| { - if is_blocking_io_error(&e, vm) { - SslError::WantRead - } else { - SslError::Py(e) - } - }); - } - - // Parse the TLS record length from the header. - let record_body_len = u16::from_be_bytes([peeked_bytes[3], peeked_bytes[4]]) as usize; - let total_record_size = TLS_RECORD_HEADER_SIZE + record_body_len; - - let recv_size = if peeked_bytes.len() >= total_record_size { - // Complete record available — consume exactly one record. - total_record_size - } else { - // Incomplete record — consume everything so the kernel buffer is - // drained and select() will block until more data arrives. - peeked_bytes.len() - }; - - // Must drop the borrow before calling sock_recv (which re-enters Python). - drop(peeked_bytes); - drop(peeked); - - socket.sock_recv(recv_size, vm).map_err(|e| { +fn recv_at_most_one_tls_record( + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult { + let bytes = socket.sock_recv_at_most_one_tls_record(vm).map_err(|e| { if is_blocking_io_error(&e, vm) { SslError::WantRead } else { SslError::Py(e) } - }) + })?; + if bytes.is_empty() { + Err(SslError::Eof) + } else { + Ok(bytes.into()) + } } -/// Read a single TLS record for post-handshake I/O while preserving the +/// Read up to a single TLS record for post-handshake I/O while preserving the /// SSL-vs-socket error precedence from the old sock_recv() path. -fn recv_one_tls_record_for_data( - conn: &mut TlsConnection, +fn recv_at_most_one_tls_record_for_data( + conn: &mut Connection, socket: &PySSLSocket, vm: &VirtualMachine, ) -> SslResult { - match recv_one_tls_record(socket, vm) { + match recv_at_most_one_tls_record(socket, vm) { Ok(data) => Ok(data), Err(SslError::Eof) => { if let Err(rustls_err) = conn.process_new_packets() { @@ -1275,7 +1042,7 @@ fn recv_one_tls_record_for_data( } fn handshake_read_data( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, is_bio: bool, is_server: bool, @@ -1285,14 +1052,15 @@ fn handshake_read_data( return Ok((false, false)); } - // SERVER-SPECIFIC: Check if this is the first read (for SNI callback) - // Must check BEFORE reading data, so we can detect first time - let is_first_sni_read = is_server && socket.is_first_sni_read(); + // SERVER-SPECIFIC: Check if this is before the SNI callback. + // sock_recv() may return only part of a TLS record, so keep capturing + // ClientHello fragments until process_new_packets() has produced a response. + let is_first_sni_read = is_server && socket.has_sni_callback() && socket.is_first_sni_read(); // Wait for data in socket mode if !is_bio { let timed_out = socket - .sock_wait_for_io_impl(SelectKind::Read, vm) + .sock_wait_for_io_impl(SockWaitKind::Read, vm) .map_err(SslError::Py)?; if timed_out { @@ -1308,9 +1076,9 @@ fn handshake_read_data( // record. This matches OpenSSL's default (no read-ahead) behaviour // and keeps remaining data in the kernel buffer where select() can // detect it. - recv_one_tls_record(socket, vm)? + recv_at_most_one_tls_record(socket, vm)? } else { - match socket.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) { + match socket.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(d) => d, Err(e) => { if is_blocking_io_error(&e, vm) { @@ -1324,7 +1092,7 @@ fn handshake_read_data( } }; - // SERVER-SPECIFIC: Save ClientHello on first read for potential connection recreation + // SERVER-SPECIFIC: Save ClientHello fragments for potential connection recreation. if is_first_sni_read { // Extract bytes from PyObjectRef use rustpython_vm::builtins::PyBytes; @@ -1344,7 +1112,7 @@ fn handshake_read_data( /// Tries to send NewSessionTicket in non-blocking mode to avoid deadlocks. /// Returns true if handshake is complete and we should exit. fn handle_handshake_complete( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, _is_server: bool, vm: &VirtualMachine, @@ -1419,7 +1187,7 @@ fn handle_handshake_complete( /// /// Returns Ok(Some(n)) if n bytes were read, Ok(None) if would block, /// or Err on real errors. -fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult> { +fn try_read_plaintext(conn: &mut Connection, buf: &mut [u8]) -> SslResult> { let mut reader = conn.reader(); match reader.read(buf) { Ok(0) => { @@ -1448,7 +1216,7 @@ fn try_read_plaintext(conn: &mut TlsConnection, buf: &mut [u8]) -> SslResult SslResult<()> { @@ -1458,12 +1226,9 @@ pub(super) fn ssl_do_handshake( } let is_bio = socket.is_bio_mode(); - let is_server = matches!(conn, TlsConnection::Server(_)); + let is_server = matches!(conn, Connection::Server(_)); let mut first_iteration = true; // Track if this is the first loop iteration - let mut iteration_count = 0; - loop { - iteration_count += 1; let mut made_progress = false; // IMPORTANT: In BIO mode, force initial write even if wants_write() is false @@ -1506,10 +1271,10 @@ pub(super) fn ssl_do_handshake( return Err(SslError::from_rustls(e)); } - // SERVER-SPECIFIC: Check SNI callback after processing packets - // SNI name is extracted during process_new_packets() - // Invoke callback on FIRST read if callback is configured, regardless of SNI presence - if is_server && is_first_sni_read && socket.has_sni_callback() { + // SERVER-SPECIFIC: Check SNI callback after processing packets. + // A partial TLS record can be read without producing any handshake + // response. Wait until rustls has processed a complete ClientHello. + if is_server && is_first_sni_read && socket.has_sni_callback() && conn.wants_write() { // IMPORTANT: Do NOT call the callback here! // The connection lock is still held, which would cause deadlock. // Return SniCallbackRestart to signal do_handshake to: @@ -1533,7 +1298,7 @@ pub(super) fn ssl_do_handshake( if conn.wants_write() { // Write all pending TLS data to outgoing BIO loop { - let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH]; + let mut buf = vec![0u8; SSL3_RT_MAX_PACKET_SIZE]; let n = match conn.write_tls(&mut buf.as_mut_slice()) { Ok(n) => n, Err(_) => break, @@ -1581,11 +1346,6 @@ pub(super) fn ssl_do_handshake( if !should_continue { break; } - - // Safety check: prevent truly infinite loops (should never happen) - if iteration_count > 1000 { - break; - } } // If we exit the loop without completing handshake, return appropriate error @@ -1599,9 +1359,9 @@ pub(super) fn ssl_do_handshake( return Err(SslError::WantRead); } // Neither wants_read nor wants_write - this is a real error - Err(SslError::Syscall(format!( - "SSL handshake failed: incomplete after {iteration_count} iterations", - ))) + Err(SslError::Syscall( + "SSL handshake failed: incomplete handshake".to_string(), + )) } else { // Handshake completed successfully (shouldn't reach here normally) Ok(()) @@ -1615,7 +1375,7 @@ pub(super) fn ssl_do_handshake( /// /// = SSL_read_ex() pub(super) fn ssl_read( - conn: &mut TlsConnection, + conn: &mut Connection, buf: &mut [u8], socket: &PySSLSocket, vm: &VirtualMachine, @@ -1753,7 +1513,7 @@ pub(super) fn ssl_read( // Blocking socket or socket with timeout: try to read more data from socket. // Even though rustls says it doesn't want to read, more TLS records may arrive. // Use single-record reading to avoid consuming close_notify alongside data. - let data = recv_one_tls_record_for_data(conn, socket, vm)?; + let data = recv_at_most_one_tls_record_for_data(conn, socket, vm)?; let bytes_read = data .clone() @@ -1788,11 +1548,6 @@ pub(super) fn ssl_read( // Successfully read and processed TLS data // Continue loop to try reading plaintext } - Err(SslError::Io(ref io_err)) if io_err.to_string().contains("message buffer full") => { - // This case should be rare now that ssl_read_tls_records handles buffer full - // Just continue loop to try again - continue; - } Err(e) => { // Other errors - check for buffered plaintext before propagating match try_read_plaintext(conn, buf)? { @@ -1817,7 +1572,7 @@ pub(super) fn ssl_read( /// /// = SSL_write_ex() pub(super) fn ssl_write( - conn: &mut TlsConnection, + conn: &mut Connection, data: &[u8], socket: &PySSLSocket, vm: &VirtualMachine, @@ -1944,7 +1699,7 @@ pub(super) fn ssl_write( return Err(SslError::WantRead); } // For socket mode, try to read TLS data - let recv_result = socket.sock_recv(4096, vm).map_err(SslError::Py)?; + let recv_result = recv_at_most_one_tls_record_for_data(conn, socket, vm)?; ssl_read_tls_records(conn, recv_result, false, vm)?; conn.process_new_packets().map_err(SslError::from_rustls)?; // Continue loop @@ -1994,7 +1749,7 @@ pub(super) fn ssl_write( // Helper functions (private-ish, used by public SSL functions) /// Write TLS records from rustls to socket -fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult> { +fn ssl_write_tls_records(conn: &mut Connection) -> SslResult> { let mut buf = Vec::new(); let n = conn .write_tls(&mut buf as &mut dyn std::io::Write) @@ -2005,7 +1760,7 @@ fn ssl_write_tls_records(conn: &mut TlsConnection) -> SslResult> { /// Read TLS records from socket to rustls fn ssl_read_tls_records( - conn: &mut TlsConnection, + conn: &mut Connection, data: PyObjectRef, is_bio: bool, vm: &VirtualMachine, @@ -2068,6 +1823,9 @@ fn ssl_read_tls_records( } Ok(n) => { offset += n; + if offset < bytes_data.len() { + conn.process_new_packets().map_err(SslError::from_rustls)?; + } } Err(e) => { return Err(SslError::Io(e)); @@ -2075,14 +1833,12 @@ fn ssl_read_tls_records( } } else { offset += read_bytes; + if offset < bytes_data.len() { + conn.process_new_packets().map_err(SslError::from_rustls)?; + } } } Err(e) => { - // Check if it's a buffer full error (unlikely but handle it) - if e.to_string().contains("buffer full") { - conn.process_new_packets().map_err(SslError::from_rustls)?; - continue; - } // Real error - propagate it return Err(SslError::Io(e)); } @@ -2118,7 +1874,7 @@ fn is_connection_closed_error(exc: &Py, vm: &VirtualMachine) -> /// Ensure TLS data is available for reading /// Returns the number of bytes read from the socket fn ssl_ensure_data_available( - conn: &mut TlsConnection, + conn: &mut Connection, socket: &PySSLSocket, vm: &VirtualMachine, ) -> SslResult { @@ -2140,7 +1896,7 @@ fn ssl_ensure_data_available( { // Socket has timeout - use select to enforce it let timed_out = socket - .sock_wait_for_io_impl(SelectKind::Read, vm) + .sock_wait_for_io_impl(SockWaitKind::Read, vm) .map_err(SslError::Py)?; if timed_out { // Socket not ready within timeout - raise socket.timeout @@ -2157,9 +1913,9 @@ fn ssl_ensure_data_available( // consuming a close_notify that arrives alongside application data, // keeping it in the kernel buffer where select() can detect it. let data = if !is_bio { - recv_one_tls_record_for_data(conn, socket, vm)? + recv_at_most_one_tls_record_for_data(conn, socket, vm)? } else { - match socket.sock_recv(2048, vm) { + match socket.sock_recv(SSL3_RT_MAX_PACKET_SIZE, vm) { Ok(data) => data, Err(e) => { if is_blocking_io_error(&e, vm) { diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index d12cd834d1..07ff448869 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -125,7 +125,6 @@ pub(crate) mod ssl_error { ) } - #[allow(dead_code, reason = "This seems like a false positive")] pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { vm.new_os_subtype_error( PySSLZeroReturnError::class(&vm.ctx).to_owned(), @@ -134,7 +133,6 @@ pub(crate) mod ssl_error { ) } - #[allow(dead_code, reason = "This seems like a false positive")] pub(crate) fn create_ssl_syscall_error( vm: &VirtualMachine, msg: impl Into, diff --git a/extra_tests/snippets/stdlib_ssl_short_recv.py b/extra_tests/snippets/stdlib_ssl_short_recv.py new file mode 100644 index 0000000000..4ec36e5b7e --- /dev/null +++ b/extra_tests/snippets/stdlib_ssl_short_recv.py @@ -0,0 +1,88 @@ +import os +import socket +import ssl +import sys +import threading + +if sys.implementation.name.lower() != "rustpython": + print("Ignored: stdlib_ssl_short_recv (RustPython only)") + raise SystemExit + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +CERTFILE = os.path.join(ROOT_DIR, "Lib/test/certdata/keycert.pem") +DATA = b"x" * 128 + +orig_recv = socket.socket.recv +client_sockname = None +recv_n = {} + + +def new_recv(sock, bufsize, flags=0): + sockname = sock.getsockname() + if sockname not in recv_n: + recv_n[sockname] = 0 + + bufsize = 1 + + if flags & socket.MSG_PEEK == 0: + recv_n[sockname] += 1 + return orig_recv(sock, bufsize, flags) + + +socket.socket.recv = new_recv + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +listener.bind(("127.0.0.1", 0)) +listener.listen(1) +addr, port = listener.getsockname() +server_errors = [] + + +def server(): + try: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(CERTFILE) + + sock, _ = listener.accept() + sock.settimeout(5.0) + + ssock = server_context.wrap_socket(sock, server_side=True) + try: + ssock.sendall(DATA) + finally: + ssock.close() + except BaseException as exc: + server_errors.append(exc) + finally: + listener.close() + + +thread = threading.Thread(target=server) +thread.start() + +raw = socket.create_connection((addr, port), timeout=5.0) +client_sockname = raw.getsockname() +raw.settimeout(5.0) + +client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +client_context.check_hostname = False +client_context.verify_mode = ssl.CERT_NONE + +client = client_context.wrap_socket(raw, server_hostname=None) +try: + chunks = [] + while sum(len(chunk) for chunk in chunks) < len(DATA): + chunk = client.recv(20000) + if not chunk: + break + chunks.append(chunk) +finally: + client.close() + +thread.join(10.0) +assert not thread.is_alive(), "server thread did not stop" +assert not server_errors, server_errors +assert b"".join(chunks) == DATA +assert len(recv_n) == 2 +assert all(n > 100 for n in recv_n.values()) diff --git a/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py b/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py new file mode 100644 index 0000000000..18ac9ab010 --- /dev/null +++ b/extra_tests/snippets/stdlib_urllib_https_misaligned_recv.py @@ -0,0 +1,246 @@ +import os +import socket +import ssl +import sys +import threading +import time +import urllib.request + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +CERTFILE = os.path.join(ROOT_DIR, "Lib/test/certdata/keycert.pem") +BODY = b"x" * 407_676 + +# TLS record body sizes observed from https://crates.io/api/v1/crates/tokio. +TLS_RECORD_BODY_SIZES = [ + 2855, + 281, + 53, + 218, + 1095, + 1395, + 1395, + 483, + 1395, + 1395, + 1395, + 1395, + 48, + 1360, + 1354, + 1395, + 1395, + 1395, + 1367, + 1395, + 1395, + 1395, + 1395, + 1326, + 1395, + 1395, + 1395, + 47, + 1395, + 1395, + 1395, + 1395, + 95, + 1395, + 1332, + 1287, + 1388, + 1395, + 1395, + 1374, + 1395, + 1380, + 794, + 791, + 1395, + 1381, + 1395, + 1395, + 1395, + 1333, + 1395, + 1395, + 1395, + 1395, + 1395, + 1395, + 965, + 16401, + 3914, + 2526, + 1041, + 8209, + 9233, + 16401, + 11650, + 10262, + 7486, + 3468, + 692, + 1041, + 16401, + 12242, + 9466, + 1041, + 8209, + 9233, + 8209, + 9233, + 16401, + 1041, + 8209, + 9233, + 6161, + 2065, + 9233, + 16401, + 16358, + 10806, + 1041, + 8209, + 16401, + 3914, + 16401, + 16401, + 3089, + 9233, + 4642, + 478, + 8209, + 3140, + 1752, + 9233, + 8209, + 8209, + 16401, + 16064, + 14676, + 13288, + 2065, + 16401, + 1041, + 8209, + 16401, + 1041, + 6374, + 1007, +] + +server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +server_context.load_cert_chain(CERTFILE) +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +listener.bind(("127.0.0.1", 0)) +listener.listen(1) +addr, port = listener.getsockname() +server_errors = [] +finished = False + + +def guard_timeout(): + time.sleep(20) + if not finished: + print( + "stdlib_urllib_https_misaligned_recv.py timed out", + file=sys.stderr, + flush=True, + ) + os.abort() + + +threading.Thread(target=guard_timeout, daemon=True).start() + + +def drain_outgoing(outgoing, conn): + while True: + try: + data = outgoing.read() + except ssl.SSLWantReadError: + return + if not data: + return + conn.sendall(data) + + +def run_server(): + try: + conn, _ = listener.accept() + conn.settimeout(5.0) + conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + tls = server_context.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + tls.do_handshake() + break + except ssl.SSLWantReadError: + drain_outgoing(outgoing, conn) + incoming.write(conn.recv(65536)) + except ssl.SSLWantWriteError: + pass + drain_outgoing(outgoing, conn) + + request = b"" + while b"\r\n\r\n" not in request: + try: + request += tls.read(65536) + except ssl.SSLWantReadError: + drain_outgoing(outgoing, conn) + incoming.write(conn.recv(65536)) + drain_outgoing(outgoing, conn) + + response = ( + b"HTTP/1.1 200 OK\r\n" + b"Connection: close\r\n" + + b"Content-Length: " + + str(len(BODY)).encode() + + b"\r\n" + + b"Content-Type: application/json\r\n" + + b"\r\n" + + BODY + ) + plaintext_sizes = [max(1, n - 17) for n in TLS_RECORD_BODY_SIZES] + pos = 0 + while pos < len(response): + size = plaintext_sizes.pop(0) if plaintext_sizes else 16384 + end = min(len(response), pos + size) + while pos < end: + try: + pos += tls.write(response[pos:end]) + except ssl.SSLWantWriteError: + pass + drain_outgoing(outgoing, conn) + conn.close() + except BaseException as exc: + server_errors.append(exc) + finally: + listener.close() + + +thread = threading.Thread(target=run_server) +thread.start() + +client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +client_context.check_hostname = False +client_context.verify_mode = ssl.CERT_NONE +opener = urllib.request.build_opener( + urllib.request.ProxyHandler({}), + urllib.request.HTTPSHandler(context=client_context), +) +try: + with opener.open(f"https://{addr}:{port}/", timeout=5.0) as response: + body = response.read() + + thread.join(10.0) + assert not thread.is_alive(), "server thread did not stop" + assert not server_errors, server_errors + assert body == BODY +finally: + finished = True From cf3b6397b2daff46532dcf2be0210c5a0ae376ce Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Fri, 22 May 2026 11:10:24 +0000 Subject: [PATCH 04/18] Fix panic in select.select() when too many FDs specified (#7948) Also: * Add regression test into existing `extra_tests/snippets/stdlib_select.py` * Stop calculating nfds on Windows as it is ignored there Panic: thread 'main' (189598) panicked at /root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/libc-0.2.186/src/unix/linux_like/mod.rs:1777:9: index out of bounds: the len is 16 but the index is 16 note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace thread 'main' (189598) panicked at library/core/src/panicking.rs:225:5: panic in a function that cannot unwind stack backtrace: 0: 0xaaab763a0b88 - <::print::DisplayBacktrace as core[f1abae5f1257fe69]::fmt::Display>::fmt 1: 0xaaab75590ff0 - core[f1abae5f1257fe69]::fmt::write 2: 0xaaab763a94fc - ::write_fmt 3: 0xaaab7638b714 - std[1934960bf7f41d0a]::panicking::default_hook::{closure#0} 4: 0xaaab7639b288 - std[1934960bf7f41d0a]::panicking::default_hook 5: 0xaaab7639b478 - std[1934960bf7f41d0a]::panicking::panic_with_hook 6: 0xaaab7638b7ec - std[1934960bf7f41d0a]::panicking::panic_handler::{closure#0} 7: 0xaaab76382654 - std[1934960bf7f41d0a]::sys::backtrace::__rust_end_short_backtrace:: 8: 0xaaab7638c504 - __rustc[b7425922bef61dcf]::rust_begin_unwind 9: 0xaaab754f778c - core[f1abae5f1257fe69]::panicking::panic_nounwind_fmt 10: 0xaaab754f7714 - core[f1abae5f1257fe69]::panicking::panic_nounwind 11: 0xaaab754f786c - core[f1abae5f1257fe69]::panicking::panic_cannot_unwind 12: 0xaaab75a6283c - rustpython_vm::function::builtin::,rustpython_vm::function::builtin::OwnedParam,rustpython_vm::function::builtin::OwnedParam,rustpython_vm::function::builtin::OwnedParam),R,rustpython_vm::vm::VirtualMachine> for F>::call_::h2471c8e242c9b51d 13: 0xaaab75db1e68 - rustpython_vm::types::slot::Callable::slot_call::hd1c1ad0ad14f306b 14: 0xaaab762c0a50 - rustpython_vm::protocol::callable::PyCallable::invoke::h9f6d571fca351ca6 15: 0xaaab75c550e8 - rustpython_vm::protocol::callable::::call_with_args::hed1f4a61aba2dced 16: 0xaaab762e7c24 - rustpython_vm::frame::ExecutingFrame::execute_call::h0ad3490dd74ed1e3 17: 0xaaab762fed40 - rustpython_vm::frame::ExecutingFrame::run::hcf90f0950fc26812 18: 0xaaab761e6768 - rustpython_vm::vm::VirtualMachine::with_frame::hd49ba6fcdf2422e2 19: 0xaaab75c45398 - rustpython_vm::builtins::function::>::invoke_with_locals::h42de3d2316941ce2 20: 0xaaab76132a80 - rustpython_vm::builtins::function::vectorcall_function::h7331cb67b334e867 21: 0xaaab763369d8 - rustpython_vm::protocol::callable::::vectorcall::h9019c5d16685c89a 22: 0xaaab762f4b54 - rustpython_vm::frame::ExecutingFrame::execute_call_vectorcall::h120134e11a58c946 23: 0xaaab76302a7c - rustpython_vm::frame::ExecutingFrame::run::hcf90f0950fc26812 24: 0xaaab761e6768 - rustpython_vm::vm::VirtualMachine::with_frame::hd49ba6fcdf2422e2 25: 0xaaab761e7f24 - rustpython_vm::vm::VirtualMachine::run_code_obj::h354618be6e5cc553 26: 0xaaab761e2d18 - rustpython_vm::vm::python_run::file_run::::run_any_file::h783d3127fbc0b523 27: 0xaaab757d700c - rustpython::run_rustpython::h354efb8d817cefbf 28: 0xaaab757c79e0 - std::thread::local::LocalKey::with::hc9728e249843a926 29: 0xaaab757db860 - rustpython_vm::vm::interpreter::Interpreter::run::h42ac1fe9ed2287a2 30: 0xaaab757d7b30 - rustpython::run::hf14a209db5b4289c 31: 0xaaab757e2eb4 - rustpython::main::h1b59d8e13276ac48 32: 0xaaab757e2eec - std::sys::backtrace::__rust_begin_short_backtrace::h47e4b1f073f2155c 33: 0xaaab757e2ed4 - std::rt::lang_start::{{closure}}::h663a6c3dc7d80101 34: 0xaaab76399fd4 - std[1934960bf7f41d0a]::rt::lang_start_internal 35: 0xaaab757e2f44 - main 36: 0xfffed057655c - __libc_start_call_main 37: 0xfffed057663c - __libc_start_main@@GLIBC_2.34 38: 0xaaab755526f0 - _start 39: 0x0 - thread caused non-unwinding panic. aborting. Aborted (core dumped) cargo run --release -- extra_tests/snippets/stdlib_select.py --- crates/stdlib/src/select.rs | 30 +++++++++++++++++++++------ extra_tests/snippets/stdlib_select.py | 25 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/crates/stdlib/src/select.rs b/crates/stdlib/src/select.rs index 3ec4e62259..12e55db57f 100644 --- a/crates/stdlib/src/select.rs +++ b/crates/stdlib/src/select.rs @@ -5,7 +5,7 @@ pub(crate) use decl::module_def; use crate::vm::{ PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::PyListRef, }; -use rustpython_host_env::select::{self as host_select, FdSet, RawFd}; +use rustpython_host_env::select::{self as host_select, FdSet, RawFd, platform::FD_SETSIZE}; use std::io; #[derive(Traverse)] @@ -81,8 +81,22 @@ mod decl { let seq2set = |list: &PyObject| -> PyResult<(Vec, FdSet)> { let v: Vec = list.try_to_value(vm)?; + + let too_many_fds = cfg_select! { + windows => v.len() > FD_SETSIZE as usize, + _ => v.len() > FD_SETSIZE, + }; + if too_many_fds { + return Err(vm.new_value_error("too many file descriptors in select()")); + } + let mut fds = FdSet::new(); for fd in &v { + #[cfg(unix)] + if fd.fno as usize >= FD_SETSIZE { + return Err(vm.new_value_error("file descriptor out of range in select()")); + } + fds.insert(fd.fno); } Ok((v, fds)) @@ -97,11 +111,15 @@ mod decl { return Ok((empty.clone(), empty.clone(), empty)); } - let nfds: i32 = [&mut r, &mut w, &mut x] - .iter_mut() - .filter_map(|set| set.highest()) - .max() - .map_or(0, |n| n + 1) as _; + let nfds = cfg_select! { + windows => 0, // value is ignored on windows + + _ => [&mut r, &mut w, &mut x] + .iter_mut() + .filter_map(|set| set.highest()) + .max() + .map_or(0, |n| n + 1) as _, + }; loop { let mut tv = timeout.map(host_select::sec_to_timeval); diff --git a/extra_tests/snippets/stdlib_select.py b/extra_tests/snippets/stdlib_select.py index d27bb82b1c..9afd95beca 100644 --- a/extra_tests/snippets/stdlib_select.py +++ b/extra_tests/snippets/stdlib_select.py @@ -4,6 +4,8 @@ from testutils import assert_raises +TOO_MANY_SELECT_FDS = 4096 + class Nope: pass @@ -42,3 +44,26 @@ def fileno(self): assert recvr in rres assert sendr in wres + +# Too many descriptors for select.select() +if sys.platform != "win32": + import resource + + soft_max_fds, hard_max_fds = resource.getrlimit(resource.RLIMIT_NOFILE) + if soft_max_fds != resource.RLIM_INFINITY: + # 100 additional fds should be enough for interpreter needs + need_fds = TOO_MANY_SELECT_FDS + 100 + + soft_max_fds = max(soft_max_fds, need_fds) + if hard_max_fds != resource.RLIM_INFINITY: + assert hard_max_fds >= soft_max_fds, ( + "Not enough file descriptors for this test" + ) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_max_fds, hard_max_fds)) +sockets = [s for _ in range(TOO_MANY_SELECT_FDS // 2) for s in socket.socketpair()] +assert_raises(ValueError, select.select, sockets, [], [], 0) +del sockets +a, b = socket.socketpair() +# CPython disallows this on *nix systems too. +assert_raises(ValueError, select.select, [a] * TOO_MANY_SELECT_FDS, [], [], 0) +del a, b From c513b923dff332d3694fc10483106a8410bba15b Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Fri, 22 May 2026 14:10:49 +0300 Subject: [PATCH 05/18] Update `test_socketserver.py` to 3.14.5 (#7949) --- Lib/test/test_socketserver.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 6235c8e74c..2ca356606b 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -43,7 +43,7 @@ def receive(sock, n, timeout=test.support.SHORT_TIMEOUT): raise RuntimeError("timed out on %r" % (sock,)) -@test.support.requires_fork() # TODO: RUSTPYTHON, os.fork is currently only supported on Unix-based systems +@test.support.requires_fork() @contextlib.contextmanager def simple_subprocess(testcase): """Tests that a custom child process is not waited on (Issue 1540386)""" @@ -218,12 +218,16 @@ def test_ForkingUDPServer(self): self.dgram_examine) @requires_unix_sockets + @unittest.skipIf(test.support.is_apple_mobile and test.support.on_github_actions, + "gh-140702: Test fails regularly on iOS simulator on GitHub Actions") def test_UnixDatagramServer(self): self.run_server(socketserver.UnixDatagramServer, socketserver.DatagramRequestHandler, self.dgram_examine) @requires_unix_sockets + @unittest.skipIf(test.support.is_apple_mobile and test.support.on_github_actions, + "gh-140702: Test fails regularly on iOS simulator on GitHub Actions") def test_ThreadingUnixDatagramServer(self): self.run_server(socketserver.ThreadingUnixDatagramServer, socketserver.DatagramRequestHandler, From 1a013930a7aef70d80061222c1e0d7811d4a9225 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Fri, 22 May 2026 14:11:27 +0300 Subject: [PATCH 06/18] Update `test_timeout.py` to 3.14.5 (#7950) --- Lib/test/test_timeout.py | 77 +++++++++++++++------------------------- 1 file changed, 29 insertions(+), 48 deletions(-) diff --git a/Lib/test/test_timeout.py b/Lib/test/test_timeout.py index 70a0175d77..967d4ff7e1 100644 --- a/Lib/test/test_timeout.py +++ b/Lib/test/test_timeout.py @@ -5,9 +5,6 @@ from test import support from test.support import socket_helper -# This requires the 'network' resource as given on the regrtest command line. -skip_expected = not support.is_resource_enabled('network') - import time import errno import socket @@ -29,10 +26,8 @@ class CreationTestCase(unittest.TestCase): """Test case for socket.gettimeout() and socket.settimeout()""" def setUp(self): - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - def tearDown(self): - self.sock.close() + self.sock = self.enterContext( + socket.socket(socket.AF_INET, socket.SOCK_STREAM)) def testObjectCreation(self): # Test Socket creation @@ -53,10 +48,10 @@ def testFloatReturnValue(self): def testReturnType(self): # Test return type of gettimeout() self.sock.settimeout(1) - self.assertEqual(type(self.sock.gettimeout()), type(1.0)) + self.assertIs(type(self.sock.gettimeout()), float) self.sock.settimeout(3.9) - self.assertEqual(type(self.sock.gettimeout()), type(1.0)) + self.assertIs(type(self.sock.gettimeout()), float) def testTypeCheck(self): # Test type checking by settimeout() @@ -116,8 +111,6 @@ class TimeoutTestCase(unittest.TestCase): def setUp(self): raise NotImplementedError() - tearDown = setUp - def _sock_operation(self, count, timeout, method, *args): """ Test the specified socket method. @@ -145,19 +138,16 @@ class TCPTimeoutTestCase(TimeoutTestCase): """TCP test case for socket.socket() timeout functions""" def setUp(self): - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock = self.enterContext( + socket.socket(socket.AF_INET, socket.SOCK_STREAM)) self.addr_remote = resolve_address('www.python.org.', 80) - def tearDown(self): - self.sock.close() - - @unittest.skipIf(True, 'need to replace these hosts; see bpo-35518') def testConnectTimeout(self): # Testing connect timeout is tricky: we need to have IP connectivity # to a host that silently drops our packets. We can't simulate this # from Python because it's a function of the underlying TCP/IP stack. - # So, the following Snakebite host has been defined: - blackhole = resolve_address('blackhole.snakebite.net', 56666) + # So, the following port on the pythontest.net host has been defined: + blackhole = resolve_address('pythontest.net', 56666) # Blackhole has been configured to silently drop any incoming packets. # No RSTs (for TCP) or ICMP UNREACH (for UDP/ICMP) will be sent back @@ -169,7 +159,7 @@ def testConnectTimeout(self): # to firewalling or general network configuration. In order to improve # our confidence in testing the blackhole, a corresponding 'whitehole' # has also been set up using one port higher: - whitehole = resolve_address('whitehole.snakebite.net', 56667) + whitehole = resolve_address('pythontest.net', 56667) # This address has been configured to immediately drop any incoming # packets as well, but it does it respectfully with regards to the @@ -183,35 +173,27 @@ def testConnectTimeout(self): # timeframe). # For the records, the whitehole/blackhole configuration has been set - # up using the 'pf' firewall (available on BSDs), using the following: + # up using the 'iptables' firewall, using the following rules: # - # ext_if="bge0" - # - # blackhole_ip="35.8.247.6" - # whitehole_ip="35.8.247.6" - # blackhole_port="56666" - # whitehole_port="56667" - # - # block return in log quick on $ext_if proto { tcp udp } \ - # from any to $whitehole_ip port $whitehole_port - # block drop in log quick on $ext_if proto { tcp udp } \ - # from any to $blackhole_ip port $blackhole_port + # -A INPUT -p tcp --destination-port 56666 -j DROP + # -A INPUT -p udp --destination-port 56666 -j DROP + # -A INPUT -p tcp --destination-port 56667 -j REJECT + # -A INPUT -p udp --destination-port 56667 -j REJECT # + # See https://github.com/python/psf-salt/blob/main/pillar/base/firewall/snakebite.sls + # for the current configuration. skip = True - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - timeout = support.LOOPBACK_TIMEOUT - sock.settimeout(timeout) - try: - sock.connect((whitehole)) - except TimeoutError: - pass - except OSError as err: - if err.errno == errno.ECONNREFUSED: - skip = False - finally: - sock.close() - del sock + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + try: + timeout = support.LOOPBACK_TIMEOUT + sock.settimeout(timeout) + sock.connect((whitehole)) + except TimeoutError: + pass + except OSError as err: + if err.errno == errno.ECONNREFUSED: + skip = False if skip: self.skipTest( @@ -278,10 +260,8 @@ class UDPTimeoutTestCase(TimeoutTestCase): """UDP test case for socket.socket() timeout functions""" def setUp(self): - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - - def tearDown(self): - self.sock.close() + self.sock = self.enterContext( + socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) def testRecvfromTimeout(self): # Test recvfrom() timeout @@ -292,6 +272,7 @@ def testRecvfromTimeout(self): def setUpModule(): support.requires('network') + support.requires_working_socket(module=True) if __name__ == "__main__": From f3b83efceefd11e53aeea29f047dd2a4af25757f Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Fri, 22 May 2026 14:11:57 +0300 Subject: [PATCH 07/18] Update `test_structseq.py` to 3.14.5 (#7951) --- Lib/test/test_structseq.py | 244 +++++++++++++++++++++++++- crates/derive-impl/src/pystructseq.rs | 110 +++++------- 2 files changed, 281 insertions(+), 73 deletions(-) diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index a9fe193028..8ef6dd2fee 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -1,6 +1,12 @@ +import copy +import gc import os +import pickle +import re +import textwrap import time import unittest +from test.support import script_helper class StructSeqTest(unittest.TestCase): @@ -37,7 +43,7 @@ def test_repr(self): # os.stat() gives a complicated struct sequence. st = os.stat(__file__) rep = repr(st) - self.assertTrue(rep.startswith("os.stat_result")) + self.assertStartsWith(rep, "os.stat_result") self.assertIn("st_mode=", rep) self.assertIn("st_ino=", rep) self.assertIn("st_dev=", rep) @@ -81,6 +87,7 @@ def test_fields(self): self.assertEqual(t.n_unnamed_fields, 0) self.assertEqual(t.n_fields, time._STRUCT_TM_ITEMS) + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument dict def test_constructor(self): t = time.struct_time @@ -89,10 +96,72 @@ def test_constructor(self): self.assertRaises(TypeError, t, "123") self.assertRaises(TypeError, t, "123", dict={}) self.assertRaises(TypeError, t, "123456789", dict=None) + self.assertRaises(TypeError, t, seq="123456789", dict={}) + + self.assertEqual(t("123456789"), tuple("123456789")) + self.assertEqual(t("123456789", {}), tuple("123456789")) + self.assertEqual(t("123456789", dict={}), tuple("123456789")) + self.assertEqual(t(sequence="123456789", dict={}), tuple("123456789")) + + self.assertEqual(t("1234567890"), tuple("123456789")) + self.assertEqual(t("1234567890").tm_zone, "0") + self.assertEqual(t("123456789", {"tm_zone": "some zone"}), tuple("123456789")) + self.assertEqual(t("123456789", {"tm_zone": "some zone"}).tm_zone, "some zone") s = "123456789" self.assertEqual("".join(t(s)), s) + @unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message + def test_constructor_with_duplicate_fields(self): + t = time.struct_time + + error_message = re.escape("got duplicate or unexpected field name(s)") + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"tm_zone": "some zone"}) + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"tm_zone": "some zone", "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"error": 0, "tm_zone": "some zone"}) + with self.assertRaisesRegex(TypeError, error_message): + t("1234567890", dict={"error": 0, "tm_zone": "some zone", "tm_mon": 1}) + + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: expected at most 1 arguments, got 2 + def test_constructor_with_duplicate_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + n_visible_fields = os.stat_result.n_sequence_fields + + r = os.stat_result(range(n_visible_fields), {'st_atime': -1.0}) + self.assertEqual(r.st_atime, -1.0) + self.assertEqual(r, tuple(range(n_visible_fields))) + + r = os.stat_result((*range(n_visible_fields), -1.0)) + self.assertEqual(r.st_atime, -1.0) + self.assertEqual(r, tuple(range(n_visible_fields))) + + with self.assertRaisesRegex(TypeError, + re.escape("got duplicate or unexpected field name(s)")): + os.stat_result((*range(n_visible_fields), -1.0), {'st_atime': -1.0}) + + @unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message + def test_constructor_with_unknown_fields(self): + t = time.struct_time + + error_message = re.escape("got duplicate or unexpected field name(s)") + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_year": 0}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_year": 0, "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_zone": "some zone", "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_zone": "some zone", "error": 0}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"error": 0, "tm_zone": "some zone", "tm_mon": 1}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"error": 0}) + with self.assertRaisesRegex(TypeError, error_message): + t("123456789", dict={"tm_zone": "some zone", "error": 0}) + def test_eviltuple(self): class Exc(Exception): pass @@ -106,9 +175,80 @@ def __len__(self): self.assertRaises(Exc, time.struct_time, C()) - def test_reduce(self): + def test_pickling(self): t = time.gmtime() - x = t.__reduce__() + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(t, proto) + t2 = pickle.loads(p) + self.assertEqual(t2.__class__, t.__class__) + self.assertEqual(t2, t) + self.assertEqual(t2.tm_year, t.tm_year) + self.assertEqual(t2.tm_zone, t.tm_zone) + + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: expected at most 1 arguments, got 2 + def test_pickling_with_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + + r = os.stat_result(range(os.stat_result.n_sequence_fields), + {'st_atime': 1.0, 'st_atime_ns': 2.0}) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(r, proto) + r2 = pickle.loads(p) + self.assertEqual(r2.__class__, r.__class__) + self.assertEqual(r2, r) + self.assertEqual(r2.st_mode, r.st_mode) + self.assertEqual(r2.st_atime, r.st_atime) + self.assertEqual(r2.st_atime_ns, r.st_atime_ns) + + def test_copying(self): + n_fields = time.struct_time.n_fields + t = time.struct_time([[i] for i in range(n_fields)]) + + t2 = copy.copy(t) + self.assertEqual(t2.__class__, t.__class__) + self.assertEqual(t2, t) + self.assertEqual(t2.tm_year, t.tm_year) + self.assertEqual(t2.tm_zone, t.tm_zone) + self.assertIs(t2[0], t[0]) + self.assertIs(t2.tm_year, t.tm_year) + + t3 = copy.deepcopy(t) + self.assertEqual(t3.__class__, t.__class__) + self.assertEqual(t3, t) + self.assertEqual(t3.tm_year, t.tm_year) + self.assertEqual(t3.tm_zone, t.tm_zone) + self.assertIsNot(t3[0], t[0]) + self.assertIsNot(t3.tm_year, t.tm_year) + + @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: expected at most 1 arguments, got 2 + def test_copying_with_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + + n_sequence_fields = os.stat_result.n_sequence_fields + r = os.stat_result([[i] for i in range(n_sequence_fields)], + {'st_atime': [1.0], 'st_atime_ns': [2.0]}) + + r2 = copy.copy(r) + self.assertEqual(r2.__class__, r.__class__) + self.assertEqual(r2, r) + self.assertEqual(r2.st_mode, r.st_mode) + self.assertEqual(r2.st_atime, r.st_atime) + self.assertEqual(r2.st_atime_ns, r.st_atime_ns) + self.assertIs(r2[0], r[0]) + self.assertIs(r2.st_mode, r.st_mode) + self.assertIs(r2.st_atime, r.st_atime) + self.assertIs(r2.st_atime_ns, r.st_atime_ns) + + r3 = copy.deepcopy(r) + self.assertEqual(r3.__class__, r.__class__) + self.assertEqual(r3, r) + self.assertEqual(r3.st_mode, r.st_mode) + self.assertEqual(r3.st_atime, r.st_atime) + self.assertEqual(r3.st_atime_ns, r.st_atime_ns) + self.assertIsNot(r3[0], r[0]) + self.assertIsNot(r3.st_mode, r.st_mode) + self.assertIsNot(r3.st_atime, r.st_atime) + self.assertIsNot(r3.st_atime_ns, r.st_atime_ns) def test_extended_getslice(self): # Test extended slicing by comparing with list slicing. @@ -133,6 +273,104 @@ def test_match_args_with_unnamed_fields(self): self.assertEqual(os.stat_result.n_unnamed_fields, 3) self.assertEqual(os.stat_result.__match_args__, expected_args) + def test_copy_replace_all_fields_visible(self): + assert os.times_result.n_unnamed_fields == 0 + assert os.times_result.n_sequence_fields == os.times_result.n_fields + + t = os.times() + + # visible fields + self.assertEqual(copy.replace(t), t) + self.assertIsInstance(copy.replace(t), os.times_result) + self.assertEqual(copy.replace(t, user=1.5), (1.5, *t[1:])) + self.assertEqual(copy.replace(t, system=2.5), (t[0], 2.5, *t[2:])) + self.assertEqual(copy.replace(t, user=1.5, system=2.5), (1.5, 2.5, *t[2:])) + + # unknown fields + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, error=-1) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, user=1, error=-1) + + @unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message + def test_copy_replace_with_invisible_fields(self): + assert time.struct_time.n_unnamed_fields == 0 + assert time.struct_time.n_sequence_fields < time.struct_time.n_fields + + t = time.gmtime(0) + + # visible fields + t2 = copy.replace(t) + self.assertEqual(t2, (1970, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertIsInstance(t2, time.struct_time) + t3 = copy.replace(t, tm_year=2000) + self.assertEqual(t3, (2000, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t3.tm_year, 2000) + t4 = copy.replace(t, tm_mon=2) + self.assertEqual(t4, (1970, 2, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t4.tm_mon, 2) + t5 = copy.replace(t, tm_year=2000, tm_mon=2) + self.assertEqual(t5, (2000, 2, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t5.tm_year, 2000) + self.assertEqual(t5.tm_mon, 2) + + # named invisible fields + self.assertHasAttr(t, 'tm_zone') + with self.assertRaisesRegex(AttributeError, 'readonly attribute'): + t.tm_zone = 'some other zone' + self.assertEqual(t2.tm_zone, t.tm_zone) + self.assertEqual(t3.tm_zone, t.tm_zone) + self.assertEqual(t4.tm_zone, t.tm_zone) + t6 = copy.replace(t, tm_zone='some other zone') + self.assertEqual(t, t6) + self.assertEqual(t6.tm_zone, 'some other zone') + t7 = copy.replace(t, tm_year=2000, tm_zone='some other zone') + self.assertEqual(t7, (2000, 1, 1, 0, 0, 0, 3, 1, 0)) + self.assertEqual(t7.tm_year, 2000) + self.assertEqual(t7.tm_zone, 'some other zone') + + # unknown fields + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, error=2) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, tm_year=2000, error=2) + with self.assertRaisesRegex(TypeError, 'unexpected field name'): + copy.replace(t, tm_zone='some other zone', error=2) + + def test_copy_replace_with_unnamed_fields(self): + assert os.stat_result.n_unnamed_fields > 0 + + r = os.stat_result(range(os.stat_result.n_sequence_fields)) + + error_message = re.escape('__replace__() is not supported') + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, st_mode=1) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, error=2) + with self.assertRaisesRegex(TypeError, error_message): + copy.replace(r, st_mode=1, error=2) + + def test_reference_cycle(self): + # gh-122527: Check that a structseq that's part of a reference cycle + # with its own type doesn't crash. Previously, if the type's dictionary + # was cleared first, the structseq instance would crash in the + # destructor. + script_helper.assert_python_ok("-c", textwrap.dedent(r""" + import time + t = time.gmtime() + type(t).refcyle = t + """)) + + def test_replace_gc_tracked(self): + # Verify that __replace__ results are properly GC-tracked + time_struct = time.gmtime(0) + lst = [] + replaced_struct = time_struct.__replace__(tm_year=lst) + lst.append(replaced_struct) + + self.assertTrue(gc.is_tracked(replaced_struct)) if __name__ == "__main__": unittest.main() diff --git a/crates/derive-impl/src/pystructseq.rs b/crates/derive-impl/src/pystructseq.rs index 4059aba63b..874f85741f 100644 --- a/crates/derive-impl/src/pystructseq.rs +++ b/crates/derive-impl/src/pystructseq.rs @@ -102,12 +102,12 @@ fn parse_fields(input: &mut DeriveInput) -> Result { bail_span!(input, "Only #[pystruct_sequence(...)] form is allowed"); }; - let idents: Vec<_> = l + let idents = l .nested .iter() .filter_map(|n| n.get_ident()) .cloned() - .collect(); + .collect::>(); for ident in idents { match ident.to_string().as_str() { @@ -205,7 +205,7 @@ pub(crate) fn impl_pystruct_sequence_data( let n_unnamed_fields = field_info.n_unnamed_fields(); // Generate field index constants for visible fields (with cfg guards) - let field_indices: Vec<_> = visible_fields + let field_indices = visible_fields .iter() .enumerate() .map(|(i, field)| { @@ -216,78 +216,58 @@ pub(crate) fn impl_pystruct_sequence_data( pub const #const_name: usize = #i; } }) - .collect(); + .collect::>(); // Generate field name entries with cfg guards for named fields - let named_field_names: Vec<_> = named_fields + let named_field_names = named_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { stringify!(#ident), } - } else { - quote! { - #(#cfg_attrs)* - { stringify!(#ident) }, - } + quote! { + #(#cfg_attrs)* + { stringify!(#ident) }, } }) - .collect(); + .collect::>(); // Generate field name entries with cfg guards for skipped fields - let skipped_field_names: Vec<_> = skipped_fields + let skipped_field_names = skipped_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { stringify!(#ident), } - } else { - quote! { - #(#cfg_attrs)* - { stringify!(#ident) }, - } + quote! { + #(#cfg_attrs)* + { stringify!(#ident) }, } }) - .collect(); + .collect::>(); // Generate into_tuple items with cfg guards - let visible_tuple_items: Vec<_> = visible_fields + let visible_tuple_items = visible_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { - ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm), - } - } else { - quote! { - #(#cfg_attrs)* - { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, - } + quote! { + #(#cfg_attrs)* + { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, } }) - .collect(); + .collect::>(); - let skipped_tuple_items: Vec<_> = skipped_fields + let skipped_tuple_items = skipped_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { - ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm), - } - } else { - quote! { - #(#cfg_attrs)* - { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, - } + quote! { + #(#cfg_attrs)* + { ::rustpython_vm::convert::ToPyObject::to_pyobject(self.#ident, vm) }, } }) - .collect(); + .collect::>(); // Generate TryFromObject impl only when try_from_object=true let try_from_object_impl = if try_from_object { @@ -317,44 +297,33 @@ pub(crate) fn impl_pystruct_sequence_data( // Generate try_from_elements trait override only when try_from_object=true let try_from_elements_trait_override = if try_from_object { - let visible_field_inits: Vec<_> = visible_fields + let visible_field_inits = visible_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { #ident: iter.next().unwrap().clone().try_into_value(vm)?, } - } else { - quote! { - #(#cfg_attrs)* - #ident: iter.next().unwrap().clone().try_into_value(vm)?, - } + quote! { + #(#cfg_attrs)* + #ident: iter.next().unwrap().clone().try_into_value(vm)?, } }) - .collect(); - let skipped_field_inits: Vec<_> = skipped_fields + .collect::>(); + + let skipped_field_inits = skipped_fields .iter() .map(|f| { let ident = &f.ident; let cfg_attrs = &f.cfg_attrs; - if cfg_attrs.is_empty() { - quote! { - #ident: match iter.next() { - Some(v) => v.clone().try_into_value(vm)?, - None => vm.ctx.none(), - }, - } - } else { - quote! { - #(#cfg_attrs)* - #ident: match iter.next() { - Some(v) => v.clone().try_into_value(vm)?, - None => vm.ctx.none(), - }, - } + quote! { + #(#cfg_attrs)* + #ident: match iter.next() { + Some(v) => v.clone().try_into_value(vm)?, + None => vm.ctx.none(), + }, } }) - .collect(); + .collect::>(); + quote! { fn try_from_elements( elements: Vec<::rustpython_vm::PyObjectRef>, @@ -426,6 +395,7 @@ impl ItemMeta for PyStructSequenceMeta { fn from_inner(inner: ItemMetaInner) -> Self { Self { inner } } + fn inner(&self) -> &ItemMetaInner { &self.inner } From d3272e752ba5393c2969fbc3039227d2d8922ddb Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sat, 23 May 2026 20:16:03 +0900 Subject: [PATCH 08/18] Align codegen metadata with CPython (#7952) --- .cspell.dict/cpython.txt | 5 + Lib/test/test_compile.py | 2 - Lib/test/test_dis.py | 6 - Lib/test/test_exceptions.py | 1 - Lib/test/test_inspect/test_inspect.py | 1 - Lib/test/test_py_compile.py | 2 - Lib/test/test_strtod.py | 4 - Lib/test/test_sys_settrace.py | 2 - Lib/test/test_unittest/test_async_case.py | 1 - crates/codegen/src/compile.rs | 5423 ++++++++++++++--- crates/codegen/src/ir.rs | 870 ++- crates/codegen/src/symboltable.rs | 63 +- crates/compiler-core/src/bytecode.rs | 5 +- crates/literal/src/float.rs | 152 +- ...code__tests__nested_double_async_with.snap | 16 +- crates/vm/src/builtins/function.rs | 88 +- crates/vm/src/frame.rs | 13 +- crates/vm/src/stdlib/builtins.rs | 41 +- scripts/dis_dump.py | 24 +- 19 files changed, 5522 insertions(+), 1197 deletions(-) diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index 688982cd7d..ffbed52121 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -5,6 +5,7 @@ argtypes asdl asname atopen +atext attro augassign badcert @@ -104,6 +105,7 @@ inlinedepth inplace inpos isbytecode +ishidden ismine ISPOINTER isoctal @@ -113,6 +115,7 @@ keeped kwnames kwonlyarg kwonlyargs +kwonlydefaults lasti libffi linearise @@ -164,6 +167,7 @@ patma peepholer phcount platstdlib +ploc posonlyarg posonlyargs prec @@ -209,6 +213,7 @@ staticbase stginfo storefast stringlib +stringized structseq subkwargs subparams diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py index fd1743e670..052d2bfc04 100644 --- a/Lib/test/test_compile.py +++ b/Lib/test/test_compile.py @@ -1249,7 +1249,6 @@ def get_code_lines(self, code): last_line = line return res - @unittest.expectedFailure # TODO: RUSTPYTHON def test_lineno_attribute(self): def load_attr(): return ( @@ -1294,7 +1293,6 @@ def aug_store_attr(): code_lines = self.get_code_lines(func.__code__) self.assertEqual(lines, code_lines) - @unittest.expectedFailure # TODO: RUSTPYTHON; + [0] def test_line_number_genexp(self): def return_genexp(): diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py index fcd6a6b8be..cedad5a0fb 100644 --- a/Lib/test/test_dis.py +++ b/Lib/test/test_dis.py @@ -1215,7 +1215,6 @@ def test_disassemble_fstring(self): def test_disassemble_with(self): self.do_disassembly_test(_with, dis_with) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_disassemble_asyncwith(self): self.do_disassembly_test(_asyncwith, dis_asyncwith) @@ -1991,26 +1990,22 @@ def test_first_line_set_to_None(self): actual = dis.get_instructions(simple, first_line=None) self.assertInstructionsEqual(list(actual), expected_opinfo_simple) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_outer(self): actual = dis.get_instructions(outer, first_line=expected_outer_line) self.assertInstructionsEqual(list(actual), expected_opinfo_outer) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_nested(self): with captured_stdout(): f = outer() actual = dis.get_instructions(f, first_line=expected_f_line) self.assertInstructionsEqual(list(actual), expected_opinfo_f) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_doubly_nested(self): with captured_stdout(): inner = outer()() actual = dis.get_instructions(inner, first_line=expected_inner_line) self.assertInstructionsEqual(list(actual), expected_opinfo_inner) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_jumpy(self): actual = dis.get_instructions(jumpy, first_line=expected_jumpy_line) self.assertInstructionsEqual(list(actual), expected_opinfo_jumpy) @@ -2314,7 +2309,6 @@ def test_iteration(self): via_generator = list(dis.get_instructions(obj)) self.assertInstructionsEqual(via_object, via_generator) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_explicit_first_line(self): actual = dis.Bytecode(outer, first_line=expected_outer_line) self.assertInstructionsEqual(list(actual), expected_opinfo_outer) diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py index 10010ffa9b..7e79732a3b 100644 --- a/Lib/test/test_exceptions.py +++ b/Lib/test/test_exceptions.py @@ -2245,7 +2245,6 @@ def test_assertion_error_location(self): result = run_script(source) self.assertEqual(result[-3:], expected) - @unittest.expectedFailure # TODO: RUSTPYTHON @force_not_colorized def test_multiline_not_highlighted(self): cases = [ diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py index 512adba281..f7a7c0cc82 100644 --- a/Lib/test/test_inspect/test_inspect.py +++ b/Lib/test/test_inspect/test_inspect.py @@ -237,7 +237,6 @@ class FakePackage: self.assertFalse(inspect.ispackage(FakePackage())) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: False is not true def test_iscoroutine(self): async_gen_coro = async_generator_function_example(1) gen_coro = gen_coroutine_function_example(1) diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py index f00f24204b..c4788f47a0 100644 --- a/Lib/test/test_py_compile.py +++ b/Lib/test/test_py_compile.py @@ -132,7 +132,6 @@ def test_exceptions_propagate(self): finally: os.chmod(self.directory, mode.st_mode) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_coding(self): bad_coding = os.path.join(os.path.dirname(__file__), 'tokenizedata', @@ -198,7 +197,6 @@ def test_invalidation_mode(self): fp.read(), 'test', {}) self.assertEqual(flags, 0b1) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_quiet(self): bad_coding = os.path.join(os.path.dirname(__file__), 'tokenizedata', diff --git a/Lib/test/test_strtod.py b/Lib/test/test_strtod.py index 03c8afa51e..f263b7ab4f 100644 --- a/Lib/test/test_strtod.py +++ b/Lib/test/test_strtod.py @@ -173,7 +173,6 @@ def test_halfway_cases(self): s = '{}e{}'.format(digits, exponent) self.check_strtod(s) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_boundaries(self): # boundaries expressed as triples (n, e, u), where # n*10**e is an approximation to the boundary value and @@ -194,7 +193,6 @@ def test_boundaries(self): u *= 10 e -= 1 - @unittest.expectedFailure # TODO: RUSTPYTHON def test_underflow_boundary(self): # test values close to 2**-1075, the underflow boundary; similar # to boundary_tests, except that the random error doesn't scale @@ -206,7 +204,6 @@ def test_underflow_boundary(self): s = '{}e{}'.format(digits, exponent) self.check_strtod(s) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_bigcomp(self): for ndigs in 5, 10, 14, 15, 16, 17, 18, 19, 20, 40, 41, 50: dig10 = 10**ndigs @@ -284,7 +281,6 @@ def negative_exp(n): self.assertEqual(float(negative_exp(20000)), 1.0) self.assertEqual(float(negative_exp(30000)), 1.0) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_particular(self): # inputs that produced crashes or incorrectly rounded results with # previous versions of dtoa.c, for various reasons diff --git a/Lib/test/test_sys_settrace.py b/Lib/test/test_sys_settrace.py index aa2d54ee16..7eef1290dc 100644 --- a/Lib/test/test_sys_settrace.py +++ b/Lib/test/test_sys_settrace.py @@ -1488,8 +1488,6 @@ def test_jump_in_nested_finally_3(output): output.append(11) output.append(12) - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(5, 11, [2, 4], (ValueError, 'after')) def test_no_jump_over_return_try_finally_in_finally_block(output): try: diff --git a/Lib/test/test_unittest/test_async_case.py b/Lib/test/test_unittest/test_async_case.py index 9b1678caf5..91d45283eb 100644 --- a/Lib/test/test_unittest/test_async_case.py +++ b/Lib/test/test_unittest/test_async_case.py @@ -296,7 +296,6 @@ async def on_cleanup2(self): test.doCleanups() self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup2', 'cleanup1']) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_deprecation_of_return_val_from_test(self): # Issue 41322 - deprecate return of value that is not None from a test class Nothing: diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 6dc9fdd4bd..41dbdac112 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -246,6 +246,8 @@ enum ComprehensionLoopControl { loop_block: BlockIdx, if_cleanup_block: BlockIdx, after_block: BlockIdx, + iter_range: TextRange, + backedge_range: TextRange, is_async: bool, end_async_for_target: BlockIdx, }, @@ -637,6 +639,41 @@ impl Compiler { } } + fn mark_conditional_ifexp_orelse_entry_block(&mut self, block: BlockIdx) { + if block != BlockIdx::NULL { + self.current_code_info().blocks[block.idx()].conditional_ifexp_orelse_entry = true; + } + } + + fn instruction_count_snapshot(&mut self) -> Vec { + self.current_code_info() + .blocks + .iter() + .map(|block| block.instructions.len()) + .collect() + } + + fn mark_new_conditional_jump_locations_since( + &mut self, + snapshot: &[usize], + target: BlockIdx, + range: TextRange, + ) { + let source = self.source_file.to_source_code(); + let location = source.source_location(range.start(), PositionEncoding::Utf8); + let end_location = source.source_location(range.end(), PositionEncoding::Utf8); + for (idx, block) in self.current_code_info().blocks.iter_mut().enumerate() { + let start = snapshot.get(idx).copied().unwrap_or(0); + for instr in block.instructions.iter_mut().skip(start) { + if instr.target == target && ir::is_conditional_jump(&instr.instr) { + instr.location = location; + instr.end_location = end_location; + instr.preserve_tobool_jump_location = true; + } + } + } + } + fn new(opts: CompileOpts, source_file: SourceFile, code_name: &str) -> Self { let module_code = ir::CodeInfo { // CPython convention: top-level module / interactive / @@ -656,7 +693,7 @@ impl Compiler { metadata: ir::CodeUnitMetadata { name: code_name.to_string(), qualname: Some(code_name.to_string()), - consts: IndexSet::default(), + consts: Default::default(), names: IndexSet::default(), varnames: IndexSet::default(), cellvars: IndexSet::default(), @@ -896,6 +933,7 @@ impl Compiler { ast::Expr::ListComp(ast::ExprListComp { generators, .. }) | ast::Expr::SetComp(ast::ExprSetComp { generators, .. }) | ast::Expr::DictComp(ast::ExprDictComp { generators, .. }) + | ast::Expr::Generator(ast::ExprGenerator { generators, .. }) if generators.iter().any(|generator| generator.is_async) => { self.found = true; @@ -1629,6 +1667,7 @@ impl Compiler { fn compile_module_annotation_setup_sequence( &mut self, body: &[ast::Stmt], + loc: TextRange, ) -> CompileResult<()> { let (saved_blocks, saved_current_block) = { let code = self.current_code_info(); @@ -1638,7 +1677,7 @@ impl Compiler { ) }; - let result = self.compile_module_annotate(body); + let result = self.compile_module_annotate(body, Some(loc)); let annotations_blocks = { let code = self.current_code_info(); @@ -1658,6 +1697,7 @@ impl Compiler { if let Some(lower) = &s.lower { self.compile_expression(lower)?; } else { + self.set_source_range(s.range); self.emit_load_const(ConstantData::None); } @@ -1665,6 +1705,7 @@ impl Compiler { if let Some(upper) = &s.upper { self.compile_expression(upper)?; } else { + self.set_source_range(s.range); self.emit_load_const(ConstantData::None); } @@ -2171,8 +2212,14 @@ impl Compiler { /// Load arguments for super() optimization onto the stack /// Stack result: [global_super, class, self] - fn load_args_for_super(&mut self, super_type: &SuperCallType<'_>) -> CompileResult<()> { + fn load_args_for_super( + &mut self, + super_type: &SuperCallType<'_>, + super_name_range: TextRange, + super_call_range: TextRange, + ) -> CompileResult<()> { // 1. Load global super + self.set_source_range(super_name_range); self.compile_name("super", NameUsage::Load)?; match super_type { @@ -2187,6 +2234,7 @@ impl Compiler { SuperCallType::ZeroArg => { // 0-arg: load __class__ cell and first parameter // Load __class__ from cell/free variable + self.set_source_range(super_call_range); let scope = self.get_ref_type("__class__").map_err(|e| self.error(e))?; let idx = match scope { SymbolScope::Cell => self.get_cell_var_index("__class__"), @@ -2211,6 +2259,7 @@ impl Compiler { "super(): no arguments and no first parameter".to_owned(), )) })?; + self.set_source_range(super_call_range); self.compile_name(&first_param, NameUsage::Load)?; } } @@ -2349,7 +2398,7 @@ impl Compiler { } // Initialize u_metadata fields - let (flags, posonlyarg_count, arg_count, kwonlyarg_count) = match scope_type { + let (mut flags, posonlyarg_count, arg_count, kwonlyarg_count) = match scope_type { CompilerScope::Module => (bytecode::CodeFlags::empty(), 0, 0, 0), CompilerScope::Class => (bytecode::CodeFlags::empty(), 0, 0, 0), CompilerScope::Function | CompilerScope::AsyncFunction | CompilerScope::Lambda => ( @@ -2378,13 +2427,30 @@ impl Compiler { ), }; - // Set CO_NESTED for scopes defined inside another function/class/etc. - // (i.e., not at module level) - let flags = if self.code_stack.len() > 1 { + if ste.is_method { + flags |= bytecode::CodeFlags::METHOD; + } + + // CPython sets CO_NESTED from symtable's ste_nested, not merely + // from lexical depth: module-level class methods are CO_METHOD but + // not CO_NESTED. + let mut flags = if ste.is_nested + && matches!( + scope_type, + CompilerScope::Function + | CompilerScope::AsyncFunction + | CompilerScope::Lambda + | CompilerScope::Comprehension + | CompilerScope::Annotation + | CompilerScope::TypeParams + ) { flags | bytecode::CodeFlags::NESTED } else { flags }; + if self.future_annotations { + flags |= bytecode::CodeFlags::FUTURE_ANNOTATIONS; + } // Get private name from parent scope let private = if !self.code_stack.is_empty() { @@ -2404,7 +2470,7 @@ impl Compiler { metadata: ir::CodeUnitMetadata { name: name.to_owned(), qualname: None, // Will be set below - consts: IndexSet::default(), + consts: Default::default(), names: IndexSet::default(), varnames: varname_cache, cellvars: cellvar_cache, @@ -2470,7 +2536,9 @@ impl Compiler { scope_type == CompilerScope::AsyncFunction || self.current_symbol_table().is_generator; if is_gen { emit!(self, Instruction::ReturnGenerator); + self.mark_last_line_only_location(lineno); emit!(self, Instruction::PopTop); + self.mark_last_line_only_location(lineno); } // CPython: LOCATION(lineno, lineno, 0, 0) @@ -2510,6 +2578,8 @@ impl Compiler { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }); } @@ -2529,6 +2599,7 @@ impl Compiler { let i_varnum: oparg::VarNum = u32::try_from(oldindex).expect("too many cellvars").into(); emit!(self, Instruction::MakeCell { i: i_varnum }); + self.set_no_location(); } } @@ -2540,6 +2611,7 @@ impl Compiler { n: u32::try_from(nfrees).expect("too many freevars"), } ); + self.set_no_location(); } } @@ -2568,8 +2640,12 @@ impl Compiler { // enter_scope sets default values based on scope_type, but push_output // allows callers to specify exact values if let Some(info) = self.code_stack.last_mut() { - // Preserve NESTED flag set by enter_scope - info.flags = flags | (info.flags & bytecode::CodeFlags::NESTED); + // Preserve flags computed from the symbol-table context. + info.flags = flags + | (info.flags + & (bytecode::CodeFlags::NESTED + | bytecode::CodeFlags::METHOD + | bytecode::CodeFlags::FUTURE_ANNOTATIONS)); info.metadata.argcount = arg_count; info.metadata.posonlyargcount = posonlyarg_count; info.metadata.kwonlyargcount = kwonlyarg_count; @@ -2643,6 +2719,7 @@ impl Compiler { fn enter_annotation_scope( &mut self, _func_name: &str, + loc: TextRange, ) -> CompileResult> { if !self.push_annotation_symbol_table() { return Ok(None); @@ -2657,6 +2734,7 @@ impl Compiler { in_async_scope: false, }; + self.set_source_range(loc); let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get(); self.enter_scope( @@ -2769,9 +2847,26 @@ impl Compiler { code.fblock.pop().expect("fblock stack underflow") } + fn set_unwind_source_range(&mut self, loc: Option) { + if let Some(range) = loc { + self.set_source_range(range); + } + } + + fn mark_unwind_no_location(&mut self, loc: Option) { + if loc.is_none() { + self.set_no_location(); + } + } + /// Unwind a single fblock, emitting cleanup code /// preserve_tos: if true, preserve the top of stack (e.g., return value) - fn unwind_fblock(&mut self, info: &FBlockInfo, preserve_tos: bool) -> CompileResult<()> { + fn unwind_fblock( + &mut self, + info: &FBlockInfo, + preserve_tos: bool, + loc: &mut Option, + ) -> CompileResult<()> { match info.fb_type { FBlockType::WhileLoop | FBlockType::ExceptionHandler @@ -2785,13 +2880,19 @@ impl Compiler { // When returning from a for-loop, CPython swaps the preserved // value with the iterator and uses POP_TOP for loop cleanup. if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); + self.mark_unwind_no_location(*loc); } FBlockType::TryExcept => { + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); } FBlockType::FinallyTry => { @@ -2804,71 +2905,113 @@ impl Compiler { FBlockType::FinallyEnd => { // codegen_unwind_fblock(FINALLY_END) if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); // exc_value + self.mark_unwind_no_location(*loc); if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopExcept); + self.mark_unwind_no_location(*loc); } FBlockType::With | FBlockType::AsyncWith => { // Stack: [..., exit_func, self_exit, return_value (if preserve_tos)] - self.set_source_range(info.fb_range); + // CPython codegen_unwind_fblock() assigns *ploc = info->fb_loc + // for WITH/ASYNC_WITH cleanup and then makes following unwind + // instructions artificial with *ploc = NO_LOCATION. + *loc = Some(info.fb_range); + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); if preserve_tos { // Rotate return value below the exit pair // [exit_func, self_exit, value] → [value, exit_func, self_exit] + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 3 }); // [value, self_exit, exit_func] + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); // [value, exit_func, self_exit] } // Call exit_func(self_exit, None, None, None) + self.set_unwind_source_range(*loc); self.emit_load_const(ConstantData::None); + self.set_unwind_source_range(*loc); self.emit_load_const(ConstantData::None); + self.set_unwind_source_range(*loc); self.emit_load_const(ConstantData::None); + self.set_unwind_source_range(*loc); emit!(self, Instruction::Call { argc: 3 }); // For async with, await the result if matches!(info.fb_type, FBlockType::AsyncWith) { + self.set_unwind_source_range(*loc); emit!(self, Instruction::GetAwaitable { r#where: 2 }); + self.set_unwind_source_range(*loc); self.emit_load_const(ConstantData::None); let _ = self.compile_yield_from_sequence(true)?; } // Pop the __exit__ result + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); + *loc = None; } FBlockType::HandlerCleanup => { // codegen_unwind_fblock(HANDLER_CLEANUP) if let FBlockDatum::ExceptionName(_) = info.fb_datum { // Named handler: PopBlock for inner SETUP_CLEANUP + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); } if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } // PopBlock for outer SETUP_CLEANUP (ExceptionHandler) + self.set_unwind_source_range(*loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopExcept); + self.mark_unwind_no_location(*loc); // If there's an exception name, clean it up if let FBlockDatum::ExceptionName(ref name) = info.fb_datum { + self.set_unwind_source_range(*loc); self.emit_load_const(ConstantData::None); + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); self.store_name(name)?; + self.mark_unwind_no_location(*loc); + self.set_unwind_source_range(*loc); self.compile_name(name, NameUsage::Delete)?; + self.mark_unwind_no_location(*loc); } } FBlockType::PopValue => { if preserve_tos { + self.set_unwind_source_range(*loc); emit!(self, Instruction::Swap { i: 2 }); + self.mark_unwind_no_location(*loc); } + self.set_unwind_source_range(*loc); emit!(self, Instruction::PopTop); + self.mark_unwind_no_location(*loc); } } Ok(()) @@ -2881,7 +3024,7 @@ impl Compiler { &mut self, preserve_tos: bool, stop_at_loop: bool, - ) -> CompileResult { + ) -> CompileResult> { // Collect the info we need, with indices for FinallyTry blocks #[derive(Clone)] enum UnwindInfo { @@ -2925,15 +3068,17 @@ impl Compiler { } // Process each fblock - let mut unwound_finally = false; + let mut unwind_loc = Some(self.current_source_range); for info in unwind_infos { match info { UnwindInfo::Normal(fblock_info) => { - self.unwind_fblock(&fblock_info, preserve_tos)?; + self.unwind_fblock(&fblock_info, preserve_tos, &mut unwind_loc)?; } UnwindInfo::FinallyTry { body, fblock_idx } => { // codegen_unwind_fblock(FINALLY_TRY) + self.set_unwind_source_range(unwind_loc); emit!(self, PseudoInstruction::PopBlock); + self.mark_unwind_no_location(unwind_loc); // Temporarily remove the FinallyTry fblock so nested return/break/continue // in the finally body won't see it again @@ -2950,7 +3095,7 @@ impl Compiler { } self.compile_statements(&body)?; - unwound_finally = true; + unwind_loc = None; if preserve_tos { self.pop_fblock(FBlockType::PopValue); @@ -2963,7 +3108,7 @@ impl Compiler { } } - Ok(unwound_finally) + Ok(unwind_loc) } // could take impl Into>, but everything is borrowed from ast structs; we never @@ -3120,6 +3265,11 @@ impl Compiler { let size_before = self.code_stack.len(); // Set future_annotations from symbol table (detected during symbol table scan) self.future_annotations = symbol_table.future_annotations; + if self.future_annotations { + self.current_code_info() + .flags + .insert(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + } // Module-level __conditional_annotations__ cell let has_module_cond_ann = Self::scope_needs_conditional_annotations_cell(&symbol_table); @@ -3141,10 +3291,12 @@ impl Compiler { self.emit_resume_for_scope(CompilerScope::Module, 1); emit!(self, PseudoInstruction::AnnotationsPlaceholder); - let (doc, statements) = split_doc(&body.body, &self.opts); + let (doc, statements) = split_doc_with_range(&body.body, &self.opts); + let module_start_loc = self.module_start_location(&body.body); // Handle annotation bookkeeping before the docstring assignment, as // codegen_body() does after _PyCodegen_Module() inserts the prefix set. if Self::find_ann(statements) { + self.set_source_range(module_start_loc); if Self::scope_needs_conditional_annotations_cell(self.current_symbol_table()) { emit!(self, Instruction::BuildSet { count: 0 }); self.store_name("__conditional_annotations__")?; @@ -3155,19 +3307,22 @@ impl Compiler { } } - if let Some(value) = doc { + if let Some((value, range)) = doc { + let saved_range = self.current_source_range; + self.set_source_range(range); self.emit_load_const(ConstantData::Str { value: value.into(), }); let doc = self.name("__doc__"); - emit!(self, Instruction::StoreName { namei: doc }) + emit!(self, Instruction::StoreName { namei: doc }); + self.set_source_range(saved_range); } // Compile all statements self.compile_statements(statements)?; if Self::find_ann(statements) && !self.future_annotations { - self.compile_module_annotation_setup_sequence(statements)?; + self.compile_module_annotation_setup_sequence(statements, module_start_loc)?; } assert_eq!(self.code_stack.len(), size_before); @@ -3187,13 +3342,20 @@ impl Compiler { self.interactive = true; // Set future_annotations from symbol table (detected during symbol table scan) self.future_annotations = symbol_table.future_annotations; + if self.future_annotations { + self.current_code_info() + .flags + .insert(bytecode::CodeFlags::FUTURE_ANNOTATIONS); + } self.symbol_table_stack.push(symbol_table); + let module_start_loc = self.module_start_location(body); self.emit_resume_for_scope(CompilerScope::Module, 1); emit!(self, PseudoInstruction::AnnotationsPlaceholder); // Handle annotations based on future_annotations flag if Self::find_ann(body) { + self.set_source_range(module_start_loc); if self.future_annotations { // PEP 563: Initialize __annotations__ dict emit!(self, Instruction::SetupAnnotations); @@ -3246,7 +3408,7 @@ impl Compiler { }; if Self::find_ann(body) && !self.future_annotations { - self.compile_module_annotation_setup_sequence(body)?; + self.compile_module_annotation_setup_sequence(body, module_start_loc)?; } self.emit_return_value(); @@ -3901,8 +4063,9 @@ impl Compiler { body, orelse, is_async, + range, .. - }) => self.compile_for(target, iter, body, orelse, *is_async)?, + }) => self.compile_for(target, iter, body, orelse, *is_async, *range)?, ast::Stmt::Match(ast::StmtMatch { subject, cases, .. }) => { self.compile_match(subject, cases)? } @@ -3981,11 +4144,14 @@ impl Compiler { type_params.as_deref(), arguments.as_deref(), )?, - ast::Stmt::Assert(ast::StmtAssert { test, msg, .. }) => { + ast::Stmt::Assert(ast::StmtAssert { + test, msg, range, .. + }) => { // if some flag, ignore all assert statements! if self.opts.optimize == 0 { let after_block = self.new_block(); self.compile_jump_if(test, true, after_block)?; + self.set_source_range(*range); emit!( self, Instruction::LoadCommonConstant { @@ -3994,8 +4160,10 @@ impl Compiler { ); if let Some(e) = msg { self.compile_expression(e)?; + self.set_source_range(*range); emit!(self, Instruction::Call { argc: 0 }); } + self.set_source_range(test.range()); emit!( self, Instruction::RaiseVarargs { @@ -4044,10 +4212,7 @@ impl Compiler { match value { Some(v) => { if self.ctx.func == FunctionContext::AsyncFunction - && self - .current_code_info() - .flags - .contains(bytecode::CodeFlags::GENERATOR) + && self.current_symbol_table().is_generator { return Err(self.error_ranged( CodegenErrorType::AsyncReturnValue, @@ -4060,9 +4225,11 @@ impl Compiler { None }; let preserve_tos = folded_constant.is_none(); + let mut return_range = stmt_range; if preserve_tos { self.compile_expression(v)?; } else { + return_range = v.range(); self.set_source_range(v.range()); emit!(self, Instruction::Nop); } @@ -4071,16 +4238,17 @@ impl Compiler { if source.line_index(v.range().start()) != source.line_index(stmt_range.start()) { + return_range = stmt_range; self.set_source_range(stmt_range); emit!(self, Instruction::Nop); } - self.set_source_range(stmt_range); - let unwound_finally = self.unwind_fblock_stack(preserve_tos, false)?; - if !unwound_finally { - self.set_source_range(stmt_range); + self.set_source_range(return_range); + let unwind_loc = self.unwind_fblock_stack(preserve_tos, false)?; + if let Some(loc) = unwind_loc { + self.set_source_range(loc); } match folded_constant { - Some(constant) if unwound_finally => { + Some(constant) if unwind_loc.is_none() => { self.emit_return_const_no_location(constant); } Some(constant) => { @@ -4089,7 +4257,7 @@ impl Compiler { } None => { self.emit_return_value(); - if unwound_finally { + if unwind_loc.is_none() { self.set_no_location(); } } @@ -4099,12 +4267,12 @@ impl Compiler { self.set_source_range(stmt_range); emit!(self, Instruction::Nop); // Unwind fblock stack with preserve_tos=false (no value to preserve) - let unwound_finally = self.unwind_fblock_stack(false, false)?; - if unwound_finally { - self.emit_return_const_no_location(ConstantData::None); - } else { - self.set_source_range(stmt_range); + let unwind_loc = self.unwind_fblock_stack(false, false)?; + if let Some(loc) = unwind_loc { + self.set_source_range(loc); self.emit_return_const(ConstantData::None); + } else { + self.emit_return_const_no_location(ConstantData::None); } } } @@ -4112,7 +4280,12 @@ impl Compiler { let dead = self.new_block(); self.switch_to_block(dead); } - ast::Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { + ast::Stmt::Assign(ast::StmtAssign { + targets, + value, + range, + .. + }) => { let folded_ifexp_assignment = matches!( value.as_ref(), ast::Expr::If(ast::ExprIf { test, .. }) @@ -4144,6 +4317,7 @@ impl Compiler { for (i, target) in targets.iter().enumerate() { if i + 1 != targets.len() { + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); } self.compile_store(target)?; @@ -4157,9 +4331,16 @@ impl Compiler { annotation, value, simple, + range, .. }) => { - self.compile_annotated_assign(target, annotation, value.as_deref(), *simple)?; + self.compile_annotated_assign( + target, + annotation, + value.as_deref(), + *simple, + *range, + )?; // Bare annotations in function scope emit no code; restore // source range so subsequent instructions keep the correct line. if value.is_none() && self.ctx.in_func() { @@ -4178,6 +4359,7 @@ impl Compiler { name, type_params, value, + range, .. }) => { let Some(name) = name.as_name_expr() else { @@ -4207,7 +4389,7 @@ impl Compiler { value: name_string.clone().into(), }); self.compile_type_params(type_params)?; - self.compile_typealias_value_closure(&name_string, value)?; + self.compile_typealias_value_closure(&name_string, value, *range)?; emit!(self, Instruction::BuildTuple { count: 3 }); emit!( self, @@ -4227,7 +4409,7 @@ impl Compiler { value: name_string.clone().into(), }); self.emit_load_const(ConstantData::None); - self.compile_typealias_value_closure(&name_string, value)?; + self.compile_typealias_value_closure(&name_string, value, *range)?; emit!(self, Instruction::BuildTuple { count: 3 }); emit!( self, @@ -4245,32 +4427,43 @@ impl Compiler { } fn compile_delete(&mut self, expression: &ast::Expr) -> CompileResult<()> { - match &expression { - ast::Expr::Name(ast::ExprName { id, .. }) => { - self.compile_name(id.as_str(), NameUsage::Delete)? - } - ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { - self.compile_expression(value)?; - let namei = self.name(attr.as_str()); - emit!(self, Instruction::DeleteAttr { namei }); - } - ast::Expr::Subscript(ast::ExprSubscript { - value, slice, ctx, .. - }) => { - self.compile_subscript(value, slice, *ctx)?; - } - ast::Expr::Tuple(ast::ExprTuple { elts, .. }) - | ast::Expr::List(ast::ExprList { elts, .. }) => { - for element in elts { - self.compile_delete(element)?; + let prev_source_range = self.current_source_range; + self.set_source_range(expression.range()); + let result = (|| -> CompileResult<()> { + match &expression { + ast::Expr::Name(ast::ExprName { id, .. }) => { + self.compile_name(id.as_str(), NameUsage::Delete)? } + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { + self.compile_expression(value)?; + let namei = self.name(attr.as_str()); + self.set_source_range(self.update_start_location_to_match_attr( + expression.range(), + expression.range(), + attr.as_str(), + )); + emit!(self, Instruction::DeleteAttr { namei }); + } + ast::Expr::Subscript(ast::ExprSubscript { + value, slice, ctx, .. + }) => { + self.compile_subscript(value, slice, *ctx)?; + } + ast::Expr::Tuple(ast::ExprTuple { elts, .. }) + | ast::Expr::List(ast::ExprList { elts, .. }) => { + for element in elts { + self.compile_delete(element)?; + } + } + ast::Expr::BinOp(_) | ast::Expr::UnaryOp(_) => { + return Err(self.error(CodegenErrorType::Delete("expression"))); + } + _ => return Err(self.error(CodegenErrorType::Delete(expression.python_name()))), } - ast::Expr::BinOp(_) | ast::Expr::UnaryOp(_) => { - return Err(self.error(CodegenErrorType::Delete("expression"))); - } - _ => return Err(self.error(CodegenErrorType::Delete(expression.python_name()))), - } - Ok(()) + Ok(()) + })(); + self.set_source_range(prev_source_range); + result } fn enter_function(&mut self, name: &str, parameters: &ast::Parameters) -> CompileResult<()> { @@ -4327,7 +4520,8 @@ impl Compiler { /// Apply decorators: each decorator calls the function below it. /// Stack: [dec1, dec2, func] → CALL 0 → [dec1, dec2(func)] → CALL 0 → [dec1(dec2(func))] fn apply_decorators(&mut self, decorator_list: &[ast::Decorator]) { - for _ in decorator_list { + for decorator in decorator_list.iter().rev() { + self.set_source_range(decorator.expression.range()); emit!(self, Instruction::Call { argc: 0 }); } } @@ -4339,6 +4533,8 @@ impl Compiler { name: &str, allow_starred: bool, ) -> CompileResult<()> { + let expr_range = expr.range(); + self.set_source_range(expr_range); self.emit_load_const(ConstantData::Tuple { elements: vec![ConstantData::Integer { value: 1.into() }], }); @@ -4373,6 +4569,7 @@ impl Compiler { if allow_starred && matches!(expr, ast::Expr::Starred(_)) { if let ast::Expr::Starred(starred) = expr { self.compile_expression(&starred.value)?; + self.set_source_range(expr_range); emit!(self, Instruction::UnpackSequence { count: 1 }); } } else { @@ -4380,12 +4577,14 @@ impl Compiler { } // Return value + self.set_source_range(expr_range); emit!(self, Instruction::ReturnValue); // Exit scope and create closure let code = self.exit_scope(); self.ctx = prev_ctx; + self.set_source_range(expr_range); self.make_closure( code, bytecode::MakeFunctionFlags::from([bytecode::MakeFunctionFlag::Defaults]), @@ -4398,7 +4597,9 @@ impl Compiler { &mut self, alias_name: &str, value: &ast::Expr, + alias_range: TextRange, ) -> CompileResult<()> { + self.set_source_range(alias_range); self.emit_load_const(ConstantData::Tuple { elements: vec![ConstantData::Integer { value: 1.into() }], }); @@ -4422,10 +4623,12 @@ impl Compiler { }; self.compile_expression(value)?; + self.set_source_range(alias_range); emit!(self, Instruction::ReturnValue); let code = self.exit_scope(); self.ctx = prev_ctx; + self.set_source_range(alias_range); self.make_closure( code, bytecode::MakeFunctionFlags::from([bytecode::MakeFunctionFlag::Defaults]), @@ -4444,8 +4647,10 @@ impl Compiler { name, bound, default, + range, .. }) => { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -4453,6 +4658,7 @@ impl Compiler { if let Some(expr) = &bound { self.compile_type_param_bound_or_default(expr, name.as_str(), false)?; + self.set_source_range(*range); let intrinsic = if expr.is_tuple_expr() { bytecode::IntrinsicFunction2::TypeVarWithConstraint } else { @@ -4474,6 +4680,7 @@ impl Compiler { name.as_str(), false, )?; + self.set_source_range(*range); emit!( self, Instruction::CallIntrinsic2 { @@ -4482,10 +4689,17 @@ impl Compiler { ); } + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.store_name(name.as_ref())?; } - ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, default, .. }) => { + ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { + name, + default, + range, + .. + }) => { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -4502,6 +4716,7 @@ impl Compiler { name.as_str(), false, )?; + self.set_source_range(*range); emit!( self, Instruction::CallIntrinsic2 { @@ -4510,12 +4725,17 @@ impl Compiler { ); } + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.store_name(name.as_ref())?; } ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { - name, default, .. + name, + default, + range, + .. }) => { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: name.as_str().into(), }); @@ -4533,6 +4753,7 @@ impl Compiler { name.as_str(), true, )?; + self.set_source_range(*range); emit!( self, Instruction::CallIntrinsic2 { @@ -4541,11 +4762,15 @@ impl Compiler { ); } + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.store_name(name.as_ref())?; } }; } + if let Some(first) = type_params.type_params.first() { + self.set_source_range(first.range()); + } emit!( self, Instruction::BuildTuple { @@ -4736,8 +4961,10 @@ impl Compiler { // SETUP_CLEANUP before PUSH_EXC_INFO if let Some(cleanup) = finally_cleanup_block { emit!(self, PseudoInstruction::SetupCleanup { delta: cleanup }); + self.set_no_location(); } emit!(self, Instruction::PushExcInfo); + self.set_no_location(); if let Some(cleanup) = finally_cleanup_block { self.push_fblock(FBlockType::FinallyEnd, cleanup, cleanup)?; } @@ -4761,8 +4988,11 @@ impl Compiler { if let Some(cleanup) = finally_cleanup_block { self.switch_to_block(cleanup); emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); emit!(self, Instruction::PopExcept); + self.set_no_location(); emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); } if preserve_finally_exit_empty_label @@ -4867,6 +5097,7 @@ impl Compiler { if let Some(exc_type) = type_ { self.compile_expression(exc_type)?; + self.set_source_range(*handler_range); emit!(self, Instruction::CheckExcMatch); emit!( self, @@ -4920,12 +5151,17 @@ impl Compiler { ); self.switch_to_block(cleanup_end); + self.set_no_location(); if let Some(alias) = name { self.emit_load_const(ConstantData::None); + self.set_no_location(); self.store_name(alias.as_str())?; + self.set_no_location(); self.compile_name(alias.as_str(), NameUsage::Delete)?; + self.set_no_location(); } emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); self.switch_to_block(handler_normal_exit); } @@ -5074,10 +5310,13 @@ impl Compiler { self.switch_to_block(cleanup); // COPY 3: copy the exception from position 3 emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); // POP_EXCEPT: restore prev_exc as current exception emit!(self, Instruction::PopExcept); + self.set_no_location(); // RERAISE 1: reraise with lasti from stack emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); } // End block - continuation point after try-finally @@ -5307,6 +5546,7 @@ impl Compiler { if let Some(exc_type) = type_ { self.compile_expression(exc_type)?; + self.set_source_range(*handler_range); emit!(self, Instruction::CheckExcMatch); emit!( self, @@ -5708,7 +5948,9 @@ impl Compiler { delta: finally_cleanup_block } ); + self.set_no_location(); emit!(self, Instruction::PushExcInfo); + self.set_no_location(); self.push_fblock( FBlockType::FinallyEnd, finally_cleanup_block, @@ -5721,8 +5963,11 @@ impl Compiler { self.switch_to_block(finally_cleanup_block); emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); emit!(self, Instruction::PopExcept); + self.set_no_location(); emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); self.switch_to_block(exit_block); if preserve_finally_exit_empty_label { @@ -6120,6 +6365,7 @@ impl Compiler { self, PseudoInstruction::JumpNoInterrupt { delta: exit_block } ); + self.set_no_location(); // Restore sub_tables for exception path compilation if let Some(cursor) = sub_table_cursor @@ -6131,9 +6377,11 @@ impl Compiler { // Exception handler path self.switch_to_block(finally_block); emit!(self, Instruction::PushExcInfo); + self.set_no_location(); if let Some(cleanup) = finally_cleanup_block { emit!(self, PseudoInstruction::SetupCleanup { delta: cleanup }); + self.set_no_location(); self.push_fblock(FBlockType::FinallyEnd, cleanup, cleanup)?; } @@ -6141,6 +6389,7 @@ impl Compiler { if finally_cleanup_block.is_some() { emit!(self, PseudoInstruction::PopBlock); + self.set_no_location(); self.pop_fblock(FBlockType::FinallyEnd); } @@ -6150,8 +6399,11 @@ impl Compiler { if let Some(cleanup) = finally_cleanup_block { self.switch_to_block(cleanup); emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); emit!(self, Instruction::PopExcept); + self.set_no_location(); emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); } } @@ -6173,6 +6425,7 @@ impl Compiler { fn compile_default_arguments( &mut self, parameters: &ast::Parameters, + loc: TextRange, ) -> CompileResult { let mut funcflags = bytecode::MakeFunctionFlags::new(); @@ -6188,6 +6441,7 @@ impl Compiler { for default in &defaults { self.compile_expression(default)?; } + self.set_source_range(loc); emit!( self, Instruction::BuildTuple { @@ -6208,11 +6462,13 @@ impl Compiler { if !kw_with_defaults.is_empty() { // Compile kwdefaults and build dict for (arg, default) in &kw_with_defaults { + self.set_source_range(loc); self.emit_load_const(ConstantData::Str { value: self.mangle(arg.name.as_str()).into_owned().into(), }); self.compile_expression(default)?; } + self.set_source_range(loc); emit!( self, Instruction::BuildMap { @@ -6234,10 +6490,8 @@ impl Compiler { body: &[ast::Stmt], is_async: bool, funcflags: bytecode::MakeFunctionFlags, + closure_range: TextRange, ) -> CompileResult<()> { - // Save source range so MAKE_FUNCTION gets the `def` line, not the body's last line - let saved_range = self.current_source_range; - // Always enter function scope self.enter_function(name, parameters)?; self.current_code_info() @@ -6272,6 +6526,11 @@ impl Compiler { } ); self.set_no_location(); + // CPython's codegen_wrap_in_stopiteration_handler() inserts + // SETUP_CLEANUP at instruction-sequence index 0, so after the + // generator prefix is inserted the protected range begins at the + // function-start RESUME. + self.move_last_instruction_before_scope_start_resume(); self.push_fblock(FBlockType::StopIteration, handler_block, handler_block)?; Some(handler_block) } else { @@ -6279,14 +6538,15 @@ impl Compiler { }; // Handle docstring - store in co_consts[0] if present - let (doc_str, body) = split_doc(body, &self.opts); + let (doc_info, body) = split_doc_with_range(body, &self.opts); + let doc_str = doc_info.as_ref().map(|(doc, _)| doc); if let Some(doc) = &doc_str { // Docstring present: store in co_consts[0] and set HAS_DOCSTRING flag self.current_code_info() .metadata .consts .insert_full(ConstantData::Str { - value: doc.to_string().into(), + value: (*doc).to_string().into(), }); self.current_code_info().flags |= bytecode::CodeFlags::HAS_DOCSTRING; } @@ -6329,7 +6589,7 @@ impl Compiler { let code = self.exit_scope(); self.ctx = prev_ctx; - self.set_source_range(saved_range); + self.set_source_range(closure_range); // Create function object with closure self.make_closure(code, funcflags)?; @@ -6348,6 +6608,7 @@ impl Compiler { func_name: &str, parameters: &ast::Parameters, returns: Option<&ast::Expr>, + func_range: TextRange, ) -> CompileResult { let has_signature_annotations = parameters .args @@ -6364,7 +6625,7 @@ impl Compiler { } // Try to enter annotation scope - returns None if no annotation_block exists - let Some(saved_ctx) = self.enter_annotation_scope(func_name)? else { + let Some(saved_ctx) = self.enter_annotation_scope(func_name, func_range)? else { return Ok(false); }; @@ -6395,6 +6656,7 @@ impl Compiler { for param in parameters_iter { if let Some(annotation) = ¶m.annotation { + self.set_source_range(func_range); self.emit_load_const(ConstantData::Str { value: self.mangle(param.name.as_str()).into_owned().into(), }); @@ -6404,6 +6666,7 @@ impl Compiler { // Handle return annotation if let Some(annotation) = returns { + self.set_source_range(func_range); self.emit_load_const(ConstantData::Str { value: "return".into(), }); @@ -6411,6 +6674,7 @@ impl Compiler { } // Build the map and return it + self.set_source_range(func_range); emit!( self, Instruction::BuildMap { @@ -6423,6 +6687,7 @@ impl Compiler { let annotate_code = self.exit_annotation_scope(saved_ctx); // Make a closure from the code object + self.set_source_range(func_range); self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; Ok(true) @@ -6484,9 +6749,55 @@ impl Compiler { annotations } + fn compile_annotation_for_symbol_cursor_only( + &mut self, + annotation: &ast::Expr, + ) -> CompileResult<()> { + let code_stack_len = self.code_stack.len(); + let code_info = self.current_code_info(); + let saved_blocks = code_info.blocks.clone(); + let saved_current_block = code_info.current_block; + let saved_annotations_blocks = code_info.annotations_blocks.clone(); + let saved_metadata = code_info.metadata.clone(); + let saved_static_attributes = code_info.static_attributes.clone(); + let saved_in_inlined_comp = code_info.in_inlined_comp; + let saved_fblock = code_info.fblock.clone(); + let saved_in_conditional_block = code_info.in_conditional_block; + let saved_in_final_with_cleanup_statement = code_info.in_final_with_cleanup_statement; + let saved_in_try_else_orelse = code_info.in_try_else_orelse; + let saved_next_conditional_annotation_index = code_info.next_conditional_annotation_index; + let saved_source_range = self.current_source_range; + + self.do_not_emit_bytecode += 1; + let result = self.compile_annotation(annotation); + self.do_not_emit_bytecode -= 1; + + debug_assert_eq!(self.code_stack.len(), code_stack_len); + let code_info = self.current_code_info(); + code_info.blocks = saved_blocks; + code_info.current_block = saved_current_block; + code_info.annotations_blocks = saved_annotations_blocks; + code_info.metadata = saved_metadata; + code_info.static_attributes = saved_static_attributes; + code_info.in_inlined_comp = saved_in_inlined_comp; + code_info.fblock = saved_fblock; + code_info.in_conditional_block = saved_in_conditional_block; + code_info.in_final_with_cleanup_statement = saved_in_final_with_cleanup_statement; + code_info.in_try_else_orelse = saved_in_try_else_orelse; + code_info.next_conditional_annotation_index = saved_next_conditional_annotation_index; + self.current_source_range = saved_source_range; + + result + } + /// Compile module-level __annotate__ function (PEP 649) /// Returns true if __annotate__ was created and stored - fn compile_module_annotate(&mut self, body: &[ast::Stmt]) -> CompileResult { + fn compile_module_annotate( + &mut self, + body: &[ast::Stmt], + loc: Option, + ) -> CompileResult { + let loc = loc.unwrap_or(self.current_source_range); let annotations = Self::collect_annotations(body); let simple_annotation_count = annotations .iter() @@ -6517,6 +6828,7 @@ impl Compiler { }; // Enter annotation scope for code generation + self.set_source_range(loc); let key = self.symbol_table_stack.len() - 1; let lineno = self.get_source_line_number().get(); self.enter_scope( @@ -6535,6 +6847,7 @@ impl Compiler { // Emit format validation: if format > VALUE_WITH_FAKE_GLOBALS: raise NotImplementedError self.emit_format_validation(); + self.set_source_range(loc); emit!(self, Instruction::BuildMap { count: 0 }); let mut simple_idx = 0usize; @@ -6543,6 +6856,7 @@ impl Compiler { target, annotation, simple, + range, .. } = stmt; let simple_name = if *simple { @@ -6556,10 +6870,7 @@ impl Compiler { if simple_name.is_none() { if !self.future_annotations { - self.do_not_emit_bytecode += 1; - let result = self.compile_annotation(annotation); - self.do_not_emit_bytecode -= 1; - result?; + self.compile_annotation_for_symbol_cursor_only(annotation)?; } continue; } @@ -6568,6 +6879,7 @@ impl Compiler { let name = simple_name.expect("missing simple annotation name"); if has_conditional { + self.set_source_range(*range); self.emit_load_const(ConstantData::Integer { value: simple_idx.into(), }); @@ -6593,10 +6905,12 @@ impl Compiler { } self.compile_annotation(annotation)?; + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 2 }); self.emit_load_const(ConstantData::Str { value: self.mangle(name).into_owned().into(), }); + self.set_source_range(loc); emit!(self, Instruction::StoreSubscr); simple_idx += 1; @@ -6605,6 +6919,7 @@ impl Compiler { } } + self.set_source_range(loc); emit!(self, Instruction::ReturnValue); // Exit annotation scope - pop symbol table, restore to parent's annotation_block, and get code @@ -6624,6 +6939,7 @@ impl Compiler { ); // Make a closure from the code object + self.set_source_range(loc); self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; // Store as __annotate_func__ for classes, __annotate__ for modules @@ -6632,6 +6948,7 @@ impl Compiler { } else { "__annotate__" }; + self.set_source_range(loc); self.store_name(name)?; Ok(true) @@ -6649,14 +6966,19 @@ impl Compiler { is_async: bool, type_params: Option<&ast::TypeParams>, ) -> CompileResult<()> { - // Save the source range of the `def` line before compiling decorators/defaults, - // so that the function code object gets the correct co_firstlineno. - let def_source_range = self.current_source_range; + // CPython's FunctionDef/AsyncFunctionDef LOC(s) starts at the + // definition line even when decorators are present. + let stmt_source_range = self.current_source_range; + let def_source_range = self.decorated_definition_range( + stmt_source_range, + decorator_list, + if is_async { "async def " } else { "def " }, + ); self.prepare_decorators(decorator_list)?; // compile defaults and return funcflags - let funcflags = self.compile_default_arguments(parameters)?; + let funcflags = self.compile_default_arguments(parameters, def_source_range)?; // Restore the `def` line range so that enter_function → push_output → get_source_line_number() // records the `def` keyword's line as co_firstlineno, not the last default-argument line. @@ -6726,13 +7048,21 @@ impl Compiler { // Compile annotations as closure (PEP 649) let mut annotations_flag = bytecode::MakeFunctionFlags::new(); - if self.compile_annotations_closure(name, parameters, returns)? { + if self.compile_annotations_closure(name, parameters, returns, def_source_range)? { annotations_flag.insert(bytecode::MakeFunctionFlag::Annotate); } // Compile function body + self.set_source_range(stmt_source_range); let final_funcflags = funcflags | annotations_flag; - self.compile_function_body(name, parameters, body, is_async, final_funcflags)?; + self.compile_function_body( + name, + parameters, + body, + is_async, + final_funcflags, + def_source_range, + )?; // Handle type params if present if is_generic { @@ -6786,6 +7116,7 @@ impl Compiler { self.apply_decorators(decorator_list); // Store the function + self.set_source_range(def_source_range); self.store_name(name)?; Ok(()) @@ -7058,7 +7389,9 @@ impl Compiler { self.code_stack.last_mut().unwrap().private = Some(name.to_owned()); // 2. Set up class namespace - let (doc_str, body) = split_doc(body, &self.opts); + let (doc_str, body) = split_doc_with_range(body, &self.opts); + let class_body_prefix_range = self.source_line_start_range(firstlineno); + self.set_source_range(class_body_prefix_range); // Load __name__ and store as __module__ self.load_name("__name__")?; @@ -7104,16 +7437,19 @@ impl Compiler { } // Store __doc__ only if there's an explicit docstring. - if let Some(doc) = doc_str { + if let Some((doc, range)) = doc_str { + let saved_range = self.current_source_range; + self.set_source_range(range); self.emit_load_const(ConstantData::Str { value: doc.into() }); self.store_name("__doc__")?; + self.set_source_range(saved_range); } // 3. Compile the class body self.compile_statements(body)?; if Self::find_ann(body) && !self.future_annotations { - self.compile_module_annotate(body)?; + self.compile_module_annotate(body, Some(class_body_prefix_range))?; } // 4. Handle __classcell__ if needed @@ -7190,6 +7526,11 @@ impl Compiler { type_params: Option<&ast::TypeParams>, arguments: Option<&ast::Arguments>, ) -> CompileResult<()> { + // CPython's ClassDef LOC(s) starts at the class line even when + // decorators are present. + let stmt_source_range = self.current_source_range; + let class_source_range = + self.decorated_definition_range(stmt_source_range, decorator_list, "class "); self.prepare_decorators(decorator_list)?; let is_generic = type_params.is_some(); @@ -7233,6 +7574,7 @@ impl Compiler { // Compile type parameters and store them in the synthetic cell that // generic class bodies close over. self.compile_type_params(type_params.unwrap())?; + self.set_source_range(class_source_range); self.store_name(".type_params")?; } @@ -7246,6 +7588,7 @@ impl Compiler { }; let class_code = self.compile_class_body(name, body, type_params, firstlineno)?; self.ctx = prev_ctx; + self.set_source_range(class_source_range); // Step 3: Generate the rest of the code for the call if is_generic { @@ -7260,6 +7603,7 @@ impl Compiler { // Create .generic_base after the class function and name are on the // stack so the remaining call shape matches CPython's ordering. + self.set_source_range(class_source_range); self.load_name(".type_params")?; emit!( self, @@ -7267,6 +7611,7 @@ impl Compiler { func: bytecode::IntrinsicFunction1::SubscriptGeneric } ); + self.set_source_range(class_source_range); self.store_name(".generic_base")?; // Compile bases and call __build_class__ @@ -7303,10 +7648,13 @@ impl Compiler { } // Add .generic_base as final element + self.set_source_range(class_source_range); self.load_name(".generic_base")?; + self.set_source_range(class_source_range); emit!(self, Instruction::ListAppend { i: 1 }); // Convert list to tuple + self.set_source_range(class_source_range); emit!( self, Instruction::CallIntrinsic1 { @@ -7316,6 +7664,7 @@ impl Compiler { self.compile_call_function_ex_keywords( arguments.map_or(&[][..], |args| &args.keywords[..]), + class_source_range, )?; emit!(self, Instruction::CallFunctionEx); } else if has_double_star { @@ -7324,7 +7673,9 @@ impl Compiler { self.compile_expression(arg)?; } } + self.set_source_range(class_source_range); self.load_name(".generic_base")?; + self.set_source_range(class_source_range); emit!( self, Instruction::BuildTuple { @@ -7332,7 +7683,10 @@ impl Compiler { .map_or(0, |args| u32::try_from(args.args.len()).unwrap()) } ); - self.compile_call_function_ex_keywords(&arguments.unwrap().keywords[..])?; + self.compile_call_function_ex_keywords( + &arguments.unwrap().keywords[..], + class_source_range, + )?; emit!(self, Instruction::CallFunctionEx); } else { // Simple case: no starred bases, no **kwargs @@ -7347,6 +7701,7 @@ impl Compiler { }; // Load .generic_base as the last base + self.set_source_range(class_source_range); self.load_name(".generic_base")?; let nargs = 2 + u32::try_from(base_count).expect("too many base classes") + 1; @@ -7365,9 +7720,11 @@ impl Compiler { }); self.compile_expression(&keyword.value)?; } + self.set_source_range(class_source_range); self.emit_load_const(ConstantData::Tuple { elements: kwarg_names, }); + self.set_source_range(class_source_range); emit!( self, Instruction::CallKw { @@ -7377,11 +7734,13 @@ impl Compiler { } ); } else { + self.set_source_range(class_source_range); emit!(self, Instruction::Call { argc: nargs }); } } // Return the created class + self.set_source_range(class_source_range); self.emit_return_value(); // Exit type params scope and wrap in function @@ -7389,8 +7748,11 @@ impl Compiler { self.ctx = saved_ctx; // Execute the type params function + self.set_source_range(class_source_range); self.make_closure(type_params_code, bytecode::MakeFunctionFlags::new())?; + self.set_source_range(class_source_range); emit!(self, Instruction::PushNull); + self.set_source_range(class_source_range); emit!(self, Instruction::Call { argc: 0 }); } else { // Non-generic class: standard path @@ -7402,14 +7764,16 @@ impl Compiler { self.emit_load_const(ConstantData::Str { value: name.into() }); if let Some(arguments) = arguments { - self.codegen_call_helper(2, arguments, self.current_source_range)?; + self.codegen_call_helper(2, arguments, class_source_range, None)?; } else { + self.set_source_range(class_source_range); emit!(self, Instruction::Call { argc: 2 }); } } // Step 4: Apply decorators and store (common to both paths) self.apply_decorators(decorator_list); + self.set_source_range(class_source_range); self.store_name(name) } @@ -7879,9 +8243,14 @@ impl Compiler { // to be in the exception table for these instructions. // If we cleared fblock, exceptions here would propagate uncaught. self.switch_to_block(cleanup_block); + // CPython codegen_with_except_finish() emits POP_EXCEPT_AND_RERAISE + // with NO_LOCATION at this cleanup label. emit!(self, Instruction::Copy { i: 3 }); + self.set_no_location(); emit!(self, Instruction::PopExcept); + self.set_no_location(); emit!(self, Instruction::Reraise { depth: 1 }); + self.set_no_location(); // ===== After block ===== self.switch_to_block(after_block); @@ -7903,6 +8272,7 @@ impl Compiler { body: &[ast::Stmt], orelse: &[ast::Stmt], is_async: bool, + for_range: TextRange, ) -> CompileResult<()> { self.enter_conditional_block(); @@ -7935,9 +8305,11 @@ impl Compiler { if self.ctx.func != FunctionContext::AsyncFunction { return Err(self.error(CodegenErrorType::InvalidAsyncFor)); } + self.set_source_range(iter.range()); emit!(self, Instruction::GetAiter); self.switch_to_block(for_block); + self.set_source_range(for_range); // codegen_async_for: push fblock BEFORE SETUP_FINALLY self.push_fblock(FBlockType::ForLoop, for_block, after_block)?; @@ -8066,25 +8438,35 @@ impl Compiler { && elts.len() <= usize::try_from(STACK_USE_GUIDELINE).unwrap() && !elts.iter().any(|e| matches!(e, ast::Expr::Starred(_))) { - if let Some(folded) = self.try_fold_constant_collection(elts, CollectionType::List)? { - self.emit_load_const(folded); - } else { - for elt in elts { - self.compile_expression(elt)?; - } - emit!( - self, - Instruction::BuildTuple { - count: u32::try_from(elts.len()).expect("too many elements"), - } - ); + for elt in elts { + self.compile_expression(elt)?; } + self.set_source_range(iter.range()); + emit!( + self, + Instruction::BuildList { + count: u32::try_from(elts.len()).expect("too many elements"), + } + ); return Ok(()); } self.compile_expression(iter) } + fn compile_comprehension_iter(&mut self, generator: &ast::Comprehension) -> CompileResult<()> { + let saved_range = self.current_source_range; + self.compile_for_iterable_expression(&generator.iter, generator.is_async)?; + self.set_source_range(generator.iter.range()); + if generator.is_async { + emit!(self, Instruction::GetAiter); + } else { + emit!(self, Instruction::GetIter); + } + self.set_source_range(saved_range); + Ok(()) + } + fn singleton_comprehension_assignment_iter(iter: &ast::Expr) -> Option<&ast::Expr> { let elts = match iter { ast::Expr::List(ast::ExprList { elts, .. }) => elts, @@ -8476,6 +8858,7 @@ impl Compiler { // Compile the class expression. self.compile_expression(&match_class.cls)?; + self.set_source_range(p.range); // Create a new tuple of attribute names. let mut attr_names = vec![]; @@ -8652,8 +9035,9 @@ impl Compiler { seen.insert(key_repr); } - self.compile_expression(key)?; + self.compile_match_pattern_expr(key)?; } + self.set_source_range(p.range); } // Stack: [subject, key1, key2, ..., key_n] @@ -8781,6 +9165,7 @@ impl Compiler { pc.fail_pop.clear(); pc.on_top = 0; // Emit a COPY(1) instruction before compiling the alternative. + self.set_source_range(alt.range()); emit!(self, Instruction::Copy { i: 1 }); self.compile_pattern(alt, pc)?; @@ -8954,7 +9339,7 @@ impl Compiler { // Match CPython codegen_pattern_value(): compare, then normalize to bool // before the fail jump. Late IR folding will collapse COMPARE_OP+TO_BOOL // into COMPARE_OP bool(...) when applicable. - self.compile_expression(&p.value)?; + self.compile_match_pattern_expr(&p.value)?; emit!( self, Instruction::CompareOp { @@ -9085,6 +9470,7 @@ impl Compiler { for (i, m) in cases.iter().enumerate().take(case_count) { // Only copy the subject if not on the last case if i != case_count - 1 { + self.set_source_range(m.pattern.range()); emit!(self, Instruction::Copy { i: 1 }); } @@ -9139,6 +9525,7 @@ impl Compiler { } else { emit!(self, PseudoInstruction::Jump { delta: end }); } + self.set_no_location(); if let Some(last) = self.current_block().instructions.last_mut() { last.match_success_jump = true; } @@ -9381,6 +9768,7 @@ impl Compiler { fn compile_annotation(&mut self, annotation: &ast::Expr) -> CompileResult<()> { if self.future_annotations { + self.set_source_range(annotation.range()); self.emit_load_const(ConstantData::Str { value: UnparseExpr::new(annotation, &self.source_file) .to_string() @@ -9445,6 +9833,7 @@ impl Compiler { annotation: &ast::Expr, value: Option<&ast::Expr>, simple: bool, + loc: TextRange, ) -> CompileResult<()> { // Perform the actual assignment first if let Some(value) = value { @@ -9461,6 +9850,7 @@ impl Compiler { // PEP 563: Store stringified annotation directly to __annotations__ // Compile annotation as string self.compile_annotation(annotation)?; + self.set_source_range(loc); // Load __annotations__ let annotations_name = self.name("__annotations__"); emit!( @@ -9470,10 +9860,12 @@ impl Compiler { } ); // Load the variable name + self.set_source_range(loc); self.emit_load_const(ConstantData::Str { value: self.mangle(id.as_str()).into_owned().into(), }); // Store: __annotations__[name] = annotation + self.set_source_range(loc); emit!(self, Instruction::StoreSubscr); } else { // PEP 649: Handle conditional annotations @@ -9535,6 +9927,11 @@ impl Compiler { ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { self.maybe_add_static_attribute_to_class(value, attr.as_str()); self.compile_expression(value)?; + self.set_source_range(self.update_start_location_to_match_attr( + target.range(), + target.range(), + attr.as_str(), + )); let namei = self.name(attr.as_str()); emit!(self, Instruction::StoreAttr { namei }); } @@ -9603,15 +10000,25 @@ impl Compiler { op: &ast::Operator, value: &ast::Expr, ) -> CompileResult<()> { + let stmt_range = self.current_source_range; + let target_range = target.range(); enum AugAssignKind<'a> { - Name { id: &'a str }, - Subscript { use_slice_opt: bool }, - Attr { idx: bytecode::NameIdx }, + Name { + id: &'a str, + }, + Subscript { + use_slice_opt: bool, + }, + Attr { + idx: bytecode::NameIdx, + attr_range: TextRange, + }, } let kind = match &target { ast::Expr::Name(ast::ExprName { id, .. }) => { let id = id.as_str(); + self.set_source_range(target_range); self.compile_name(id, NameUsage::Load)?; AugAssignKind::Name { id } } @@ -9623,6 +10030,7 @@ impl Compiler { }) => { let use_slice_opt = slice.should_use_slice_optimization(); self.compile_expression(value)?; + self.set_source_range(target_range); if use_slice_opt { let ast::Expr::Slice(slice_expr) = slice.as_ref() else { unreachable!( @@ -9630,12 +10038,14 @@ impl Compiler { ); }; self.compile_slice_two_parts(slice_expr)?; + self.set_source_range(target_range); emit!(self, Instruction::Copy { i: 3 }); emit!(self, Instruction::Copy { i: 3 }); emit!(self, Instruction::Copy { i: 3 }); emit!(self, Instruction::BinarySlice); } else { self.compile_expression(slice)?; + self.set_source_range(target_range); emit!(self, Instruction::Copy { i: 2 }); emit!(self, Instruction::Copy { i: 2 }); emit!( @@ -9650,10 +10060,14 @@ impl Compiler { ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { let attr = attr.as_str(); self.compile_expression(value)?; + let attr_range = + self.update_start_location_to_match_attr(target_range, target_range, attr); + self.set_source_range(attr_range); emit!(self, Instruction::Copy { i: 1 }); let idx = self.name(attr); + self.set_source_range(attr_range); self.emit_load_attr(idx); - AugAssignKind::Attr { idx } + AugAssignKind::Attr { idx, attr_range } } _ => { return Err(self.error(CodegenErrorType::Assign(target.python_name()))); @@ -9661,14 +10075,17 @@ impl Compiler { }; self.compile_expression(value)?; + self.set_source_range(stmt_range); self.compile_op(op, true); match kind { AugAssignKind::Name { id } => { // stack: RESULT + self.set_source_range(target_range); self.compile_name(id, NameUsage::Store)?; } AugAssignKind::Subscript { use_slice_opt } => { + self.set_source_range(target_range); if use_slice_opt { // stack: CONTAINER START STOP RESULT emit!(self, Instruction::Swap { i: 4 }); @@ -9682,8 +10099,9 @@ impl Compiler { emit!(self, Instruction::StoreSubscr); } } - AugAssignKind::Attr { idx } => { + AugAssignKind::Attr { idx, attr_range } => { // stack: CONTAINER RESULT + self.set_source_range(attr_range); emit!(self, Instruction::Swap { i: 2 }); emit!(self, Instruction::StoreAttr { namei: idx }); } @@ -9766,6 +10184,10 @@ impl Compiler { self.compile_jump_if_inner(body, condition, target_block, source_range)?; emit!(self, PseudoInstruction::JumpNoInterrupt { delta: end }); self.set_no_location(); + // CPython emits this jump with NO_LOCATION in codegen_jump_if() + // and flowgraph.c::propagate_line_numbers() copies the previous + // body-expression location onto it. + self.copy_previous_location_to_last_instruction(); self.switch_to_block(next2); self.compile_jump_if_inner(orelse, condition, target_block, source_range)?; @@ -9793,6 +10215,12 @@ impl Compiler { && matches!(&comparators[0], ast::Expr::NoneLiteral(_)) => { self.compile_expression(left)?; + // CPython codegen first emits LOAD_CONST None; IS_OP; POP_JUMP... + // and flowgraph.c::basicblock_optimize_load_const folds it into + // POP_JUMP_IF_NONE / POP_JUMP_IF_NOT_NONE. Register None here + // to preserve CPython's co_consts ordering even though we emit + // the folded jump directly. + self.arg_constant(ConstantData::None); let source = self.source_file.to_source_code(); let comparator_line = source.line_index(comparators[0].range().start()); let left_line = source.line_index(left.range().start()); @@ -9807,6 +10235,7 @@ impl Compiler { // is not None + jump_if_false → POP_JUMP_IF_NONE // is not None + jump_if_true → POP_JUMP_IF_NOT_NONE let jump_if_none = condition != is_not; + self.set_source_range(source_range.unwrap_or_else(|| expression.range())); if jump_if_none { emit!( self, @@ -9831,6 +10260,7 @@ impl Compiler { self.disable_load_fast_borrow_for_block(target_block); } self.compile_expression(expression)?; + self.set_source_range(expression.range()); emit!(self, Instruction::ToBool); if condition { emit!( @@ -9867,39 +10297,48 @@ impl Compiler { /// Compile a boolean operation as an expression. /// This means, that the last value remains on the stack. fn compile_bool_op(&mut self, op: &ast::BoolOp, values: &[ast::Expr]) -> CompileResult<()> { + let boolop_range = self.current_source_range; fn flatten_same_boolop_values<'a>( op: &ast::BoolOp, - value: &'a ast::Expr, - out: &mut Vec<&'a ast::Expr>, + values: &'a [ast::Expr], + current_range: TextRange, + outer_pop_range: Option, + out: &mut Vec<(&'a ast::Expr, Option)>, ) { - if let ast::Expr::BoolOp(ast::ExprBoolOp { - op: inner_op, - values, - .. - }) = value - && inner_op == op - { - for value in values { - flatten_same_boolop_values(op, value, out); + for (idx, value) in values.iter().enumerate() { + let is_last = idx + 1 == values.len(); + let pop_range = if is_last { + outer_pop_range + } else { + Some(current_range) + }; + if let ast::Expr::BoolOp(ast::ExprBoolOp { + op: inner_op, + values, + .. + }) = value + && inner_op == op + { + flatten_same_boolop_values(op, values, value.range(), pop_range, out); + } else { + out.push((value, pop_range)); } - } else { - out.push(value); } } let mut flattened = Vec::with_capacity(values.len()); - for value in values { - flatten_same_boolop_values(op, value, &mut flattened); - } + flatten_same_boolop_values(op, values, boolop_range, None, &mut flattened); let after_block = self.new_block(); - let (last_value, prefix_values) = flattened.split_last().unwrap(); + let ((last_value, _), prefix_values) = flattened.split_last().unwrap(); - for value in prefix_values { + for &(value, pop_range) in prefix_values { let continue_block = self.new_block(); self.compile_expression(value)?; + self.set_source_range(boolop_range); self.emit_short_circuit_test(op, after_block); self.switch_to_block(continue_block); + self.set_source_range(pop_range.expect("prefix boolop value must have pop range")); emit!(self, Instruction::PopTop); } @@ -9908,28 +10347,6 @@ impl Compiler { Ok(()) } - fn compile_bool_op_with_head_constant( - &mut self, - op: &ast::BoolOp, - head: ConstantData, - tail: &[ast::Expr], - ) -> CompileResult<()> { - self.emit_load_const(head); - self.mark_last_instruction_folded_from_nonliteral_expr(); - if tail.is_empty() { - return Ok(()); - } - - let after_block = self.new_block(); - for value in tail { - self.emit_short_circuit_test(op, after_block); - emit!(self, Instruction::PopTop); - self.compile_expression(value)?; - } - self.switch_to_block(after_block); - Ok(()) - } - /// Emit `Copy 1` + conditional jump for short-circuit evaluation. /// For `And`, emits `PopJumpIfFalse`; for `Or`, emits `PopJumpIfTrue`. fn emit_short_circuit_test(&mut self, op: &ast::BoolOp, target: BlockIdx) { @@ -9945,7 +10362,7 @@ impl Compiler { } } - fn compile_dict(&mut self, items: &[ast::DictItem]) -> CompileResult<()> { + fn compile_dict(&mut self, items: &[ast::DictItem], range: TextRange) -> CompileResult<()> { let has_unpacking = items.iter().any(|item| item.key.is_none()); if !has_unpacking { @@ -9961,6 +10378,7 @@ impl Compiler { self.compile_expression(item.key.as_ref().unwrap())?; self.compile_expression(&item.value)?; } + self.set_source_range(range); emit!( self, Instruction::BuildMap { @@ -9987,11 +10405,13 @@ impl Compiler { (total_map_add, 0usize) }; + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); let mut idx = 0; for chunk_i in 0..big_count { if chunk_i > 0 { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); } let chunk_size = if idx + BIG_MAP_CHUNK <= n - tail_count { @@ -10002,9 +10422,11 @@ impl Compiler { for item in &items[idx..idx + chunk_size] { self.compile_expression(item.key.as_ref().unwrap())?; self.compile_expression(&item.value)?; + self.set_source_range(range); emit!(self, Instruction::MapAdd { i: 1 }); } if chunk_i > 0 { + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } idx += chunk_size; @@ -10016,12 +10438,14 @@ impl Compiler { self.compile_expression(item.key.as_ref().unwrap())?; self.compile_expression(&item.value)?; } + self.set_source_range(range); emit!( self, Instruction::BuildMap { count: tail_count.to_u32(), } ); + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } } @@ -10040,8 +10464,10 @@ impl Compiler { () => { #[allow(unused_assignments)] if elements > 0 { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: elements }); if have_dict { + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } else { have_dict = true; @@ -10061,16 +10487,19 @@ impl Compiler { // ** unpacking entry flush_pending!(); if !have_dict { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); have_dict = true; } self.compile_expression(&item.value)?; + self.set_source_range(range); emit!(self, Instruction::DictUpdate { i: 1 }); } } flush_pending!(); if !have_dict { + self.set_source_range(range); emit!(self, Instruction::BuildMap { count: 0 }); } @@ -10162,26 +10591,7 @@ impl Compiler { | ast::Expr::BooleanLiteral(_) | ast::Expr::NoneLiteral(_) | ast::Expr::EllipsisLiteral(_) - ) || matches!(expr, ast::Expr::FString(fstring) if Self::fstring_value_is_const(&fstring.value)) - } - - fn fstring_value_is_const(fstring: &ast::FStringValue) -> bool { - for part in fstring { - if !Self::fstring_part_is_const(part) { - return false; - } - } - true - } - - fn fstring_part_is_const(part: &ast::FStringPart) -> bool { - match part { - ast::FStringPart::Literal(_) => true, - ast::FStringPart::FString(fstring) => fstring - .elements - .iter() - .all(|element| matches!(element, ast::InterpolatedStringElement::Literal(_))), - } + ) } fn compile_expression(&mut self, expression: &ast::Expr) -> CompileResult<()> { @@ -10189,16 +10599,6 @@ impl Compiler { let range = expression.range(); self.set_source_range(range); - if let ast::Expr::Subscript(ast::ExprSubscript { - ctx: ast::ExprContext::Load, - .. - }) = expression - && let Some(constant) = self.try_fold_constant_expr(expression)? - { - self.emit_load_const(constant); - return Ok(()); - } - if matches!(expression, ast::Expr::BinOp(_)) && let Some(constant) = self.try_fold_constant_expr(expression)? { @@ -10211,25 +10611,31 @@ impl Compiler { { let mut simplified_prefix = 0usize; let mut last_constant = None; - let mut retained_head = None; + let mut last_constant_range = None; for value in values { let Some(constant) = self.try_fold_constant_expr(value)? else { break; }; if !Self::boolop_fast_fold_literal(value) { - retained_head = Some(constant); - simplified_prefix += 1; break; } + // CPython codegen_boolop() emits each literal with + // ADDOP_LOAD_CONST before flowgraph.c folds the constant + // branch away. Register it here so remove_unused_consts() + // preserves the same first-constant ordering. + self.arg_constant(constant.clone()); let is_truthy = Self::constant_truthiness(&constant); last_constant = Some(constant); + last_constant_range = Some(value.range()); match op { ast::BoolOp::Or if is_truthy => { + self.set_source_range(last_constant_range.expect("missing boolop range")); self.emit_load_const(last_constant.expect("missing boolop constant")); self.mark_last_instruction_folded_from_nonliteral_expr(); return Ok(()); } ast::BoolOp::And if !is_truthy => { + self.set_source_range(last_constant_range.expect("missing boolop range")); self.emit_load_const(last_constant.expect("missing boolop constant")); self.mark_last_instruction_folded_from_nonliteral_expr(); return Ok(()); @@ -10240,11 +10646,8 @@ impl Compiler { } } - if let Some(head) = retained_head { - self.compile_bool_op_with_head_constant(op, head, &values[simplified_prefix..])?; - return Ok(()); - } if simplified_prefix == values.len() { + self.set_source_range(last_constant_range.expect("missing boolop range")); self.emit_load_const(last_constant.expect("missing folded boolop constant")); self.mark_last_instruction_folded_from_nonliteral_expr(); return Ok(()); @@ -10284,10 +10687,25 @@ impl Compiler { self.compile_subscript(value, slice, *ctx)?; } ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, .. }) => { - self.compile_expression(operand)?; - - // Restore full expression range before emitting the operation - self.set_source_range(range); + if let ( + ast::UnaryOp::Not, + ast::Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + .. + }), + ) = (op, operand.as_ref()) + && ops.len() == 1 + { + self.set_source_range(range); + self.compile_compare(left, ops, comparators)?; + } else { + self.compile_expression(operand)?; + } + + // Restore full expression range before emitting the operation + self.set_source_range(range); match op { ast::UnaryOp::UAdd => emit!( self, @@ -10308,11 +10726,14 @@ impl Compiler { if let Some(super_type) = self.can_optimize_super_call(value, attr.as_str()) { // super().attr or super(cls, self).attr optimization // Stack: [global_super, class, self] → LOAD_SUPER_ATTR → [attr] - // Set source range to super() call for arg-loading instructions - let super_range = value.range(); - self.set_source_range(super_range); - self.load_args_for_super(&super_type)?; - self.set_source_range(super_range); + let ast::Expr::Call(ast::ExprCall { + func: super_func, .. + }) = value.as_ref() + else { + unreachable!("can_optimize_super_call only accepts calls"); + }; + self.load_args_for_super(&super_type, super_func.range(), value.range())?; + self.set_source_range(range); let idx = self.name(attr.as_str()); match super_type { SuperCallType::TwoArg { .. } => { @@ -10325,6 +10746,11 @@ impl Compiler { } else { // Normal attribute access self.compile_expression(value)?; + self.set_source_range(self.update_start_location_to_match_attr( + range, + range, + attr.as_str(), + )); let idx = self.name(attr.as_str()); self.emit_load_attr(idx); } @@ -10349,29 +10775,37 @@ impl Compiler { ast::Expr::Set(ast::ExprSet { elts, .. }) => { self.starunpack_helper(elts, 0, CollectionType::Set)?; } - ast::Expr::Dict(ast::ExprDict { items, .. }) => { - self.compile_dict(items)?; + ast::Expr::Dict(ast::ExprDict { items, range, .. }) => { + self.compile_dict(items, *range)?; } ast::Expr::Slice(ast::ExprSlice { - lower, upper, step, .. + lower, + upper, + step, + range, + .. }) => { if let Some(folded_const) = self.try_fold_constant_slice( lower.as_deref(), upper.as_deref(), step.as_deref(), )? { + self.set_source_range(*range); self.emit_load_const(folded_const); return Ok(()); } - let mut compile_bound = |bound: Option<&ast::Expr>| match bound { - Some(exp) => self.compile_expression(exp), - None => { - self.emit_load_const(ConstantData::None); - Ok(()) - } - }; - compile_bound(lower.as_deref())?; - compile_bound(upper.as_deref())?; + if let Some(lower) = lower { + self.compile_expression(lower)?; + } else { + self.set_source_range(*range); + self.emit_load_const(ConstantData::None); + } + if let Some(upper) = upper { + self.compile_expression(upper)?; + } else { + self.set_source_range(*range); + self.emit_load_const(ConstantData::None); + } if let Some(step) = step { self.compile_expression(step)?; } @@ -10379,6 +10813,7 @@ impl Compiler { Some(_) => BuildSliceArgCount::Three, None => BuildSliceArgCount::Two, }; + self.set_source_range(*range); emit!(self, Instruction::BuildSlice { argc }); } ast::Expr::Yield(ast::ExprYield { value, .. }) => { @@ -10390,6 +10825,7 @@ impl Compiler { Some(expression) => self.compile_expression(expression)?, Option::None => self.emit_load_const(ConstantData::None), }; + self.set_source_range(range); if self.ctx.func == FunctionContext::AsyncFunction { emit!( self, @@ -10412,6 +10848,7 @@ impl Compiler { return Err(self.error(CodegenErrorType::InvalidAwait)); } self.compile_expression(value)?; + self.set_source_range(range); emit!(self, Instruction::GetAwaitable { r#where: 0 }); self.emit_load_const(ConstantData::None); let _ = self.compile_yield_from_sequence(true)?; @@ -10428,13 +10865,17 @@ impl Compiler { } self.mark_generator(); self.compile_expression(value)?; + self.set_source_range(range); emit!(self, Instruction::GetYieldFromIter); self.emit_load_const(ConstantData::None); let _ = self.compile_yield_from_sequence(false)?; } ast::Expr::Name(ast::ExprName { id, .. }) => self.load_name(id.as_str())?, ast::Expr::Lambda(ast::ExprLambda { - parameters, body, .. + parameters, + body, + range, + .. }) => { let default_params = ast::Parameters::default(); let params = parameters.as_deref().unwrap_or(&default_params); @@ -10456,6 +10897,7 @@ impl Compiler { for element in &defaults { self.compile_expression(element)?; } + self.set_source_range(*range); emit!(self, Instruction::BuildTuple { count: size }); } @@ -10471,11 +10913,13 @@ impl Compiler { if have_kwdefaults { let default_kw_count = kw_with_defaults.len(); for (arg, default) in &kw_with_defaults { + self.set_source_range(*range); self.emit_load_const(ConstantData::Str { value: self.mangle(arg.name.as_str()).into_owned().into(), }); self.compile_expression(default)?; } + self.set_source_range(*range); emit!( self, Instruction::BuildMap { @@ -10504,13 +10948,21 @@ impl Compiler { in_async_scope: false, }; - // Lambda cannot have docstrings, so no None is added to co_consts - self.compile_expression(body)?; + self.set_source_range(body.range()); self.emit_return_value(); + // _PyCodegen_AddReturnAtEnd() appends a no-location + // return-None epilogue even after lambda's explicit + // RETURN_VALUE. It is later removed as unreachable, but + // remove_unused_consts() keeps None when it was the first + // constant in an otherwise constant-free lambda. + if self.current_code_info().metadata.consts.is_empty() { + self.arg_constant(ConstantData::None); + } let code = self.exit_scope(); // Create lambda function with closure + self.set_source_range(*range); self.make_closure(code, func_flags)?; self.ctx = prev_ctx; @@ -10532,6 +10984,7 @@ impl Compiler { generators, &|compiler, collection_add_i| { compiler.compile_comprehension_element(elt)?; + compiler.set_source_range(elt.range()); emit!( compiler, Instruction::ListAppend { @@ -10543,6 +10996,8 @@ impl Compiler { ComprehensionType::List, Self::contains_await(elt) || Self::generators_contain_await(generators), *range, + elt.range(), + elt.range(), )?; } ast::Expr::SetComp(ast::ExprSetComp { @@ -10562,6 +11017,7 @@ impl Compiler { generators, &|compiler, collection_add_i| { compiler.compile_comprehension_element(elt)?; + compiler.set_source_range(elt.range()); emit!( compiler, Instruction::SetAdd { @@ -10573,6 +11029,8 @@ impl Compiler { ComprehensionType::Set, Self::contains_await(elt) || Self::generators_contain_await(generators), *range, + elt.range(), + elt.range(), )?; } ast::Expr::DictComp(ast::ExprDictComp { @@ -10596,6 +11054,10 @@ impl Compiler { compiler.compile_expression(key)?; compiler.compile_expression(value)?; + compiler.set_source_range(TextRange::new( + key.range().start(), + value.range().end(), + )); emit!( compiler, Instruction::MapAdd { @@ -10610,6 +11072,8 @@ impl Compiler { || Self::contains_await(value) || Self::generators_contain_await(generators), *range, + TextRange::new(key.range().start(), value.range().end()), + key.range(), )?; } ast::Expr::Generator(ast::ExprGenerator { @@ -10618,47 +11082,7 @@ impl Compiler { range, .. }) => { - // Check if element or generators contain async content - // This makes the generator expression into an async generator - let element_contains_await = - Self::contains_await(elt) || Self::generators_contain_await(generators); - self.compile_comprehension( - "", - None, - generators, - &|compiler, _collection_add_i| { - // Compile the element expression - // Note: if element is an async comprehension, compile_expression - // already handles awaiting it, so we don't need to await again here - compiler.compile_comprehension_element(elt)?; - - compiler.mark_generator(); - if compiler.ctx.func == FunctionContext::AsyncFunction { - emit!( - compiler, - Instruction::CallIntrinsic1 { - func: bytecode::IntrinsicFunction1::AsyncGenWrap - } - ); - } - // arg=0: direct yield (wrapped for async generators) - emit!(compiler, Instruction::YieldValue { arg: 0 }); - emit!( - compiler, - Instruction::Resume { - context: oparg::ResumeContext::from( - oparg::ResumeLocation::AfterYield - ) - } - ); - emit!(compiler, Instruction::PopTop); - - Ok(()) - }, - ComprehensionType::Generator, - element_contains_await, - *range, - )?; + self.compile_generator_expression(elt, generators, *range)?; } ast::Expr::Starred(ast::ExprStarred { value, .. }) => { if self.in_annotation { @@ -10682,6 +11106,9 @@ impl Compiler { .map(Self::constant_truthiness); let else_block = self.new_block(); let after_block = self.new_block(); + if self.current_code_info().in_conditional_block > 0 { + self.mark_conditional_ifexp_orelse_entry_block(else_block); + } self.compile_jump_if(test, false, else_block)?; // True case @@ -10721,7 +11148,7 @@ impl Compiler { target, value, node_index: _, - range: _, + range, }) => { // Walrus targets in inlined comps should NOT be hidden from locals() if self.current_code_info().in_inlined_comp @@ -10733,8 +11160,10 @@ impl Compiler { info.metadata.fast_hidden_final.swap_remove(name.as_ref()); } self.compile_expression(value)?; + self.set_source_range(*range); emit!(self, Instruction::Copy { i: 1 }); self.compile_store(target)?; + self.set_source_range(target.range()); } ast::Expr::FString(fstring) => { self.compile_expr_fstring(fstring)?; @@ -10812,6 +11241,7 @@ impl Compiler { &mut self, kind: BuiltinGeneratorCallKind, generator_expr: &ast::Expr, + loc: TextRange, end: BlockIdx, ) -> CompileResult<()> { let common_constant = match kind { @@ -10825,6 +11255,7 @@ impl Compiler { let cleanup = self.new_block(); // Stack: [func] — copy function for identity check + self.set_source_range(loc); emit!(self, Instruction::Copy { i: 1 }); emit!( self, @@ -10837,45 +11268,64 @@ impl Compiler { emit!(self, Instruction::PopTop); if matches!(kind, BuiltinGeneratorCallKind::Tuple) { + self.set_source_range(loc); emit!(self, Instruction::BuildList { count: 0 }); } let sub_table_cursor = self.symbol_table_stack.last().map(|t| t.next_sub_table); - self.compile_expression(generator_expr)?; + if let Some(range) = self.cpython_implicit_call_generator_range(generator_expr) { + self.compile_expression_with_generator_range(generator_expr, range)?; + } else { + self.compile_expression(generator_expr)?; + } if let Some(cursor) = sub_table_cursor && let Some(current_table) = self.symbol_table_stack.last_mut() { current_table.next_sub_table = cursor; } self.switch_to_block(loop_block); + self.set_source_range(loc); emit!(self, Instruction::ForIter { delta: cleanup }); match kind { BuiltinGeneratorCallKind::Tuple => { + self.set_source_range(loc); emit!(self, Instruction::ListAppend { i: 2 }); + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: loop_block }); } BuiltinGeneratorCallKind::All => { + self.set_source_range(loc); emit!(self, Instruction::ToBool); emit!(self, Instruction::PopJumpIfTrue { delta: loop_block }); + self.set_source_range(loc); emit!(self, Instruction::PopIter); + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: false }); + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); } BuiltinGeneratorCallKind::Any => { + self.set_source_range(loc); emit!(self, Instruction::ToBool); emit!(self, Instruction::PopJumpIfFalse { delta: loop_block }); + self.set_source_range(loc); emit!(self, Instruction::PopIter); + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: true }); + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); } } self.switch_to_block(cleanup); + self.set_source_range(loc); emit!(self, Instruction::EndFor); + self.set_source_range(loc); emit!(self, Instruction::PopIter); match kind { BuiltinGeneratorCallKind::Tuple => { + self.set_source_range(loc); emit!( self, Instruction::CallIntrinsic1 { @@ -10884,12 +11334,15 @@ impl Compiler { ); } BuiltinGeneratorCallKind::All => { + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: true }); } BuiltinGeneratorCallKind::Any => { + self.set_source_range(loc); self.emit_load_const(ConstantData::Boolean { value: false }); } } + self.set_source_range(loc); emit!(self, PseudoInstruction::Jump { delta: end }); self.switch_to_block(fallback); @@ -10910,13 +11363,27 @@ impl Compiler { // super().method() or super(cls, self).method() optimization // CALL path: [global_super, class, self] → LOAD_SUPER_METHOD → [method, self] // CALL_FUNCTION_EX path: [global_super, class, self] → LOAD_SUPER_ATTR → [attr] - // Set source range to the super() call for LOAD_GLOBAL/LOAD_DEREF/etc. - let super_range = value.range(); - self.set_source_range(super_range); - self.load_args_for_super(&super_type)?; - self.set_source_range(super_range); + let ast::Expr::Call(ast::ExprCall { + func: super_func, .. + }) = value.as_ref() + else { + unreachable!("can_optimize_super_call only accepts calls"); + }; + self.load_args_for_super(&super_type, super_func.range(), value.range())?; + let attr_access_range = self.update_start_location_to_match_attr( + func.range(), + func.range(), + attr.as_str(), + ); + let method_call_range = self.update_start_location_to_match_attr( + call_range, + func.range(), + attr.as_str(), + ); + self.set_source_range(attr_access_range); let idx = self.name(attr.as_str()); if uses_ex_call { + self.set_source_range(func.range()); match super_type { SuperCallType::TwoArg { .. } => { self.emit_load_super_attr(idx); @@ -10928,11 +11395,11 @@ impl Compiler { // CPython's Attribute_kind super path emits an attr-line // NOP after LOAD_SUPER_ATTR, even when the call later uses // CALL_FUNCTION_EX for starred arguments. - self.set_source_range(attr.range()); + self.set_source_range(attr_access_range); emit!(self, Instruction::Nop); - self.set_source_range(super_range); + self.set_source_range(func.range()); emit!(self, Instruction::PushNull); - self.codegen_call_helper(0, args, call_range)?; + self.codegen_call_helper(0, args, call_range, None)?; } else { match super_type { SuperCallType::TwoArg { .. } => { @@ -10943,14 +11410,25 @@ impl Compiler { } } // NOP for line tracking at .method( line - self.set_source_range(attr.range()); + self.set_source_range(attr_access_range); emit!(self, Instruction::Nop); // CALL at .method( line (not the full expression line) - self.codegen_call_helper(0, args, attr.range())?; + self.codegen_call_helper(0, args, method_call_range, Some(attr_access_range))?; } } else { self.compile_expression(value)?; let idx = self.name(attr.as_str()); + let attr_access_range = self.update_start_location_to_match_attr( + func.range(), + func.range(), + attr.as_str(), + ); + let method_call_range = self.update_start_location_to_match_attr( + call_range, + func.range(), + attr.as_str(), + ); + self.set_source_range(attr_access_range); // Imported names and CALL_FUNCTION_EX-style calls use plain // LOAD_ATTR + PUSH_NULL; other names use method-call mode. // Check current scope and enclosing scopes for IMPORTED flag. @@ -10962,7 +11440,11 @@ impl Compiler { } else { self.emit_load_attr_method(idx); } - self.codegen_call_helper(0, args, call_range)?; + if is_import || uses_ex_call { + self.codegen_call_helper(0, args, call_range, None)?; + } else { + self.codegen_call_helper(0, args, method_call_range, Some(attr_access_range))?; + } } } else if let Some(kind) = (!uses_ex_call) .then(|| self.detect_builtin_generator_call(func, args)) @@ -10970,17 +11452,17 @@ impl Compiler { { let end = self.new_block(); self.compile_expression(func)?; - self.optimize_builtin_generator_call(kind, &args.args[0], end)?; - self.set_source_range(call_range); + self.optimize_builtin_generator_call(kind, &args.args[0], func.range(), end)?; + self.set_source_range(func.range()); emit!(self, Instruction::PushNull); - self.codegen_call_helper(0, args, call_range)?; + self.codegen_call_helper(0, args, call_range, None)?; self.switch_to_block(end); } else { // Regular call: push func, then NULL for self_or_null slot // Stack layout: [func, NULL, args...] - same as method call [func, self, args...] self.compile_expression(func)?; emit!(self, Instruction::PushNull); - self.codegen_call_helper(0, args, call_range)?; + self.codegen_call_helper(0, args, call_range, None)?; } Ok(()) } @@ -11002,6 +11484,7 @@ impl Compiler { keywords: &[ast::Keyword], begin: usize, end: usize, + call_range: TextRange, ) -> CompileResult<()> { let n = end - begin; assert!(n > 0); @@ -11010,22 +11493,26 @@ impl Compiler { let big = n * 2 > STACK_USE_GUIDELINE as usize; if big { + self.set_source_range(call_range); emit!(self, Instruction::BuildMap { count: 0 }); } for kw in &keywords[begin..end] { // Key first, then value - this is critical! + self.set_source_range(call_range); self.emit_load_const(ConstantData::Str { value: kw.arg.as_ref().unwrap().as_str().into(), }); self.compile_expression(&kw.value)?; if big { + self.set_source_range(call_range); emit!(self, Instruction::MapAdd { i: 1 }); } } if !big { + self.set_source_range(call_range); emit!(self, Instruction::BuildMap { count: n.to_u32() }); } @@ -11040,6 +11527,7 @@ impl Compiler { additional_positional: u32, arguments: &ast::Arguments, call_range: TextRange, + kw_names_range: Option, ) -> CompileResult<()> { let nelts = arguments.args.len(); let nkwelts = arguments.keywords.len(); @@ -11058,8 +11546,18 @@ impl Compiler { if !has_starred && !has_double_star && !too_big { // Simple call path: no * or ** args + let implicit_generator_range = + if additional_positional == 0 && nelts == 1 && nkwelts == 0 { + self.cpython_implicit_call_generator_range(&arguments.args[0]) + } else { + None + }; for arg in &arguments.args { - self.compile_expression(arg)?; + if let Some(range) = implicit_generator_range { + self.compile_expression_with_generator_range(arg, range)?; + } else { + self.compile_expression(arg)?; + } } if nkwelts > 0 { @@ -11073,11 +11571,12 @@ impl Compiler { } // Restore call expression range for kwnames and CALL_KW - self.set_source_range(call_range); + self.set_source_range(kw_names_range.unwrap_or(call_range)); self.emit_load_const(ConstantData::Tuple { elements: kwarg_names, }); + self.set_source_range(call_range); let argc = additional_positional + nelts.to_u32() + nkwelts.to_u32(); emit!(self, Instruction::CallKw { argc }); } else { @@ -11104,23 +11603,21 @@ impl Compiler { } self.set_source_range(call_range); let positional_count = additional_positional + nelts.to_u32(); - if positional_count == 0 { - self.emit_load_const(ConstantData::Tuple { elements: vec![] }); - } else { - emit!( - self, - Instruction::BuildTuple { - count: positional_count - } - ); - } + emit!( + self, + Instruction::BuildTuple { + count: positional_count + } + ); } else { // Use starunpack_helper to build a list, then convert to tuple + self.set_source_range(call_range); self.starunpack_helper( &arguments.args, additional_positional, CollectionType::List, )?; + self.set_source_range(call_range); emit!( self, Instruction::CallIntrinsic1 { @@ -11129,7 +11626,7 @@ impl Compiler { ); } - self.compile_call_function_ex_keywords(&arguments.keywords)?; + self.compile_call_function_ex_keywords(&arguments.keywords, call_range)?; self.set_source_range(call_range); emit!(self, Instruction::CallFunctionEx); @@ -11141,8 +11638,10 @@ impl Compiler { fn compile_call_function_ex_keywords( &mut self, keywords: &[ast::Keyword], + call_range: TextRange, ) -> CompileResult<()> { if keywords.is_empty() { + self.set_source_range(call_range); emit!(self, Instruction::PushNull); return Ok(()); } @@ -11153,8 +11652,9 @@ impl Compiler { for (i, keyword) in keywords.iter().enumerate() { if keyword.arg.is_none() { if nseen > 0 { - self.codegen_subkwargs(keywords, i - nseen, i)?; + self.codegen_subkwargs(keywords, i - nseen, i, call_range)?; if have_dict { + self.set_source_range(call_range); emit!(self, Instruction::DictMerge { i: 1 }); } have_dict = true; @@ -11162,11 +11662,13 @@ impl Compiler { } if !have_dict { + self.set_source_range(call_range); emit!(self, Instruction::BuildMap { count: 0 }); have_dict = true; } self.compile_expression_without_const_boolop_folding(&keyword.value)?; + self.set_source_range(call_range); emit!(self, Instruction::DictMerge { i: 1 }); } else { nseen += 1; @@ -11174,8 +11676,9 @@ impl Compiler { } if nseen > 0 { - self.codegen_subkwargs(keywords, keywords.len() - nseen, keywords.len())?; + self.codegen_subkwargs(keywords, keywords.len() - nseen, keywords.len(), call_range)?; if have_dict { + self.set_source_range(call_range); emit!(self, Instruction::DictMerge { i: 1 }); } have_dict = true; @@ -11197,6 +11700,173 @@ impl Compiler { }) } + fn compile_expression_with_generator_range( + &mut self, + expression: &ast::Expr, + range: TextRange, + ) -> CompileResult<()> { + if let ast::Expr::Generator(ast::ExprGenerator { + elt, generators, .. + }) = expression + { + self.set_source_range(range); + self.compile_generator_expression(elt, generators, range) + } else { + self.compile_expression(expression) + } + } + + fn cpython_implicit_call_generator_range(&self, expression: &ast::Expr) -> Option { + if !matches!(expression, ast::Expr::Generator(_)) { + return None; + } + let range = expression.range(); + let source = self.source_file.source_text().as_bytes(); + let start = range.start().to_usize(); + let end = range.end().to_usize(); + if source.get(start) == Some(&b'(') + && !Self::starts_with_parenthesized_generator_element(source, start, end) + { + return None; + } + + let mut open = start; + while open > 0 && source[open - 1].is_ascii_whitespace() { + open -= 1; + } + if open == 0 || source[open - 1] != b'(' { + return None; + } + + let mut close = end; + while close < source.len() && source[close].is_ascii_whitespace() { + close += 1; + } + if source.get(close) != Some(&b')') { + return None; + } + + let adjusted_start = u32::try_from(open - 1).ok()?; + let adjusted_end = u32::try_from(close + 1).ok()?; + Some(TextRange::new( + TextSize::from(adjusted_start), + TextSize::from(adjusted_end), + )) + } + + fn starts_with_parenthesized_generator_element( + source: &[u8], + start: usize, + end: usize, + ) -> bool { + let mut depth = 0usize; + let mut i = start; + while i < end { + match source[i] { + b'(' | b'[' | b'{' => depth += 1, + b')' | b']' | b'}' => { + if depth == 0 { + return false; + } + depth -= 1; + if depth == 0 { + return Self::next_token_is_for(source, i + 1, end); + } + } + b'\'' | b'"' => i = Self::skip_python_string_literal(source, i), + _ => {} + } + i += 1; + } + false + } + + fn skip_python_string_literal(source: &[u8], quote: usize) -> usize { + let quote_byte = source[quote]; + let triple = source.get(quote + 1) == Some("e_byte) + && source.get(quote + 2) == Some("e_byte); + let mut i = quote + if triple { 3 } else { 1 }; + while i < source.len() { + if source[i] == b'\\' { + i += 2; + continue; + } + if triple { + if source[i] == quote_byte + && source.get(i + 1) == Some("e_byte) + && source.get(i + 2) == Some("e_byte) + { + return i + 2; + } + } else if source[i] == quote_byte { + return i; + } + i += 1; + } + source.len().saturating_sub(1) + } + + fn next_token_is_for(source: &[u8], mut i: usize, end: usize) -> bool { + while i < end && source[i].is_ascii_whitespace() { + i += 1; + } + source.get(i..i + 3) == Some(b"for") + && source + .get(i + 3) + .is_none_or(|byte| !byte.is_ascii_alphanumeric() && *byte != b'_') + } + + fn compile_generator_expression( + &mut self, + elt: &ast::Expr, + generators: &[ast::Comprehension], + range: TextRange, + ) -> CompileResult<()> { + // Check if element or generators contain async content + // This makes the generator expression into an async generator + let element_contains_await = + Self::contains_await(elt) || Self::generators_contain_await(generators); + self.compile_comprehension( + "", + None, + generators, + &|compiler, _collection_add_i| { + // Compile the element expression + // Note: if element is an async comprehension, compile_expression + // already handles awaiting it, so we don't need to await again here + compiler.compile_comprehension_element(elt)?; + + compiler.mark_generator(); + if compiler.ctx.func == FunctionContext::AsyncFunction { + compiler.set_source_range(elt.range()); + emit!( + compiler, + Instruction::CallIntrinsic1 { + func: bytecode::IntrinsicFunction1::AsyncGenWrap + } + ); + } + // arg=0: direct yield (wrapped for async generators) + compiler.set_source_range(elt.range()); + emit!(compiler, Instruction::YieldValue { arg: 0 }); + emit!( + compiler, + Instruction::Resume { + context: oparg::ResumeContext::from(oparg::ResumeLocation::AfterYield) + } + ); + emit!(compiler, Instruction::PopTop); + + Ok(()) + }, + ComprehensionType::Generator, + element_contains_await, + range, + elt.range(), + elt.range(), + ) + } + fn consume_next_sub_table(&mut self) -> CompileResult<()> { { let _ = self.push_symbol_table()?; @@ -11328,7 +11998,11 @@ impl Compiler { self.enter_scope(obj_name, scope_type, key, lineno.to_u32())?; if let Some(info) = self.code_stack.last_mut() { - info.flags = flags | (info.flags & bytecode::CodeFlags::NESTED); + info.flags = flags + | (info.flags + & (bytecode::CodeFlags::NESTED + | bytecode::CodeFlags::METHOD + | bytecode::CodeFlags::FUTURE_ANNOTATIONS)); info.metadata.argcount = arg_count; info.metadata.posonlyargcount = posonlyarg_count; info.metadata.kwonlyargcount = kwonlyarg_count; @@ -11346,6 +12020,8 @@ impl Compiler { comprehension_type: ComprehensionType, element_contains_await: bool, comprehension_range: TextRange, + element_range: TextRange, + outer_backedge_range: TextRange, ) -> CompileResult<()> { let prev_ctx = self.ctx; let has_an_async_gen = generators.iter().any(|g| g.is_async); @@ -11393,8 +12069,7 @@ impl Compiler { init_collection, generators, compile_element, - has_an_async_gen, - comprehension_range, + (comprehension_range, element_range, outer_backedge_range), ); } @@ -11424,10 +12099,11 @@ impl Compiler { // scope itself. Peek past those nested scopes so we can enter the // correct comprehension table here, then let the real outermost // iterator compile consume its nested scopes later in parent scope. - self.push_output_with_symbol_table(comp_table, flags, 1, 1, 0, name)?; + self.push_output_with_symbol_table(comp_table, flags, 0, 1, 0, name)?; // Set qualname for comprehension self.set_qualname(); + self.set_source_range(comprehension_range); let arg0 = self.varname(".0"); @@ -11444,6 +12120,11 @@ impl Compiler { } ); self.set_no_location(); + // CPython's codegen_wrap_in_stopiteration_handler() inserts + // SETUP_CLEANUP at instruction-sequence index 0, so after the + // generator prefix is inserted the protected range begins at the + // comprehension-start RESUME. + self.move_last_instruction_before_scope_start_resume(); self.push_fblock(FBlockType::StopIteration, handler_block, handler_block)?; Some(handler_block) } else { @@ -11469,7 +12150,13 @@ impl Compiler { if !generator.ifs.is_empty() { let if_cleanup_block = self.new_block(); for if_condition in &generator.ifs { + let snapshot = self.instruction_count_snapshot(); self.compile_jump_if(if_condition, false, if_cleanup_block)?; + self.mark_new_conditional_jump_locations_since( + &snapshot, + if_cleanup_block, + element_range, + ); } let body_block = self.new_block(); self.switch_to_block(body_block); @@ -11487,14 +12174,7 @@ impl Compiler { emit!(self, Instruction::LoadFast { var_num: arg0 }); } else { // Evaluate iterated item: - self.compile_for_iterable_expression(&generator.iter, generator.is_async)?; - - // Get iterator / turn item into an iterator - if generator.is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - } + self.compile_comprehension_iter(generator)?; } self.switch_to_block(loop_block); @@ -11515,14 +12195,24 @@ impl Compiler { self.pop_fblock(FBlockType::AsyncComprehensionGenerator); self.compile_store(&generator.target)?; } else { + let saved_range = self.current_source_range; + self.set_source_range(generator.iter.range()); emit!(self, Instruction::ForIter { delta: after_block }); + self.set_source_range(saved_range); self.compile_store(&generator.target)?; } real_loop_depth += 1; + let backedge_range = if gen_index + 1 == generators.len() { + element_range + } else { + outer_backedge_range + }; loop_labels.push(ComprehensionLoopControl::Iteration { loop_block, if_cleanup_block, after_block, + iter_range: generator.iter.range(), + backedge_range, is_async: generator.is_async, end_async_for_target, }); @@ -11530,7 +12220,13 @@ impl Compiler { // CPython always lowers comprehension guards through codegen_jump_if // and leaves constant-folding to later CFG optimization passes. for if_condition in &generator.ifs { + let snapshot = self.instruction_count_snapshot(); self.compile_jump_if(if_condition, false, if_cleanup_block)?; + self.mark_new_conditional_jump_locations_since( + &snapshot, + if_cleanup_block, + element_range, + ); } if !generator.ifs.is_empty() { let body_block = self.new_block(); @@ -11546,20 +12242,26 @@ impl Compiler { loop_block, if_cleanup_block, after_block, + iter_range, + backedge_range, is_async, end_async_for_target, } => { + self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); self.switch_to_block(if_cleanup_block); + self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); self.switch_to_block(after_block); if is_async { + self.set_source_range(comprehension_range); // EndAsyncFor pops both the exception and the aiter // (handler depth is before GetANext, so aiter is at handler depth) self.emit_end_async_for(end_async_for_target); } else { + self.set_source_range(iter_range); // END_FOR + POP_ITER pattern (CPython 3.14) emit!(self, Instruction::EndFor); emit!(self, Instruction::PopIter); @@ -11599,24 +12301,18 @@ impl Compiler { self.ctx = prev_ctx; // Create comprehension function with closure + self.set_source_range(comprehension_range); self.make_closure(code, bytecode::MakeFunctionFlags::new())?; - // Evaluate iterated item: - self.compile_for_iterable_expression(&outermost.iter, outermost.is_async)?; + // Evaluate iterated item and get its iterator. + self.compile_comprehension_iter(outermost)?; self.symbol_table_stack .last_mut() .expect("no current symbol table") .next_sub_table += 1; - // Get iterator / turn item into an iterator - // Use is_async from the first generator, not has_an_async_gen which covers ALL generators - if outermost.is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - }; - // Call just created function: + self.set_source_range(comprehension_range); emit!(self, Instruction::Call { argc: 0 }); if is_async_list_set_dict_comprehension { emit!(self, Instruction::GetAwaitable { r#where: 0 }); @@ -11635,9 +12331,9 @@ impl Compiler { init_collection: Option, generators: &[ast::Comprehension], compile_element: &dyn Fn(&mut Self, usize) -> CompileResult<()>, - has_async: bool, - comprehension_range: TextRange, + ranges: (TextRange, TextRange, TextRange), ) -> CompileResult<()> { + let (comprehension_range, element_range, outer_backedge_range) = ranges; fn collect_bound_names(target: &ast::Expr, out: &mut Vec) { match target { ast::Expr::Name(ast::ExprName { id, .. }) => out.push(id.to_string()), @@ -11658,10 +12354,7 @@ impl Compiler { // nested scopes (e.g. lambdas) whose sub_tables sit at the current // position in the parent's list. Those must be consumed before we // splice in the comprehension's own children. - self.compile_for_iterable_expression( - &generators[0].iter, - has_async && generators[0].is_async, - )?; + self.compile_comprehension_iter(&generators[0])?; self.symbol_table_stack .last_mut() .expect("no current symbol table") @@ -11691,12 +12384,6 @@ impl Compiler { current_table.sub_tables.insert(insert_pos + i, st.clone()); } } - if has_async && generators[0].is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - } - let mut source_order_bound_names = Vec::new(); for generator in generators { collect_bound_names(&generator.target, &mut source_order_bound_names); @@ -11810,18 +12497,12 @@ impl Compiler { ); } - // Step 4: Create the collection (list/set/dict) - if let Some(init_collection) = init_collection { - self._emit(init_collection, OpArg::new(0), BlockIdx::NULL); - // SWAP to get iterator on top - emit!(self, Instruction::Swap { i: 2 }); - } - - // Set up exception handler for cleanup on exception - let cleanup_block = self.new_block(); - let end_block = self.new_block(); - - if !pushed_locals.is_empty() { + // CPython's codegen_push_inlined_comprehension_locals() + // installs the virtual cleanup before codegen_comprehension() + // emits BUILD_LIST/BUILD_SET/BUILD_MAP for the result object. + let cleanup_blocks = if !pushed_locals.is_empty() { + let cleanup_block = self.new_block(); + let end_block = self.new_block(); emit!( self, PseudoInstruction::SetupFinally { @@ -11829,6 +12510,16 @@ impl Compiler { } ); self.push_fblock(FBlockType::TryExcept, cleanup_block, end_block)?; + Some((cleanup_block, end_block)) + } else { + None + }; + + // Step 4: Create the collection (list/set/dict) + if let Some(init_collection) = init_collection { + self._emit(init_collection, OpArg::new(0), BlockIdx::NULL); + // SWAP to get iterator on top + emit!(self, Instruction::Swap { i: 2 }); } // Step 5: Compile the comprehension loop(s) @@ -11846,7 +12537,13 @@ impl Compiler { if !generator.ifs.is_empty() { let if_cleanup_block = self.new_block(); for if_condition in &generator.ifs { + let snapshot = self.instruction_count_snapshot(); self.compile_jump_if(if_condition, false, if_cleanup_block)?; + self.mark_new_conditional_jump_locations_since( + &snapshot, + if_cleanup_block, + element_range, + ); } let body_block = self.new_block(); self.switch_to_block(body_block); @@ -11861,12 +12558,7 @@ impl Compiler { let after_block = self.new_block(); if i > 0 { - self.compile_for_iterable_expression(&generator.iter, generator.is_async)?; - if generator.is_async { - emit!(self, Instruction::GetAiter); - } else { - emit!(self, Instruction::GetIter); - } + self.compile_comprehension_iter(generator)?; } self.switch_to_block(loop_block); @@ -11894,10 +12586,17 @@ impl Compiler { } real_loop_depth += 1; + let backedge_range = if i + 1 == generators.len() { + element_range + } else { + outer_backedge_range + }; loop_labels.push(ComprehensionLoopControl::Iteration { loop_block, if_cleanup_block, after_block, + iter_range: generator.iter.range(), + backedge_range, is_async: generator.is_async, end_async_for_target, }); @@ -11905,7 +12604,13 @@ impl Compiler { // CPython always lowers comprehension guards through codegen_jump_if // and leaves constant-folding to later CFG optimization passes. for if_condition in &generator.ifs { + let snapshot = self.instruction_count_snapshot(); self.compile_jump_if(if_condition, false, if_cleanup_block)?; + self.mark_new_conditional_jump_locations_since( + &snapshot, + if_cleanup_block, + element_range, + ); } } @@ -11919,18 +12624,24 @@ impl Compiler { loop_block, if_cleanup_block, after_block, + iter_range, + backedge_range, is_async, end_async_for_target, } => { + self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); self.switch_to_block(if_cleanup_block); + self.set_source_range(backedge_range); emit!(self, PseudoInstruction::Jump { delta: loop_block }); self.switch_to_block(after_block); if is_async { + self.set_source_range(comprehension_range); self.emit_end_async_for(end_async_for_target); } else { + self.set_source_range(iter_range); emit!(self, Instruction::EndFor); emit!(self, Instruction::PopIter); } @@ -11943,8 +12654,9 @@ impl Compiler { // Step 8: Clean up - restore saved locals (and cell values) self.set_source_range(comprehension_range); - if total_stack_items > 0 { + if let Some((cleanup_block, end_block)) = cleanup_blocks { emit!(self, PseudoInstruction::PopBlock); + self.set_no_location(); self.pop_fblock(FBlockType::TryExcept); // Match CPython codegen_pop_inlined_comprehension_locals(): @@ -11955,14 +12667,18 @@ impl Compiler { self, PseudoInstruction::JumpNoInterrupt { delta: end_block } ); + self.set_no_location(); // Exception cleanup path self.switch_to_block(cleanup_block); // Stack: [saved_values..., collection, exception] emit!(self, Instruction::Swap { i: 2 }); + self.set_no_location(); emit!(self, Instruction::PopTop); // Pop incomplete collection + self.set_no_location(); // Restore locals and cell values + self.set_source_range(comprehension_range); emit!( self, Instruction::Swap { @@ -11975,9 +12691,11 @@ impl Compiler { } // Re-raise the exception emit!(self, Instruction::Reraise { depth: 0 }); + self.set_no_location(); // Normal end path self.switch_to_block(end_block); + self.set_source_range(comprehension_range); } // SWAP result to TOS (above saved values) @@ -12064,6 +12782,8 @@ impl Compiler { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }); } @@ -12091,6 +12811,21 @@ impl Compiler { } } + fn copy_previous_location_to_last_instruction(&mut self) { + let instructions = &mut self.current_block().instructions; + let Some(last_idx) = instructions.len().checked_sub(1) else { + return; + }; + let Some(previous_idx) = last_idx.checked_sub(1) else { + return; + }; + let previous = instructions[previous_idx]; + let last = &mut instructions[last_idx]; + last.location = previous.location; + last.end_location = previous.end_location; + last.lineno_override = previous.lineno_override; + } + fn force_remove_last_no_location_nop(&mut self) { if let Some(info) = self.current_block().instructions.last_mut() { info.remove_no_location_nop = true; @@ -12106,12 +12841,49 @@ impl Compiler { } } + fn move_last_instruction_before_scope_start_resume(&mut self) { + let instructions = &mut self.current_block().instructions; + let Some(last_idx) = instructions.len().checked_sub(1) else { + return; + }; + let Some(resume_idx) = + instructions[..last_idx] + .iter() + .rposition(|info| match info.instr.real() { + Some(Instruction::Resume { context }) => { + matches!( + context.get(info.arg).location(), + oparg::ResumeLocation::AtFuncStart + ) + } + _ => false, + }) + else { + return; + }; + + let instruction = instructions.remove(last_idx); + instructions.insert(resume_idx, instruction); + } + fn mark_last_no_location_exit(&mut self) { if let Some(last) = self.current_block().instructions.last_mut() { last.no_location_exit = true; } } + fn mark_last_line_only_location(&mut self, lineno: u32) { + if let Some(last) = self.current_block().instructions.last_mut() { + let location = SourceLocation { + line: OneIndexed::new(lineno as usize).unwrap_or(OneIndexed::MIN), + character_offset: OneIndexed::MIN, + }; + last.location = location; + last.end_location = location; + last.lineno_override = Some(ir::LINE_ONLY_LOCATION_OVERRIDE); + } + } + fn mark_last_break_continue_cleanup_jump(&mut self) { if let Some(last) = self.current_block().instructions.last_mut() { last.break_continue_cleanup_jump = true; @@ -12307,89 +13079,22 @@ impl Compiler { ) -> Option { let (left_int, left_is_bool) = Self::constant_as_fold_int(left)?; let (right_int, right_is_bool) = Self::constant_as_fold_int(right)?; - let zero = BigInt::from(0); - if !left_is_bool && !right_is_bool { + if !(left_is_bool && right_is_bool) { return None; } match op { - ast::Operator::Add => Some(ConstantData::Integer { - value: left_int + right_int, + ast::Operator::BitAnd => Some(ConstantData::Boolean { + value: !left_int.is_zero() & !right_int.is_zero(), }), - ast::Operator::Sub => Some(ConstantData::Integer { - value: left_int - right_int, + ast::Operator::BitOr => Some(ConstantData::Boolean { + value: !left_int.is_zero() | !right_int.is_zero(), }), - ast::Operator::Mult => Some(ConstantData::Integer { - value: left_int * right_int, + ast::Operator::BitXor => Some(ConstantData::Boolean { + value: !left_int.is_zero() ^ !right_int.is_zero(), }), - ast::Operator::Div => { - if right_int.is_zero() { - return None; - } - Some(ConstantData::Float { - value: left_int.to_f64()? / right_int.to_f64()?, - }) - } - ast::Operator::FloorDiv => { - if right_int.is_zero() || left_int < zero || right_int < zero { - return None; - } - Some(ConstantData::Integer { - value: left_int / right_int, - }) - } - ast::Operator::Mod => { - if right_int.is_zero() || left_int < zero || right_int < zero { - return None; - } - Some(ConstantData::Integer { - value: left_int % right_int, - }) - } - ast::Operator::Pow => { - let exponent = right_int.to_u32()?; - if exponent > 128 { - return None; - } - Some(ConstantData::Integer { - value: left_int.pow(exponent), - }) - } - ast::Operator::BitAnd => { - if left_is_bool && right_is_bool { - Some(ConstantData::Boolean { - value: !left_int.is_zero() & !right_int.is_zero(), - }) - } else { - Some(ConstantData::Integer { - value: left_int & right_int, - }) - } - } - ast::Operator::BitOr => { - if left_is_bool && right_is_bool { - Some(ConstantData::Boolean { - value: !left_int.is_zero() | !right_int.is_zero(), - }) - } else { - Some(ConstantData::Integer { - value: left_int | right_int, - }) - } - } - ast::Operator::BitXor => { - if left_is_bool && right_is_bool { - Some(ConstantData::Boolean { - value: !left_int.is_zero() ^ !right_int.is_zero(), - }) - } else { - Some(ConstantData::Integer { - value: left_int ^ right_int, - }) - } - } - ast::Operator::MatMult | ast::Operator::LShift | ast::Operator::RShift => None, + _ => None, } } @@ -12578,6 +13283,152 @@ impl Compiler { })) } + fn try_compile_ast_constant( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + Ok(Some(match expr { + ast::Expr::NumberLiteral(num) => match &num.value { + ast::Number::Int(int) => ConstantData::Integer { + value: ruff_int_to_bigint(int).map_err(|e| self.error(e))?, + }, + ast::Number::Float(value) => ConstantData::Float { value: *value }, + ast::Number::Complex { real, imag } => ConstantData::Complex { + value: Complex::new(*real, *imag), + }, + }, + ast::Expr::StringLiteral(s) => ConstantData::Str { + value: self.compile_string_value(s), + }, + ast::Expr::BytesLiteral(b) => ConstantData::Bytes { + value: b.value.bytes().collect(), + }, + ast::Expr::BooleanLiteral(b) => ConstantData::Boolean { value: b.value }, + ast::Expr::NoneLiteral(_) => ConstantData::None, + ast::Expr::EllipsisLiteral(_) => ConstantData::Ellipsis, + _ => return Ok(None), + })) + } + + fn try_negate_match_pattern_constant(constant: ConstantData) -> Option { + match constant { + ConstantData::Integer { value } => Some(ConstantData::Integer { value: -value }), + ConstantData::Float { value } => Some(ConstantData::Float { value: -value }), + ConstantData::Complex { value } => Some(ConstantData::Complex { value: -value }), + ConstantData::Boolean { value } => Some(ConstantData::Integer { + value: -BigInt::from(u8::from(value)), + }), + _ => None, + } + } + + fn constant_as_match_pattern_complex(constant: &ConstantData) -> Option> { + match constant { + ConstantData::Integer { value } => Some(Complex::new(value.to_f64()?, 0.0)), + ConstantData::Float { value } => Some(Complex::new(*value, 0.0)), + ConstantData::Complex { value } => Some(*value), + ConstantData::Boolean { value } => Some(Complex::new(f64::from(u8::from(*value)), 0.0)), + _ => None, + } + } + + fn try_fold_match_pattern_binop( + op: ast::Operator, + left: &ConstantData, + right: &ConstantData, + ) -> Option { + if let (ConstantData::Integer { value: left }, ConstantData::Integer { value: right }) = + (left, right) + { + return match op { + ast::Operator::Add => Some(ConstantData::Integer { + value: left + right, + }), + ast::Operator::Sub => Some(ConstantData::Integer { + value: left - right, + }), + _ => None, + }; + } + + let left_is_complex = matches!(left, ConstantData::Complex { .. }); + let right_is_complex = matches!(right, ConstantData::Complex { .. }); + if left_is_complex || right_is_complex { + let left = Self::constant_as_match_pattern_complex(left)?; + let right = Self::constant_as_match_pattern_complex(right)?; + let value = match op { + ast::Operator::Add => Complex::new(left.re + right.re, left.im + right.im), + ast::Operator::Sub => { + let imag = if !left_is_complex && right_is_complex { + -right.im + } else { + left.im - right.im + }; + Complex::new(left.re - right.re, imag) + } + _ => return None, + }; + return Some(ConstantData::Complex { value }); + } + + let left = Self::constant_as_match_pattern_complex(left)?; + let right = Self::constant_as_match_pattern_complex(right)?; + match op { + ast::Operator::Add => Some(ConstantData::Float { + value: left.re + right.re, + }), + ast::Operator::Sub => Some(ConstantData::Float { + value: left.re - right.re, + }), + _ => None, + } + } + + fn try_fold_match_pattern_const_expr( + &mut self, + expr: &ast::Expr, + ) -> CompileResult> { + // CPython 3.14 ast_preprocess.c::fold_const_match_patterns() + // folds only the constant forms needed by match patterns before + // codegen_pattern_value()/codegen_pattern_mapping_key() visit them. + Ok(match expr { + ast::Expr::UnaryOp(ast::ExprUnaryOp { + op: ast::UnaryOp::USub, + operand, + .. + }) => { + let Some(constant) = self.try_compile_ast_constant(operand)? else { + return Ok(None); + }; + Self::try_negate_match_pattern_constant(constant) + } + ast::Expr::BinOp(ast::ExprBinOp { + left, op, right, .. + }) if matches!(op, ast::Operator::Add | ast::Operator::Sub) => { + let Some(left) = (match self.try_fold_match_pattern_const_expr(left)? { + Some(constant) => Some(constant), + None => self.try_compile_ast_constant(left)?, + }) else { + return Ok(None); + }; + let Some(right) = self.try_compile_ast_constant(right)? else { + return Ok(None); + }; + Self::try_fold_match_pattern_binop(*op, &left, &right) + } + _ => None, + }) + } + + fn compile_match_pattern_expr(&mut self, expr: &ast::Expr) -> CompileResult<()> { + if let Some(constant) = self.try_fold_match_pattern_const_expr(expr)? { + self.emit_load_const(constant); + } else { + self.compile_expression(expr)?; + } + Ok(()) + } + fn emit_load_const(&mut self, constant: ConstantData) { let idx = self.arg_constant(constant); self.emit_arg(idx, |consti| Instruction::LoadConst { consti }) @@ -13111,6 +13962,65 @@ impl Compiler { self.current_source_range = range; } + fn decorated_definition_range( + &self, + statement_range: TextRange, + decorator_list: &[ast::Decorator], + keyword: &str, + ) -> TextRange { + let Some(last_decorator) = decorator_list.last() else { + return statement_range; + }; + let search_start = last_decorator.expression.range().end(); + if search_start >= statement_range.end() { + return statement_range; + } + let search_range = TextRange::new(search_start, statement_range.end()); + let source = self.source_file.slice(search_range); + let Some(keyword_offset) = source.find(keyword) else { + return statement_range; + }; + let Ok(keyword_offset) = u32::try_from(keyword_offset) else { + return statement_range; + }; + TextRange::new( + search_start + TextSize::new(keyword_offset), + statement_range.end(), + ) + } + + fn update_start_location_to_match_attr( + &self, + loc_range: TextRange, + attr_range: TextRange, + attr: &str, + ) -> TextRange { + let source = self.source_file.to_source_code(); + if source.line_index(loc_range.start()) == source.line_index(attr_range.end()) { + return loc_range; + } + let Ok(attr_len) = u32::try_from(attr.len()) else { + return TextRange::new(loc_range.start(), loc_range.end()); + }; + let attr_len = TextSize::new(attr_len); + if attr_len > attr_range.len() { + return TextRange::new(loc_range.start(), loc_range.end()); + } + TextRange::new(attr_range.end() - attr_len, loc_range.end()) + } + + fn source_line_start_range(&self, lineno: u32) -> TextRange { + let source = self.source_file.to_source_code(); + let line = OneIndexed::new(lineno as usize).unwrap_or(OneIndexed::MIN); + let start = source.line_start(line); + TextRange::new(start, start) + } + + fn module_start_location(&self, body: &[ast::Stmt]) -> TextRange { + body.first() + .map_or_else(|| self.source_line_start_range(1), Ranged::range) + } + fn get_source_line_number(&mut self) -> OneIndexed { self.source_file .to_source_code() @@ -13118,7 +14028,14 @@ impl Compiler { } fn mark_generator(&mut self) { - self.current_code_info().flags |= bytecode::CodeFlags::GENERATOR + let is_async = self.ctx.func == FunctionContext::AsyncFunction; + let flags = &mut self.current_code_info().flags; + if is_async { + flags.remove(bytecode::CodeFlags::COROUTINE); + flags.insert(bytecode::CodeFlags::ASYNC_GENERATOR); + } else { + flags.insert(bytecode::CodeFlags::GENERATOR); + } } /// Whether the expression contains an await expression and @@ -13190,18 +14107,25 @@ impl Compiler { let mut element_count = 0; let mut pending_literal = None; + let mut pending_literal_range = None; let mut pending_literal_no_location = false; for part in fstring { self.compile_fstring_part_into( part, &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, false, )?; } - self.set_source_range(fstring_range); - self.finish_fstring(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + Some(fstring_range), + ); Ok(()) } @@ -13215,17 +14139,24 @@ impl Compiler { let mut element_count = 0; let mut pending_literal = None; + let mut pending_literal_range = None; let mut pending_literal_no_location = false; for part in fstring { self.compile_fstring_part_into( part, &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, true, )?; } - self.finish_fstring_join(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring_join( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + ); Ok(()) } @@ -13233,6 +14164,7 @@ impl Compiler { &mut self, part: &ast::FStringPart, pending_literal: &mut Option, + pending_literal_range: &mut Option, pending_literal_no_location: &mut bool, element_count: &mut u32, append_to_join_list: bool, @@ -13241,10 +14173,11 @@ impl Compiler { ast::FStringPart::Literal(string) => { let value = self.compile_fstring_part_literal_value(string); if pending_literal.is_none() { - self.set_source_range(string.range); + *pending_literal_range = Some(string.range); *pending_literal_no_location = string.range == TextRange::default(); *pending_literal = Some(value); } else if let Some(pending) = pending_literal.as_mut() { + Self::extend_pending_literal_range(pending_literal_range, string.range); *pending_literal_no_location &= string.range == TextRange::default(); pending.push_wtf8(value.as_ref()); } @@ -13254,7 +14187,7 @@ impl Compiler { fstring.flags, &fstring.elements, pending_literal, - pending_literal_no_location, + (pending_literal_range, pending_literal_no_location), element_count, append_to_join_list, ), @@ -13264,12 +14197,15 @@ impl Compiler { fn finish_fstring( &mut self, mut pending_literal: Option, + mut pending_literal_range: Option, mut pending_literal_no_location: bool, mut element_count: u32, + fstring_range: Option, ) { let keep_empty = element_count == 0; self.emit_pending_fstring_literal( &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, keep_empty, @@ -13277,10 +14213,16 @@ impl Compiler { ); if element_count == 0 { + if let Some(fstring_range) = fstring_range { + self.set_source_range(fstring_range); + } self.emit_load_const(ConstantData::Str { value: Wtf8Buf::new(), }); } else if element_count > 1 { + if let Some(fstring_range) = fstring_range { + self.set_source_range(fstring_range); + } emit!( self, Instruction::BuildString { @@ -13293,12 +14235,14 @@ impl Compiler { fn finish_fstring_join( &mut self, mut pending_literal: Option, + mut pending_literal_range: Option, mut pending_literal_no_location: bool, mut element_count: u32, ) { let keep_empty = element_count == 0; self.emit_pending_fstring_literal( &mut pending_literal, + &mut pending_literal_range, &mut pending_literal_no_location, &mut element_count, keep_empty, @@ -13310,6 +14254,7 @@ impl Compiler { fn emit_pending_fstring_literal( &mut self, pending_literal: &mut Option, + pending_literal_range: &mut Option, pending_literal_no_location: &mut bool, element_count: &mut u32, keep_empty: bool, @@ -13318,6 +14263,7 @@ impl Compiler { let Some(value) = pending_literal.take() else { return; }; + let range = pending_literal_range.take(); let no_location = *pending_literal_no_location; *pending_literal_no_location = false; @@ -13328,6 +14274,9 @@ impl Compiler { return; } + if let Some(range) = range { + self.set_source_range(range); + } self.emit_load_const(ConstantData::Str { value }); if no_location { self.set_no_location(); @@ -13338,6 +14287,18 @@ impl Compiler { } } + fn extend_pending_literal_range(pending: &mut Option, range: TextRange) { + let Some(existing) = pending else { + *pending = Some(range); + return; + }; + if *existing == TextRange::default() { + *existing = range; + } else if range != TextRange::default() { + *existing = TextRange::new(existing.start(), range.end()); + } + } + fn count_fstring_parts(&self, fstring: &[ast::FStringPart]) -> u32 { let mut element_count = 0; let mut pending_literal = None; @@ -13393,6 +14354,7 @@ impl Compiler { &mut self, flags: ast::FStringFlags, fstring_elements: &ast::InterpolatedStringElements, + fstring_range: Option, ) -> CompileResult<()> { if self.count_fstring_elements(flags, fstring_elements) > STACK_USE_GUIDELINE { return self.compile_fstring_elements_joined(flags, fstring_elements); @@ -13400,16 +14362,23 @@ impl Compiler { let mut element_count = 0; let mut pending_literal: Option = None; + let mut pending_literal_range: Option = None; let mut pending_literal_no_location = false; self.compile_fstring_elements_into( flags, fstring_elements, &mut pending_literal, - &mut pending_literal_no_location, + (&mut pending_literal_range, &mut pending_literal_no_location), &mut element_count, false, )?; - self.finish_fstring(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + fstring_range, + ); Ok(()) } @@ -13427,37 +14396,58 @@ impl Compiler { let mut element_count = 0; let mut pending_literal: Option = None; + let mut pending_literal_range: Option = None; let mut pending_literal_no_location = false; self.compile_fstring_elements_into( flags, fstring_elements, &mut pending_literal, - &mut pending_literal_no_location, + (&mut pending_literal_range, &mut pending_literal_no_location), &mut element_count, true, )?; - self.finish_fstring_join(pending_literal, pending_literal_no_location, element_count); + self.finish_fstring_join( + pending_literal, + pending_literal_range, + pending_literal_no_location, + element_count, + ); Ok(()) } + fn cpython_format_spec_range(&self, range: TextRange) -> TextRange { + let start = range.start().to_usize(); + if start == 0 { + return range; + } + let source = self.source_file.source_text().as_bytes(); + if source.get(start - 1) == Some(&b':') { + TextRange::new(range.start() - TextSize::new(1), range.end()) + } else { + range + } + } + fn compile_fstring_elements_into( &mut self, flags: ast::FStringFlags, fstring_elements: &ast::InterpolatedStringElements, pending_literal: &mut Option, - pending_literal_no_location: &mut bool, + pending_literal_meta: (&mut Option, &mut bool), element_count: &mut u32, append_to_join_list: bool, ) -> CompileResult<()> { + let (pending_literal_range, pending_literal_no_location) = pending_literal_meta; for element in fstring_elements { match element { ast::InterpolatedStringElement::Literal(string) => { let value = self.compile_fstring_literal_value(string, flags); if pending_literal.is_none() { - self.set_source_range(string.range); + *pending_literal_range = Some(string.range); *pending_literal_no_location = string.range == TextRange::default(); *pending_literal = Some(value); } else if let Some(pending) = pending_literal.as_mut() { + Self::extend_pending_literal_range(pending_literal_range, string.range); *pending_literal_no_location &= string.range == TextRange::default(); pending.push_wtf8(value.as_ref()); } @@ -13472,18 +14462,33 @@ impl Compiler { if let Some(ast::DebugText { leading, trailing }) = &fstring_expr.debug_text { let range = fstring_expr.expression.range(); + let leading = strip_fstring_debug_comments(leading); + let trailing = strip_fstring_debug_comments(trailing); let source = self.source_file.slice(range); - let text = [ - strip_fstring_debug_comments(leading).as_str(), - source, - strip_fstring_debug_comments(trailing).as_str(), - ] - .concat(); + let text = [leading.as_str(), source, trailing.as_str()].concat(); + let debug_text_range = TextRange::new( + range.start() + - TextSize::new( + u32::try_from(leading.len()) + .expect("debug f-string leading text too long"), + ), + range.end() + + TextSize::new( + u32::try_from(trailing.len()) + .expect("debug f-string trailing text too long"), + ), + ); let text: Wtf8Buf = text.into(); if pending_literal.is_none() { + *pending_literal_range = Some(debug_text_range); *pending_literal_no_location = false; *pending_literal = Some(Wtf8Buf::new()); + } else { + Self::extend_pending_literal_range( + pending_literal_range, + debug_text_range, + ); } pending_literal.as_mut().unwrap().push_wtf8(text.as_ref()); @@ -13499,6 +14504,7 @@ impl Compiler { self.emit_pending_fstring_literal( pending_literal, + pending_literal_range, pending_literal_no_location, element_count, false, @@ -13507,22 +14513,32 @@ impl Compiler { self.compile_expression(&fstring_expr.expression)?; + let formatted_value_range = fstring_expr.range; match conversion { ConvertValueOparg::None => {} ConvertValueOparg::Str | ConvertValueOparg::Repr | ConvertValueOparg::Ascii => { + self.set_source_range(formatted_value_range); emit!(self, Instruction::ConvertValue { oparg: conversion }) } } match &fstring_expr.format_spec { Some(format_spec) => { - self.compile_fstring_elements(flags, &format_spec.elements)?; - + let format_spec_range = + self.cpython_format_spec_range(format_spec.range); + self.compile_fstring_elements( + flags, + &format_spec.elements, + Some(format_spec_range), + )?; + + self.set_source_range(formatted_value_range); emit!(self, Instruction::FormatWithSpec); } None => { + self.set_source_range(formatted_value_range); emit!(self, Instruction::FormatSimple); } } @@ -13714,7 +14730,11 @@ impl Compiler { let has_format_spec = interp.format_spec.is_some(); if let Some(format_spec) = &interp.format_spec { - self.compile_fstring_elements(ast::FStringFlags::empty(), &format_spec.elements)?; + self.compile_fstring_elements( + ast::FStringFlags::empty(), + &format_spec.elements, + Some(format_spec.range), + )?; } // CPython keeps bit 1 set in BUILD_INTERPOLATION's oparg and uses @@ -13820,17 +14840,20 @@ fn expandtabs(input: &str, tab_size: usize) -> String { expanded_str } -fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, &'a [ast::Stmt]) { +fn split_doc_with_range<'a>( + body: &'a [ast::Stmt], + opts: &CompileOpts, +) -> (Option<(String, TextRange)>, &'a [ast::Stmt]) { if let Some((ast::Stmt::Expr(expr), body_rest)) = body.split_first() { let doc_comment = match &*expr.value { - ast::Expr::StringLiteral(value) => Some(&value.value), + ast::Expr::StringLiteral(value) => Some((&value.value, expr.value.range())), // f-strings are not allowed in Python doc comments. ast::Expr::FString(_) => None, _ => None, }; - if let Some(doc) = doc_comment { + if let Some((doc, range)) = doc_comment { return if opts.optimize < 2 { - (Some(clean_doc(doc.to_str())), body_rest) + (Some((clean_doc(doc.to_str()), range)), body_rest) } else { (None, body_rest) }; @@ -13839,6 +14862,12 @@ fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, (None, body) } +#[cfg(test)] +fn split_doc<'a>(body: &'a [ast::Stmt], opts: &CompileOpts) -> (Option, &'a [ast::Stmt]) { + let (doc, body) = split_doc_with_range(body, opts); + (doc.map(|(doc, _)| doc), body) +} + pub fn ruff_int_to_bigint(int: &ast::Int) -> Result { if let Some(small) = int.as_u64() { Ok(BigInt::from(small)) @@ -14113,6 +15142,46 @@ mod tests { compiler.exit_scope() } + #[test] + fn test_empty_module_implicit_return_inherits_resume_location_like_cpython() { + let code = compile_exec(""); + // CPython 3.14 codegen emits the implicit LOAD_CONST/RETURN_VALUE with + // NO_LOCATION, then flowgraph.c::propagate_line_numbers() propagates + // the module RESUME location, whose line is 0. + assert_eq!(code.linetable.as_ref(), &[0xf2, 0x03, 0x01, 0x01, 0x01]); + } + + #[test] + fn test_redundant_nop_location_copies_full_location_like_cpython() { + let code = compile_exec( + "\ +def f(x, y, z): + while x: + if y: + pass + elif z: + if y < 0: + return y + if z: + y = y + 1 + elif y: + return 1 + return -1 +", + ); + let f = find_code(&code, "f").expect("missing function code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdf, 0x0a, 0x0b, 0xdf, 0x0b, 0x0c, 0xd9, 0x0c, 0x10, 0xdf, 0x0d, 0x0e, + 0xd8, 0x0f, 0x10, 0x90, 0x31, 0x8c, 0x75, 0xd8, 0x17, 0x18, 0x90, 0x08, 0xdf, 0x0f, + 0x10, 0xd8, 0x14, 0x15, 0x98, 0x01, 0x95, 0x45, 0x92, 0x01, 0xf1, 0x03, 0x00, 0x10, + 0x11, 0xe7, 0x0d, 0x0e, 0x89, 0x51, 0xd9, 0x13, 0x14, 0xd8, 0x0b, 0x0d, 0x80, 0x49, + ], + "CPython basicblock_remove_redundant_nops() copies the full NOP location into a following no-location jump" + ); + } + fn scan_program_symbol_table(source: &str) -> SymbolTable { let source_file = SourceFileBuilder::new("source_path", source).finish(); let parsed = ruff_python_parser::parse( @@ -14224,6 +15293,7 @@ mod tests { } ); compiler.set_no_location(); + compiler.move_last_instruction_before_scope_start_resume(); compiler .push_fblock(FBlockType::StopIteration, handler_block, handler_block) .unwrap(); @@ -16056,542 +17126,2118 @@ def f(buffer, pos, last_char): }) } - fn non_cache_instructions(code: &CodeObject) -> impl Iterator { - code.instructions - .iter() - .filter(|unit| !matches!(unit.op, Instruction::Cache)) - } - - fn varname_index(code: &CodeObject, name: &str) -> usize { - code.varnames - .iter() - .position(|varname| varname.as_str() == name) - .unwrap_or_else(|| panic!("missing {name} local")) + fn find_direct_child_code<'a>(code: &'a CodeObject, name: &str) -> Option<&'a CodeObject> { + code.constants.iter().find_map(|constant| { + if let ConstantData::Code { code } = constant { + (code.obj_name == name).then_some(code.as_ref()) + } else { + None + } + }) } - fn load_fast_ops_for_var(code: &CodeObject, name: &str) -> Vec { - let var_idx = varname_index(code, name); - non_cache_instructions(code) - .filter_map(|unit| match unit.op { - Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { - let var_num = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); - (usize::from(var_num) == var_idx).then_some(unit.op) - } - _ => None, - }) - .collect() + #[test] + fn test_annotated_multiline_function_body_keeps_def_firstlineno_like_cpython() { + let code = compile_exec( + r#" +a = 1 +def f( + x: a, +): ... +"#, + ); + let f = find_code(&code, "f").expect("missing f code"); + // CPython 3.14 codegen_function() computes firstlineno from the + // FunctionDef before compiling annotations, then passes it to + // codegen_function_body(). + assert_eq!(f.linetable.as_ref(), &[0x80, 0x00, 0xe1, 0x03, 0x06]); } - fn load_fast_pair_ops_for_vars( - code: &CodeObject, - left_name: &str, - right_name: &str, - ) -> Vec { - let left_idx = varname_index(code, left_name); - let right_idx = varname_index(code, right_name); - non_cache_instructions(code) - .filter_map(|unit| { - let var_nums = match unit.op { - Instruction::LoadFastLoadFast { var_nums } - | Instruction::LoadFastBorrowLoadFastBorrow { var_nums } => var_nums, - _ => return None, - }; - let (left, right) = var_nums - .get(OpArg::new(u32::from(u8::from(unit.arg)))) - .indexes(); - (usize::from(left) == left_idx && usize::from(right) == right_idx) - .then_some(unit.op) - }) - .collect() + #[test] + fn test_annotation_scope_return_uses_function_location_like_cpython() { + let code = compile_exec( + r#" +def g(): + def f(x: not (int is int), /): ... +"#, + ); + let g = find_code(&code, "g").expect("missing g code"); + let annotate = find_code(g, "__annotate__").expect("missing annotation code"); + // CPython 3.14 codegen_function_annotations() receives LOC(function) + // and uses it for the annotation closure's BUILD_MAP/RETURN_VALUE and + // for the parent MAKE_FUNCTION annotate sequence. + assert_eq!(g.linetable.as_ref(), &[0x80, 0x00, 0xdf, 0x04, 0x26]); + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd7, 0x04, 0x26, 0xd1, 0x04, 0x26, 0x94, 0x23, 0x9c, 0x13, 0xd0, 0x0d, + 0x1d, 0xd1, 0x04, 0x26, + ], + ); } - fn count_strong_loads_for_vars(code: &CodeObject, names: &[&str]) -> usize { - let var_indices = names + #[test] + fn test_module_deferred_annotations_use_start_location_like_cpython() { + let code = compile_exec( + "\ +import os +X: int +Y: str +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + + // CPython 3.14 compile.c::start_location() passes the first module + // statement location into _PyCodegen_Module(), and + // codegen_process_deferred_annotations() uses that loc for annotation + // scope setup, BUILD_MAP, STORE_SUBSCR, and RETURN_VALUE. + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0x80, 0x00, 0x87, 0x09, 0x81, 0x09, 0xdf, 0x00, 0x06, 0x82, 0x06, 0x84, 0x33, 0x81, + 0x06, 0xf1, 0x03, 0x00, 0x01, 0x0a, 0xe7, 0x00, 0x06, 0x82, 0x06, 0x84, 0x33, 0x81, + 0x06, 0xf2, 0x05, 0x00, 0x01, 0x0a, + ] + ); + } + + #[test] + fn test_super_method_call_kw_names_use_attribute_location_like_cpython() { + let code = compile_exec( + "\ +class C: + def f(self, x, y): + super().__init__( + x=x, + y=y) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let call_kw_index = f + .instructions .iter() - .map(|name| varname_index(code, name)) - .collect::>(); - non_cache_instructions(code) - .filter(|unit| match unit.op { - Instruction::LoadFast { var_num } => { - let var_num = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); - var_indices.contains(&usize::from(var_num)) - } - _ => false, - }) - .count() + .position(|unit| matches!(unit.op, Instruction::CallKw { .. })) + .expect("missing CALL_KW"); + let (kw_names, (location, end_location)) = f + .instructions + .iter() + .zip(&f.locations) + .take(call_kw_index) + .rev() + .find(|(unit, _)| matches!(unit.op, Instruction::LoadConst { .. })) + .expect("missing CALL_KW names tuple"); + + assert!( + matches!(kw_names.op, Instruction::LoadConst { .. }), + "expected keyword names tuple before CALL_KW" + ); + assert_eq!( + (location.line.get(), end_location.line.get()), + (3, 3), + "CPython maybe_optimize_method_call() passes the updated method-attribute loc into codegen_call_simple_kw_helper()" + ); } - fn count_strong_loads(code: &CodeObject) -> usize { - non_cache_instructions(code) - .filter(|unit| matches!(unit.op, Instruction::LoadFast { .. })) - .count() + #[test] + fn test_lambda_return_uses_body_location_like_cpython() { + let code = compile_exec( + "\ +def outer(): + return lambda x: x if x else 1 +", + ); + let lambda = find_code(&code, "").expect("missing lambda code"); + let return_positions: Vec<_> = lambda + .instructions + .iter() + .zip(&lambda.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 22, 2, 35), (2, 22, 2, 35)], + "CPython codegen_lambda() emits RETURN_VALUE at LOC(lambda body)" + ); } #[test] - fn test_match_or_default_block_keeps_load_fast_strong() { + fn test_not_compare_uses_unary_location_like_cpython() { let code = compile_exec( - r#" -def f(format, other): - match format: - case 1 | 2: - return other - case _: - raise NotImplementedError(other) -"#, + "\ +def f(self, other): + return not self == other +", ); - let function = find_code(&code, "f").expect("missing function code"); - let loads = load_fast_ops_for_var(function, "other"); - assert!( - matches!( - loads.as_slice(), - [ - Instruction::LoadFastBorrow { .. }, - Instruction::LoadFastBorrow { .. }, - Instruction::LoadFast { .. }, - ] - ), - "CPython optimize_load_fast() keeps trailing OR-pattern default loads strong, got {loads:?}", + let f = find_code(&code, "f").expect("missing f code"); + + // CPython 3.14 parses the Compare inside UnaryOp(Not) with the + // UnaryOp start location, so codegen_compare() emits COMPARE_OP at + // the full "not self == other" range before flowgraph folds TO_BOOL. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0f, 0x13, 0xd2, 0x0b, 0x1c, 0xd0, 0x04, 0x1c, + ] ); } #[test] - fn test_match_nested_or_default_block_keeps_load_fast_strong() { + fn test_not_chained_compare_keeps_compare_location_like_cpython() { let code = compile_exec( - r#" -def f(format, other): - match format: - case [1 | 2, value]: - return other - case _: - raise NotImplementedError(other) -"#, + "\ +def f(c): + return not (b\" \" <= c <= b\"~\") +", ); - let function = find_code(&code, "f").expect("missing function code"); - let loads = load_fast_ops_for_var(function, "other"); - assert!( - loads - .iter() - .any(|op| matches!(op, Instruction::LoadFast { .. })), - "CPython optimize_load_fast() keeps trailing nested OR-pattern default loads strong, got {loads:?}", + let f = find_code(&code, "f").expect("missing f code"); + + // CPython's single Compare under UnaryOp(Not) includes "not" in the + // Compare range, but chained comparisons keep their inner range for + // compare scaffolding and only use the UnaryOp range for TO_BOOL and + // UNARY_NOT. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x10, 0x14, 0x98, 0x01, 0xd7, 0x10, 0x21, 0xd4, 0x10, 0x21, 0x98, + 0x54, 0xd1, 0x10, 0x21, 0xd4, 0x0b, 0x22, 0xd0, 0x04, 0x22, 0xd1, 0x10, 0x21, 0xd4, + 0x0b, 0x22, 0xd0, 0x04, 0x22, + ] ); } #[test] - fn test_match_success_next_location_preserves_pass_nop() { + fn test_type_param_scopes_use_cpython_locations() { + let code = compile_exec("type BoundGenericAlias[X: int] = set[X]\n"); + let type_params = find_code(&code, "") + .expect("missing generic parameters code"); + let bound = find_direct_child_code(type_params, "X").expect("missing X bound code"); + let alias = + find_direct_child_code(type_params, "BoundGenericAlias").expect("missing alias code"); + + // CPython 3.14 codegen_type_params() emits type-parameter ops at + // LOC(typeparam), bound/default evaluator ops at LOC(e), and type alias + // body plumbing at LOC(s). + assert_eq!( + type_params.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0xd0, 0x00, 0x27, 0x90, 0x76, 0x9b, 0x23, 0x93, 0x76, 0xd7, 0x00, + 0x27, 0xd1, 0x00, 0x27, + ], + ); + assert_eq!( + bound.linetable.as_ref(), + &[0x80, 0x00, 0x9f, 0x23, 0x9e, 0x23] + ); + assert_eq!( + alias.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0xd7, 0x00, 0x27, 0xd0, 0x00, 0x27, 0xa4, 0x13, 0xa0, 0x51, 0xa5, + 0x16, 0xd0, 0x00, 0x27, + ], + ); + } + + #[test] + fn test_generic_function_annotation_scope_uses_function_location_like_cpython() { + let code = compile_exec("def f[T](x: int): ...\n"); + let type_params = + find_code(&code, "").expect("missing type params code"); + let annotate = + find_direct_child_code(type_params, "__annotate__").expect("missing annotation code"); + + // CPython 3.14 passes LOC(function) into codegen_function_annotations(), + // even when the annotation closure is emitted inside the generic + // parameters scope after codegen_type_params(). + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd7, 0x00, 0x15, 0xd1, 0x00, 0x15, 0x8c, 0x43, 0xd1, 0x00, 0x15, + ], + ); + } + + #[test] + fn test_generic_class_type_params_store_uses_class_location_like_cpython() { let code = compile_exec( - r#" -def f(command): - match command: - case "": - pass - case _ as unknown: - sink(unknown) - return False -"#, + "\ +def outer(): + class X[T]: ... +", ); - let function = find_code(&code, "f").expect("missing function code"); - let ops = non_cache_instructions(function) - .map(|unit| unit.op) - .collect::>(); - assert!( - ops.windows(3).any(|window| matches!( - window, - [ - Instruction::PopTop, - Instruction::Nop, - Instruction::LoadConst { .. }, - ] - )), - "CPython NEXT_LOCATION keeps the pass NOP after match subject POP_TOP, got {ops:?}", + let type_params = + find_code(&code, "").expect("missing type params code"); + + // CPython 3.14 codegen_class() calls codegen_type_params(), then stores + // the resulting .type_params cell with codegen_nameop(c, LOC(class), ...). + assert_eq!( + type_params.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0x8c, 0x41, 0x87, 0x4f, 0x87, 0x4f, 0x80, 0x4f, + ] ); } #[test] - fn test_while_try_body_layout_keeps_false_jump_to_anchor() { + fn test_generic_class_wrapper_ops_use_class_location_like_cpython() { let code = compile_exec( - r#" -def f(stack, itstack, node_to_stack_index): - while True: - while stack: - try: - node = itstack[-1]() - break - except StopIteration: - del node_to_stack_index[stack.pop()] - itstack.pop() - else: - break -"#, + "\ +def f(): + class X[T](tuple): + pass +", ); - let function = find_code(&code, "f").expect("missing function code"); - let ops = non_cache_instructions(function) - .map(|unit| unit.op) - .collect::>(); - let stack_test = ops - .windows(5) - .find(|window| { - matches!( - window, - [ - Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, - Instruction::ToBool, - Instruction::PopJumpIfFalse { .. }, - Instruction::NotTaken, - Instruction::Nop, - ] + let f = find_code(&code, "f").expect("missing function code"); + let wrapper_positions: Vec<_> = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Resume { .. })) + .take(4) + .zip(f.locations.iter().filter(|_| true).skip(1)) + .map(|(unit, (location, end_location))| { + ( + unit.op, + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), ) }) - .unwrap_or_else(|| { - panic!("expected CPython-style while/try false jump to anchor, got {ops:?}") - }); - assert!(matches!(stack_test[2], Instruction::PopJumpIfFalse { .. })); + .collect(); + assert_eq!( + wrapper_positions + .iter() + .map(|(_, line, col, end_line, end_col)| (*line, *col, *end_line, *end_col)) + .collect::>(), + vec![(2, 5, 3, 13); 4], + "CPython codegen_class() emits type-params wrapper closure, PUSH_NULL, and CALL at LOC(class)" + ); + + let type_params = + find_code(f, "").expect("missing generic parameters code"); + let generic_base_position = type_params + .instructions + .iter() + .zip(&type_params.locations) + .find_map(|(unit, (location, end_location))| { + let Instruction::LoadFastBorrow { var_num } = unit.op else { + return None; + }; + let idx = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); + let localsplus = type_params + .varnames + .iter() + .chain(type_params.cellvars.iter()) + .chain(type_params.freevars.iter()) + .collect::>(); + localsplus + .get(usize::from(idx)) + .is_some_and(|name| name.as_str() == ".generic_base") + .then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing .generic_base load"); + assert_eq!( + generic_base_position, + (2, 5, 3, 13), + "CPython codegen_class() injects .generic_base with LOC(class)" + ); } #[test] - fn test_while_if_not_break_keeps_body_call() { + fn test_class_deferred_annotations_use_class_body_location_like_cpython() { let code = compile_exec( r#" -def f(waiters): - while waiters: - waiter = waiters.popleft() - if not waiter.done(): - waiter.set_result(None) - break +class C: + "doc" + x: int "#, ); - let function = find_code(&code, "f").expect("missing function code"); - let ops = non_cache_instructions(function) - .map(|unit| unit.op) - .collect::>(); - assert!( - ops.windows(4).any(|window| matches!( - window, - [ - Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, - Instruction::LoadAttr { .. }, - Instruction::LoadConst { .. }, - Instruction::Call { .. }, - ] - )), - "CPython keeps waiter.set_result(None) before the break, got {ops:?}", + let class_code = find_code(&code, "C").expect("missing class code"); + + // CPython 3.14 calls codegen_body(c, loc, ...) from codegen_class_body() + // with LOCATION(firstlineno, firstlineno, 0, 0). Deferred annotation + // closure setup and following artificial class tail inherit that class + // body location, not the annotation expression location. + assert_eq!( + class_code.linetable.as_ref(), + &[ + 0xf8, 0x87, 0x00, 0x80, 0x00, 0xd9, 0x04, 0x09, 0xf7, 0x03, 0x00, 0x01, 0x01, 0x83, + 0x00, + ], + ); + } + + #[test] + fn test_future_annotation_string_uses_annotation_location_like_cpython() { + let code = compile_exec("from __future__ import annotations\nclass Bar:\n foo: Foo\n"); + let class_code = find_code(&code, "Bar").expect("missing class code"); + + // CPython 3.14 codegen_annassign() calls codegen_visit_annexpr(), + // which emits the stringized annotation at LOC(annotation), then emits + // the __annotations__ store sequence at LOC(AnnAssign). + assert_eq!( + class_code.linetable.as_ref(), + &[0x87, 0x00, 0xd8, 0x09, 0x0c, 0x87, 0x48] + ); + } + + #[test] + fn test_lambda_dict_literal_ops_use_dict_location_like_cpython() { + let code = compile_exec( + "\ +f = lambda data: {'x': data} +g = lambda i: {**i} +", + ); + let f = find_code(&code, "").expect("missing f lambda code"); + let g = code + .constants + .iter() + .filter_map(|constant| { + if let ConstantData::Code { code } = constant { + (code.obj_name == "").then_some(code.as_ref()) + } else { + None + } + }) + .nth(1) + .expect("missing g lambda code"); + + // CPython 3.14 codegen_dict()/codegen_subdict() uses LOC(dict) for + // BUILD_MAP, MAP_ADD, and DICT_UPDATE, so the lambda RETURN_VALUE + // inherits the full dict literal location after compiling its body. + assert_eq!( + f.linetable.as_ref(), + &[0x80, 0x00, 0x90, 0x23, 0x90, 0x74, 0x91, 0x1b] + ); + assert_eq!( + g.linetable.as_ref(), + &[0x80, 0x00, 0x88, 0x65, 0x90, 0x11, 0x89, 0x65] + ); + } + + #[test] + fn test_class_function_like_scopes_set_method_flag_like_cpython() { + let code = compile_exec_with_options( + r#" +class C: + def m(self): + pass + + async def am(self): + pass + + f = lambda self: self + y = (i for i in ()) + +def f(): + pass +"#, + CompileOpts::default(), + ); + let class_code = find_code(&code, "C").expect("missing class code"); + let method = find_code(class_code, "m").expect("missing method code"); + let async_method = find_code(class_code, "am").expect("missing async method code"); + let lambda = find_code(class_code, "").expect("missing lambda code"); + let genexpr = find_code(class_code, "").expect("missing genexpr code"); + let module_function = find_code(&code, "f").expect("missing module function code"); + + for code in [method, async_method, lambda, genexpr] { + assert!( + code.flags.contains(bytecode::CodeFlags::METHOD), + "class-scope function-like code should carry CO_METHOD like CPython 3.14, got {:?}", + code.flags + ); + } + assert!( + !module_function.flags.contains(bytecode::CodeFlags::METHOD), + "module-scope function must not carry CO_METHOD" + ); + } + + #[test] + fn test_inlined_comprehension_lambda_in_class_is_not_method_like_cpython() { + let code = compile_exec( + "\ +class C: + def method(self): + super() + return __class__ + items = [(lambda: i) for i in range(5)] +", + ); + let class_code = find_code(&code, "C").expect("missing class code"); + let lambda = find_code(class_code, "").expect("missing lambda code"); + assert!( + lambda.flags.contains(bytecode::CodeFlags::NESTED), + "lambda under inlined class comprehension should stay nested" + ); + assert!( + !lambda.flags.contains(bytecode::CodeFlags::METHOD), + "CPython creates this lambda while the current symtable block is the comprehension, not the class" + ); + } + + #[test] + fn test_genexpr_implicit_iterator_is_not_posonly_like_cpython() { + let code = compile_exec("x = (i for i in ())"); + let genexpr = find_code(&code, "").expect("missing genexpr code"); + + assert_eq!(genexpr.arg_count, 1); + assert_eq!( + genexpr.posonlyarg_count, 0, + "CPython codegen_comprehension() sets u_argcount=1 and leaves u_posonlyargcount=0" + ); + } + + #[test] + fn test_async_generator_uses_cpython_async_generator_flag() { + let code = compile_exec_with_options( + r#" +def g(): + yield 1 + +async def c(): + return 1 + +async def ag(): + yield 1 +"#, + CompileOpts::default(), + ); + let generator = find_code(&code, "g").expect("missing generator code"); + let coroutine = find_code(&code, "c").expect("missing coroutine code"); + let async_generator = find_code(&code, "ag").expect("missing async generator code"); + + assert!(generator.flags.contains(bytecode::CodeFlags::GENERATOR)); + assert!(!generator.flags.contains(bytecode::CodeFlags::COROUTINE)); + assert!( + !generator + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); + + assert!(coroutine.flags.contains(bytecode::CodeFlags::COROUTINE)); + assert!(!coroutine.flags.contains(bytecode::CodeFlags::GENERATOR)); + assert!( + !coroutine + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); + + assert!( + async_generator + .flags + .contains(bytecode::CodeFlags::ASYNC_GENERATOR) + ); + assert!( + !async_generator + .flags + .contains(bytecode::CodeFlags::GENERATOR) + ); + assert!( + !async_generator + .flags + .contains(bytecode::CodeFlags::COROUTINE) + ); + } + + #[test] + fn test_is_none_jump_preserves_cpython_const_order() { + let code = compile_exec_with_options( + r#" +def f(self, payload): + "doc" + if self.x is None: + self.x = [payload] + else: + raise TypeError("bad") +"#, + CompileOpts::default(), + ); + let function = find_code(&code, "f").expect("missing function code"); + assert!( + matches!( + function.constants.as_ref(), + [ + ConstantData::Str { value: doc }, + ConstantData::None, + ConstantData::Str { value: message }, + ] if doc.as_ref() == "doc" && message.as_ref() == "bad" + ), + "CPython registers None from the pre-folded `is None` comparison before the else-body string" + ); + } + + #[test] + fn test_stop_iteration_handler_starts_at_scope_start_resume_like_cpython() { + let code = compile_exec_with_options( + r#" +def g(): + yield 1 + +async def c(): + return 1 + +x = (i for i in ()) +"#, + CompileOpts::default(), + ); + + fn assert_stop_iteration_table_starts_at_resume(code: &CodeObject) { + let resume_idx = u32::try_from( + code.instructions + .iter() + .position(|unit| { + matches!( + unit.op, + Instruction::Resume { context } + if matches!( + context + .get(OpArg::new(u32::from(u8::from(unit.arg)))) + .location(), + oparg::ResumeLocation::AtFuncStart + ) + ) + }) + .expect("missing function-start RESUME"), + ) + .unwrap(); + let entries = bytecode::decode_exception_table(&code.exceptiontable); + assert!( + entries.iter().any(|entry| entry.start == resume_idx), + "CPython codegen_wrap_in_stopiteration_handler() inserts SETUP_CLEANUP before RESUME so the StopIteration table starts at RESUME; resume_idx={resume_idx}, entries={entries:?}, instructions={:?}", + code.instructions + ); + } + + assert_stop_iteration_table_starts_at_resume(find_code(&code, "g").expect("missing g")); + assert_stop_iteration_table_starts_at_resume(find_code(&code, "c").expect("missing c")); + assert_stop_iteration_table_starts_at_resume( + find_code(&code, "").expect("missing genexpr"), + ); + } + + #[test] + fn test_inlined_comprehension_cleanup_starts_at_result_build_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(self): + return [k for k, v in self._headers] +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let build_list_idx = u32::try_from( + f.instructions + .iter() + .position(|unit| matches!(unit.op, Instruction::BuildList { .. })) + .expect("missing BUILD_LIST"), + ) + .unwrap(); + let entries = bytecode::decode_exception_table(&f.exceptiontable); + assert!( + entries.iter().any(|entry| { + entry.start == build_list_idx && entry.depth == 3 && !entry.push_lasti + }), + "CPython codegen_push_inlined_comprehension_locals() emits SETUP_FINALLY before BUILD_LIST, so the virtual cleanup table starts at BUILD_LIST with saved locals depth; build_list_idx={build_list_idx}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn test_or_return_not_taken_before_jump_target_splits_exception_table_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(self, maintype): + if maintype != "multipart" or not self.is_multipart(): + return + yield 1 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let not_taken_before_return = u32::try_from( + f.instructions + .windows(3) + .position(|window| { + matches!( + window, + [ + CodeUnit { + op: Instruction::NotTaken, + .. + }, + CodeUnit { + op: Instruction::LoadConst { .. }, + .. + }, + CodeUnit { + op: Instruction::ReturnValue, + .. + }, + ] + ) + }) + .expect("missing NOT_TAKEN before return"), + ) + .unwrap(); + let return_load = not_taken_before_return + 1; + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries.iter().all(|entry| { + not_taken_before_return < entry.start || not_taken_before_return >= entry.end + }), + "CPython normalize_jumps() can leave a NOT_TAKEN before a separately labelled jump target outside the generator StopIteration range; entries={entries:?}, instructions={:?}", + f.instructions + ); + assert!( + entries + .iter() + .any(|entry| entry.start <= return_load && return_load < entry.end), + "the return block after that NOT_TAKEN is still protected by the StopIteration handler; entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn test_loop_break_condition_splits_exception_table_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(start, items): + if start: + for x in items: + if x == start: + break + yield 1 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let break_jump = u32::try_from( + f.instructions + .windows(3) + .position(|window| { + matches!( + window, + [ + CodeUnit { + op: Instruction::PopJumpIfTrue { .. }, + .. + }, + CodeUnit { + op: Instruction::Cache, + .. + }, + CodeUnit { + op: Instruction::NotTaken, + .. + }, + ] + ) || matches!( + window, + [ + CodeUnit { + op: Instruction::PopJumpIfTrue { .. }, + .. + }, + CodeUnit { + op: Instruction::NotTaken, + .. + }, + CodeUnit { + op: Instruction::JumpBackward { .. }, + .. + }, + ] + ) + }) + .expect("missing loop break conditional jump"), + ) + .unwrap(); + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries + .iter() + .all(|entry| break_jump < entry.start || break_jump >= entry.end), + "CPython normalize_jumps() leaves the loop-break conditional before the synthetic NOT_TAKEN/JUMP_BACKWARD block outside the StopIteration table; break_jump={break_jump}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn test_nested_ifexp_not_taken_splits_exception_table_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(flag, subparts): + if flag: + candidate = subparts[0] if subparts else None + yield 1 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let conditional_expr_not_taken = u32::try_from( + f.instructions + .iter() + .enumerate() + .find_map(|(idx, unit)| { + if !matches!(unit.op, Instruction::NotTaken) { + return None; + } + let prev = f.instructions[..idx] + .iter() + .rev() + .find(|unit| !matches!(unit.op, Instruction::Cache))?; + let mut following = f.instructions[idx + 1..] + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)); + let next = following.next()?; + let after_next = following.next()?; + (matches!(prev.op, Instruction::PopJumpIfFalse { .. }) + && matches!(next.op, Instruction::LoadFastBorrow { .. }) + && matches!(after_next.op, Instruction::LoadSmallInt { .. })) + .then_some(idx) + }) + .expect("missing conditional expression NOT_TAKEN"), + ) + .unwrap(); + let body_start = conditional_expr_not_taken + 1; + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries.iter().all(|entry| { + conditional_expr_not_taken < entry.start || conditional_expr_not_taken >= entry.end + }), + "CPython codegen_ifexp() uses a separate orelse label inside conditional statements, leaving the normalize_jumps NOT_TAKEN outside the StopIteration table; not_taken={conditional_expr_not_taken}, entries={entries:?}, instructions={:?}", + f.instructions + ); + assert!( + entries + .iter() + .any(|entry| entry.start <= body_start && body_start < entry.end), + "the conditional-expression body after that NOT_TAKEN remains protected; body_start={body_start}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + #[test] + fn test_bool_not_taken_after_conditional_yield_splits_like_cpython() { + let code = compile_exec_with_options( + r#" +def f(a, b, c): + if a: + yield 1 + if b: + x = 2 + if c: + x = 3 + yield 4 +"#, + CompileOpts::default(), + ); + let f = find_code(&code, "f").expect("missing f"); + let split_not_taken = f + .instructions + .iter() + .enumerate() + .filter_map(|(idx, unit)| { + if !matches!(unit.op, Instruction::NotTaken) { + return None; + } + let prev = f.instructions[..idx] + .iter() + .rev() + .find(|unit| !matches!(unit.op, Instruction::Cache))?; + matches!( + prev.op, + Instruction::PopJumpIfFalse { .. } | Instruction::PopJumpIfTrue { .. } + ) + .then(|| u32::try_from(idx).unwrap()) + }) + .nth(1) + .expect("missing second bool conditional NOT_TAKEN"); + let entries = bytecode::decode_exception_table(&f.exceptiontable); + + assert!( + entries + .iter() + .all(|entry| split_not_taken < entry.start || split_not_taken >= entry.end), + "CPython labels exception targets before normalize_jumps(), so the general bool-jump NOT_TAKEN after a conditional yield is outside the StopIteration table; not_taken={split_not_taken}, entries={entries:?}, instructions={:?}", + f.instructions + ); + } + + fn non_cache_instructions(code: &CodeObject) -> impl Iterator { + code.instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + } + + fn varname_index(code: &CodeObject, name: &str) -> usize { + code.varnames + .iter() + .position(|varname| varname.as_str() == name) + .unwrap_or_else(|| panic!("missing {name} local")) + } + + fn load_fast_ops_for_var(code: &CodeObject, name: &str) -> Vec { + let var_idx = varname_index(code, name); + non_cache_instructions(code) + .filter_map(|unit| match unit.op { + Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { + let var_num = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); + (usize::from(var_num) == var_idx).then_some(unit.op) + } + _ => None, + }) + .collect() + } + + fn load_fast_pair_ops_for_vars( + code: &CodeObject, + left_name: &str, + right_name: &str, + ) -> Vec { + let left_idx = varname_index(code, left_name); + let right_idx = varname_index(code, right_name); + non_cache_instructions(code) + .filter_map(|unit| { + let var_nums = match unit.op { + Instruction::LoadFastLoadFast { var_nums } + | Instruction::LoadFastBorrowLoadFastBorrow { var_nums } => var_nums, + _ => return None, + }; + let (left, right) = var_nums + .get(OpArg::new(u32::from(u8::from(unit.arg)))) + .indexes(); + (usize::from(left) == left_idx && usize::from(right) == right_idx) + .then_some(unit.op) + }) + .collect() + } + + fn count_strong_loads_for_vars(code: &CodeObject, names: &[&str]) -> usize { + let var_indices = names + .iter() + .map(|name| varname_index(code, name)) + .collect::>(); + non_cache_instructions(code) + .filter(|unit| match unit.op { + Instruction::LoadFast { var_num } => { + let var_num = var_num.get(OpArg::new(u32::from(u8::from(unit.arg)))); + var_indices.contains(&usize::from(var_num)) + } + _ => false, + }) + .count() + } + + fn count_strong_loads(code: &CodeObject) -> usize { + non_cache_instructions(code) + .filter(|unit| matches!(unit.op, Instruction::LoadFast { .. })) + .count() + } + + #[test] + fn test_match_or_default_block_keeps_load_fast_strong() { + let code = compile_exec( + r#" +def f(format, other): + match format: + case 1 | 2: + return other + case _: + raise NotImplementedError(other) +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let loads = load_fast_ops_for_var(function, "other"); + assert!( + matches!( + loads.as_slice(), + [ + Instruction::LoadFastBorrow { .. }, + Instruction::LoadFastBorrow { .. }, + Instruction::LoadFast { .. }, + ] + ), + "CPython optimize_load_fast() keeps trailing OR-pattern default loads strong, got {loads:?}", + ); + } + + #[test] + fn test_match_nested_or_default_block_keeps_load_fast_strong() { + let code = compile_exec( + r#" +def f(format, other): + match format: + case [1 | 2, value]: + return other + case _: + raise NotImplementedError(other) +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let loads = load_fast_ops_for_var(function, "other"); + assert!( + loads + .iter() + .any(|op| matches!(op, Instruction::LoadFast { .. })), + "CPython optimize_load_fast() keeps trailing nested OR-pattern default loads strong, got {loads:?}", + ); + } + + #[test] + fn test_match_success_next_location_preserves_pass_nop() { + let code = compile_exec( + r#" +def f(command): + match command: + case "": + pass + case _ as unknown: + sink(unknown) + return False +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let ops = non_cache_instructions(function) + .map(|unit| unit.op) + .collect::>(); + assert!( + ops.windows(3).any(|window| matches!( + window, + [ + Instruction::PopTop, + Instruction::Nop, + Instruction::LoadConst { .. }, + ] + )), + "CPython NEXT_LOCATION keeps the pass NOP after match subject POP_TOP, got {ops:?}", + ); + } + + #[test] + fn test_match_subject_copy_uses_case_pattern_location_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case 1: + return True + case 2: + return False +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let copy_line = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, _))| { + let Instruction::Copy { i } = unit.op else { + return None; + }; + let arg = OpArg::new(u32::from(u8::from(unit.arg))); + (i.get(arg) == 1).then_some(location.line.get()) + }) + .expect("missing match subject COPY"); + assert_eq!( + copy_line, 3, + "CPython codegen_match_inner() emits ADDOP_I(c, LOC(m->pattern), COPY, 1)" + ); + } + + #[test] + fn test_match_or_alternative_copies_use_alternative_locations_like_cpython() { + let code = compile_exec( + "\ +def f(): + x = False + match 0: + case 0 | 1 | 2 | 3: + x = True + return x +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x08, 0x0d, 0x80, 0x41, 0xd8, 0x0a, 0x0b, 0xdf, 0x0d, 0x0e, 0x97, + 0x11, 0x97, 0x51, 0x9f, 0x11, 0x88, 0x5d, 0xe0, 0x0b, 0x0c, 0x80, 0x48, 0xf0, 0x05, + 0x00, 0x0e, 0x1b, 0xd8, 0x10, 0x14, 0x88, 0x41, 0xd8, 0x0b, 0x0c, 0x80, 0x48, + ], + "CPython codegen_pattern_or() emits each alternative COPY with LOC(alt)" + ); + } + + #[test] + fn test_match_success_jump_uses_no_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + match 0: + case 0: + x = True + case 0: + x = False + self.assertIs(x, True) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0a, 0x0b, 0xde, 0x0d, 0x0e, 0xd9, 0x10, 0x14, 0x89, 0x41, 0xdd, + 0x0d, 0x0e, 0xd8, 0x10, 0x15, 0x88, 0x41, 0xd8, 0x04, 0x08, 0x87, 0x4d, 0x81, 0x4d, + 0x90, 0x21, 0x90, 0x54, 0xd6, 0x04, 0x1a, + ], + "CPython codegen_match_inner() emits the success jump with NO_LOCATION" + ); + } + + #[test] + fn test_match_mapping_keys_scaffolding_uses_mapping_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + x = {} + y = None + match x: + case {0: 0}: + y = 0 + self.assertIs(y, None) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x08, 0x0a, 0x80, 0x41, 0xd8, 0x08, 0x0c, 0x80, 0x41, 0xd8, 0x0a, + 0x0b, 0xdf, 0x0d, 0x13, 0x8f, 0x56, 0x8a, 0x56, 0x95, 0x11, 0x89, 0x56, 0xd8, 0x10, + 0x11, 0x89, 0x41, 0xf2, 0x03, 0x00, 0x0e, 0x14, 0xe0, 0x04, 0x08, 0x87, 0x4d, 0x81, + 0x4d, 0x90, 0x21, 0x90, 0x54, 0xd6, 0x04, 0x1a, + ], + "CPython codegen_pattern_mapping() returns to LOC(p) for BUILD_TUPLE/MATCH_KEYS scaffolding" + ); + } + + #[test] + fn test_match_class_scaffolding_uses_class_pattern_location_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case bool(z): + y = 0 + return y, z +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0a, 0x0b, 0xdc, 0x0d, 0x11, 0x8f, 0x57, 0x88, 0x57, 0xd8, 0x10, + 0x11, 0x88, 0x41, 0xd8, 0x0b, 0x0c, 0x88, 0x34, 0x80, 0x4b, 0xf0, 0x05, 0x00, 0x0e, + 0x15, 0xe0, 0x0b, 0x0c, 0x88, 0x61, 0x88, 0x34, 0x80, 0x4b, + ], + "CPython codegen_pattern_class() returns to LOC(p) after VISIT(cls)" + ); + } + + #[test] + fn test_while_try_body_layout_keeps_false_jump_to_anchor() { + let code = compile_exec( + r#" +def f(stack, itstack, node_to_stack_index): + while True: + while stack: + try: + node = itstack[-1]() + break + except StopIteration: + del node_to_stack_index[stack.pop()] + itstack.pop() + else: + break +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let ops = non_cache_instructions(function) + .map(|unit| unit.op) + .collect::>(); + let stack_test = ops + .windows(5) + .find(|window| { + matches!( + window, + [ + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::ToBool, + Instruction::PopJumpIfFalse { .. }, + Instruction::NotTaken, + Instruction::Nop, + ] + ) + }) + .unwrap_or_else(|| { + panic!("expected CPython-style while/try false jump to anchor, got {ops:?}") + }); + assert!(matches!(stack_test[2], Instruction::PopJumpIfFalse { .. })); + } + + #[test] + fn test_while_if_not_break_keeps_body_call() { + let code = compile_exec( + r#" +def f(waiters): + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + break +"#, + ); + let function = find_code(&code, "f").expect("missing function code"); + let ops = non_cache_instructions(function) + .map(|unit| unit.op) + .collect::>(); + assert!( + ops.windows(4).any(|window| matches!( + window, + [ + Instruction::LoadFastBorrow { .. } | Instruction::LoadFast { .. }, + Instruction::LoadAttr { .. }, + Instruction::LoadConst { .. }, + Instruction::Call { .. }, + ] + )), + "CPython keeps waiter.set_result(None) before the break, got {ops:?}", + ); + } + + fn localsplus_name(code: &CodeObject, idx: usize) -> Option<&str> { + if idx < code.varnames.len() { + return Some(code.varnames[idx].as_str()); + } + + let mut extra_idx = idx - code.varnames.len(); + for cellvar in &code.cellvars { + if !code.varnames.iter().any(|varname| varname == cellvar) { + if extra_idx == 0 { + return Some(cellvar.as_str()); + } + extra_idx -= 1; + } + } + code.freevars.get(extra_idx).map(|name| name.as_str()) + } + + fn has_common_constant(code: &CodeObject, expected: bytecode::CommonConstant) -> bool { + code.instructions.iter().any(|unit| match unit.op { + Instruction::LoadCommonConstant { idx } => { + idx.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected + } + _ => false, + }) + } + + fn has_intrinsic_1(code: &CodeObject, expected: IntrinsicFunction1) -> bool { + code.instructions.iter().any(|unit| match unit.op { + Instruction::CallIntrinsic1 { func } => { + func.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected + } + _ => false, + }) + } + + #[test] + fn test_trace_assert_true_try_pair() { + let trace = compile_exec_late_cfg_trace( + "\ +try: + assert True +except AssertionError as e: + fail() +try: + assert True, 'msg' +except AssertionError as e: + fail() +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_trace_for_unpack_list_literal() { + let trace = compile_exec_late_cfg_trace( + "\ +result = [] +for x, in [(1,), (2,), (3,)]: + result.append(x) +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_trace_break_in_finally_function() { + let trace = compile_single_function_late_cfg_trace( + "\ +def f(self): + count = 0 + while count < 2: + count += 1 + try: + pass + finally: + break + self.assertEqual(count, 1) +", + "f", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_import_originated_name_disables_method_call_optimization_even_with_local_import() { + let code = compile_exec( + "\ +import warnings + +def f(ch): + import warnings + warnings.warn( + '\"\\\\%c\" is an invalid escape sequence' % ch + if 0x20 <= ch < 0x7F + else '\"\\\\x%02x\" is an invalid escape sequence' % ch, + DeprecationWarning, + stacklevel=2, + ) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let ops: Vec<_> = f.instructions.iter().map(|unit| unit.op).collect(); + let warn_attr = ops + .iter() + .position(|op| matches!(op, Instruction::LoadAttr { .. })) + .expect("missing LOAD_ATTR for warnings.warn"); + let push_null = ops[warn_attr + 10..] + .iter() + .position(|op| matches!(op, Instruction::PushNull)) + .map(|idx| warn_attr + 10 + idx) + .expect("expected PUSH_NULL after plain LOAD_ATTR"); + + let load_attr = match f.instructions[warn_attr].op { + Instruction::LoadAttr { namei } => namei.get(OpArg::new(u32::from(u8::from( + f.instructions[warn_attr].arg, + )))), + _ => unreachable!(), + }; + assert!( + !load_attr.is_method(), + "import-originated names should use plain LOAD_ATTR" + ); + assert!( + matches!(ops[push_null + 1], Instruction::LoadSmallInt { .. }), + "expected warning message expression to start after PUSH_NULL, got ops={ops:?}" + ); + } + + #[test] + fn test_trace_constant_false_elif_chain() { + let trace = compile_exec_late_cfg_trace( + "\ +if 0: pass +elif 0: pass +elif 0: pass +elif 0: pass +else: pass +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_trace_multi_pass_suite() { + let trace = compile_exec_late_cfg_trace( + "\ +if 1: + # + # + # + pass + pass + # + pass + # +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_trace_single_compare_if() { + let trace = compile_exec_late_cfg_trace( + "\ +if 1 == 1: + pass +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_trace_comparison_suite() { + let trace = compile_exec_late_cfg_trace( + "\ +if 1: pass +x = (1 == 1) +if 1 == 1: pass +if 1 != 1: pass +if 1 < 1: pass +if 1 > 1: pass +if 1 <= 1: pass +if 1 >= 1: pass +if x is x: pass +if x is not x: pass +if 1 in (): pass +if 1 not in (): pass +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_trace_if_for_except_layout() { + let trace = compile_exec_late_cfg_trace( + "\ +from sys import maxsize +if maxsize == 2147483647: + for s in ('2147483648', '0o40000000000', '0x100000000', '0b10000000000000000000000000000000'): + try: + x = eval(s) + except OverflowError: + fail(\"OverflowError on huge integer literal %r\" % s) +elif maxsize == 9223372036854775807: + pass +", + ); + for (stage, dump) in trace { + eprintln!("=== {stage} ===\n{dump}"); + } + } + + #[test] + fn test_break_in_finally_tail_loads_borrow_through_empty_fallthrough_block() { + let code = compile_exec( + "\ +def f(self): + count = 0 + while count < 2: + count += 1 + try: + pass + finally: + break + self.assertEqual(count, 1) +", + ); + let code = find_code(&code, "f").unwrap(); + let ops: Vec<_> = code + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + assert!( + ops.windows(5).any(|window| { + matches!( + window, + [ + Instruction::LoadFastBorrow { .. }, + Instruction::LoadAttr { .. }, + Instruction::LoadFastBorrow { .. }, + Instruction::LoadSmallInt { .. }, + Instruction::Call { .. } + ] + ) + }), + "{:?}", + code.instructions + .iter() + .map(|unit| unit.op) + .collect::>() + ); + } + + #[test] + fn test_plain_constant_bool_op_folds_to_selected_operand() { + let code = compile_exec( + "\ +x = 1 or 2 or 3 +", + ); + let ops: Vec<_> = code + .instructions + .iter() + .map(|unit| unit.op) + .filter(|op| !matches!(op, Instruction::Cache)) + .collect(); + let folded_small_int = code.instructions.iter().any(|unit| { + matches!( + unit.op, + Instruction::LoadSmallInt { i } + if i.get(OpArg::new(u32::from(u8::from(unit.arg)))) == 1 + ) + }); + let folded_const_one = code + .instructions + .iter() + .find_map(|unit| match unit.op { + Instruction::LoadConst { .. } => code.constants.get(usize::from(u8::from(unit.arg))), + _ => None, + }) + .is_some_and(|constant| { + matches!(constant, ConstantData::Integer { value } if *value == BigInt::from(1)) + }); + + assert!( + folded_small_int || folded_const_one, + "expected folded constant 1, got ops={ops:?}" + ); + assert!( + !ops.iter().any(|op| { + matches!( + op, + Instruction::Copy { .. } + | Instruction::ToBool + | Instruction::PopJumpIfTrue { .. } + | Instruction::PopJumpIfFalse { .. } + ) + }), + "plain constant BoolOp should not leave short-circuit scaffolding, got ops={ops:?}" + ); + } + + #[test] + fn test_taken_constant_boolop_load_const_uses_literal_location_like_cpython() { + let code = compile_exec( + "\ +def and_false(x): + return False and x + +def or_true(x): + return True or x +", + ); + let and_false = find_code(&code, "and_false").expect("missing and_false code"); + let or_true = find_code(&code, "or_true").expect("missing or_true code"); + + // CPython 3.14 codegen_boolop() VISITs the selected literal before the + // short-circuit jump is optimized away, so the surviving LOAD_CONST + // keeps the literal range rather than the whole BoolOp range. + assert_eq!( + and_false.linetable.as_ref(), + &[0x80, 0x00, 0xd8, 0x0b, 0x10, 0xd0, 0x04, 0x16] + ); + assert_eq!( + or_true.linetable.as_ref(), + &[0x80, 0x00, 0xd8, 0x0b, 0x0f, 0xd0, 0x04, 0x14] + ); + } + + #[test] + fn test_assert_false_message_call_uses_assert_location_like_cpython() { + let code = compile_exec( + "\ +def f(): + assert False, \"x\" +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython 3.14 codegen_assert() emits LOAD_COMMON_CONSTANT and CALL + // at LOC(assert statement), then RAISE_VARARGS at LOC(test). + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x04, 0x15, 0x90, 0x23, 0xd3, 0x04, 0x15, 0x88, 0x35, + ] ); } - fn localsplus_name(code: &CodeObject, idx: usize) -> Option<&str> { - if idx < code.varnames.len() { - return Some(code.varnames[idx].as_str()); - } + #[test] + fn test_static_swap_implicit_return_keeps_preswap_store_location_like_cpython() { + let code = compile_exec( + "\ +def f(a, b): + a, b = a, b + b, a = a, b +", + ); + let f = find_code(&code, "f").expect("missing f code"); - let mut extra_idx = idx - code.varnames.len(); - for cellvar in &code.cellvars { - if !code.varnames.iter().any(|varname| varname == cellvar) { - if extra_idx == 0 { - return Some(cellvar.as_str()); - } - extra_idx -= 1; - } - } - code.freevars.get(extra_idx).map(|name| name.as_str()) + // CPython 3.14 flowgraph.c resolves line numbers before + // optimize_basic_block() turns BUILD_TUPLE/UNPACK_SEQUENCE into SWAP + // and apply_static_swaps() reorders the STORE_FAST pair. The + // synthetic return epilogue therefore keeps the pre-swap final store + // location. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0b, 0x0c, 0x80, 0x71, 0xd8, 0x0b, 0x0c, 0x82, 0x71, + ] + ); } - fn has_common_constant(code: &CodeObject, expected: bytecode::CommonConstant) -> bool { - code.instructions.iter().any(|unit| match unit.op { - Instruction::LoadCommonConstant { idx } => { - idx.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected - } - _ => false, - }) - } + #[test] + fn test_unpack_store_pair_jump_uses_second_target_location_like_cpython() { + let code = compile_exec( + "\ +def f(value): + if value.startswith('=?'): + try: + token, value = get_encoded_word(value) + except E: + token, value = get_atext(value) + else: + token, value = get_atext(value) + atom.append(token) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let jump_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::JumpForward { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing post-try JUMP_FORWARD"); - fn has_intrinsic_1(code: &CodeObject, expected: IntrinsicFunction1) -> bool { - code.instructions.iter().any(|unit| match unit.op { - Instruction::CallIntrinsic1 { func } => { - func.get(OpArg::new(u32::from(u8::from(unit.arg)))) == expected - } - _ => false, - }) + // CPython 3.14 flowgraph.c turns the second STORE_FAST into a NOP + // during STORE_FAST_STORE_FAST fusion, then NOP removal copies that + // second target location onto the following no-location jump. + assert_eq!(jump_position, (4, 20, 4, 25)); } #[test] - fn test_trace_assert_true_try_pair() { - let trace = compile_exec_late_cfg_trace( + fn test_chained_store_pair_jump_keeps_copy_target_location_like_cpython() { + let code = compile_exec( "\ -try: - assert True -except AssertionError as e: - fail() -try: - assert True, 'msg' -except AssertionError as e: - fail() +def f(flag): + if flag: + a = b = True + else: + a = False + b = False + g(a, b) + return a ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let f = find_code(&code, "f").expect("missing f code"); + let jump_position = f + .instructions + .windows(2) + .zip(f.locations.windows(2)) + .find_map(|(units, locations)| { + matches!(units[0].op, Instruction::StoreFastStoreFast { .. }) + .then(|| { + matches!(units[1].op, Instruction::JumpForward { .. }).then_some(( + locations[1].0.line.get(), + locations[1].0.character_offset.get(), + locations[1].1.line.get(), + locations[1].1.character_offset.get(), + )) + }) + .flatten() + }) + .expect("missing jump after chained STORE_FAST_STORE_FAST"); + + // CPython 3.14 flowgraph.c preserves the second chained-assignment + // target location on the jump that skips the else body. + assert_eq!(jump_position, (3, 13, 3, 14)); } #[test] - fn test_trace_for_unpack_list_literal() { - let trace = compile_exec_late_cfg_trace( + fn test_tuple_store_pair_jump_keeps_fused_store_location_like_cpython() { + let code = compile_exec( "\ -result = [] -for x, in [(1,), (2,), (3,)]: - result.append(x) +def f(flag, n, exp): + if flag: + n, d = n * 10**exp, 1 + else: + d = -exp + g(n, d) + return n ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let f = find_code(&code, "f").expect("missing f code"); + let jump_position = f + .instructions + .windows(2) + .zip(f.locations.windows(2)) + .find_map(|(units, locations)| { + matches!(units[0].op, Instruction::StoreFastStoreFast { .. }) + .then(|| { + matches!(units[1].op, Instruction::JumpForward { .. }).then_some(( + locations[1].0.line.get(), + locations[1].0.character_offset.get(), + locations[1].1.line.get(), + locations[1].1.character_offset.get(), + )) + }) + .flatten() + }) + .expect("missing jump after tuple STORE_FAST_STORE_FAST"); + + // Without COPY before the fused stores, CPython keeps the fused + // STORE_FAST_STORE_FAST location on the following jump. + assert_eq!(jump_position, (3, 12, 3, 13)); } #[test] - fn test_trace_break_in_finally_function() { - let trace = compile_single_function_late_cfg_trace( + fn test_genexpr_make_closure_and_call_use_genexpr_location_like_cpython() { + let code = compile_exec( "\ -def f(self): - count = 0 - while count < 2: - count += 1 - try: - pass - finally: - break - self.assertEqual(count, 1) +def f(parameters): + return ((p, type(p)) for p in parameters) ", - "f", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let f = find_code(&code, "f").expect("missing f code"); + let genexpr = find_code(f, "").expect("missing genexpr code"); + + // CPython 3.14 codegen_comprehension() uses LOC(e) for + // codegen_make_closure(), the outer CALL, and the implicit .0 load + // in codegen_sync_comprehension_generator(). + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd9, 0x0b, 0x2d, 0xa1, 0x2a, 0xd3, 0x0b, 0x2d, 0xd0, 0x04, 0x2d, + ] + ); + assert_eq!( + genexpr.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0b, 0x2d, 0xa1, 0x2a, 0x98, 0x51, 0x94, 0x04, 0x90, + 0x51, 0x93, 0x07, 0x8d, 0x4c, 0xa3, 0x2a, 0xf9, + ] + ); } #[test] - fn test_import_originated_name_disables_method_call_optimization_even_with_local_import() { + fn test_implicit_call_genexpr_range_includes_call_parens_like_cpython() { let code = compile_exec( "\ -import warnings +def implicit(): + return list(x for x in range(10)) -def f(ch): - import warnings - warnings.warn( - '\"\\\\%c\" is an invalid escape sequence' % ch - if 0x20 <= ch < 0x7F - else '\"\\\\x%02x\" is an invalid escape sequence' % ch, - DeprecationWarning, - stacklevel=2, - ) +def explicit(): + return list((x for x in range(10))) ", ); - let f = find_code(&code, "f").expect("missing f code"); - let ops: Vec<_> = f.instructions.iter().map(|unit| unit.op).collect(); - let warn_attr = ops - .iter() - .position(|op| matches!(op, Instruction::LoadAttr { .. })) - .expect("missing LOAD_ATTR for warnings.warn"); - let push_null = ops[warn_attr + 10..] - .iter() - .position(|op| matches!(op, Instruction::PushNull)) - .map(|idx| warn_attr + 10 + idx) - .expect("expected PUSH_NULL after plain LOAD_ATTR"); + let implicit = find_code(&code, "implicit").expect("missing implicit code"); + let implicit_gen = find_code(implicit, "").expect("missing implicit genexpr code"); + let explicit = find_code(&code, "explicit").expect("missing explicit code"); + let explicit_gen = find_code(explicit, "").expect("missing explicit genexpr code"); - let load_attr = match f.instructions[warn_attr].op { - Instruction::LoadAttr { namei } => namei.get(OpArg::new(u32::from(u8::from( - f.instructions[warn_attr].arg, - )))), - _ => unreachable!(), - }; - assert!( - !load_attr.is_method(), - "import-originated names should use plain LOAD_ATTR" + // CPython's parser gives an unparenthesized sole GeneratorExp call + // argument the call-parenthesized range, and codegen_comprehension() + // uses LOC(e) for MAKE_FUNCTION, the outer CALL, and the implicit .0 + // LOAD_FAST. Explicitly parenthesized genexprs already carry their own + // parentheses and must not be widened again. + assert_eq!( + implicit.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x0f, 0x25, 0x9c, 0x35, 0xa0, 0x12, 0x9c, 0x39, + 0xd3, 0x0f, 0x25, 0xd3, 0x0b, 0x25, 0xd0, 0x04, 0x25, + ] ); - assert!( - matches!(ops[push_null + 1], Instruction::LoadSmallInt { .. }), - "expected warning message expression to start after PUSH_NULL, got ops={ops:?}" + assert_eq!( + implicit_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0f, 0x25, 0x99, 0x39, 0x90, 0x61, 0x94, 0x01, 0x9b, + 0x39, 0xf9, + ] + ); + assert_eq!( + explicit_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x10, 0x26, 0x99, 0x49, 0x90, 0x71, 0x94, 0x11, 0x9b, + 0x49, 0xf9, + ] ); } #[test] - fn test_trace_constant_false_elif_chain() { - let trace = compile_exec_late_cfg_trace( + fn test_implicit_call_genexpr_parenthesized_element_range_like_cpython() { + let code = compile_exec( "\ -if 0: pass -elif 0: pass -elif 0: pass -elif 0: pass -else: pass +def bytes_binop(): + return bytes((x ^ 0x5C) for x in range(256)) + +def dict_tuple(d): + return dict((v, k) for (k, v) in d.items()) + +def plain_tuple_elt(xs): + return list((x, y) for x, y in xs) + +def explicit_gen(xs): + return list(((x, y) for x, y in xs)) ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let bytes_binop = find_code(&code, "bytes_binop").expect("missing bytes_binop code"); + let bytes_gen = find_code(bytes_binop, "").expect("missing bytes genexpr code"); + let dict_tuple = find_code(&code, "dict_tuple").expect("missing dict_tuple code"); + let dict_gen = find_code(dict_tuple, "").expect("missing dict genexpr code"); + let plain_tuple_elt = + find_code(&code, "plain_tuple_elt").expect("missing plain_tuple_elt code"); + let plain_gen = + find_code(plain_tuple_elt, "").expect("missing plain genexpr code"); + let explicit_gen = find_code(&code, "explicit_gen").expect("missing explicit_gen code"); + let explicit_inner = + find_code(explicit_gen, "").expect("missing explicit genexpr code"); + + // CPython 3.14's parser includes the call argument parentheses in + // LOC(GeneratorExp) for implicit sole-argument generator expressions, + // even when the element expression itself starts with parentheses. + assert_eq!( + bytes_binop.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x10, 0xd1, 0x10, 0x30, 0xa4, 0x55, 0xa8, 0x33, 0xa4, 0x5a, + 0xd3, 0x10, 0x30, 0xd3, 0x0b, 0x30, 0xd0, 0x04, 0x30, + ] + ); + assert_eq!( + bytes_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x10, 0x30, 0xa1, 0x5a, 0xa0, 0x01, 0x90, 0x64, 0x97, + 0x28, 0x92, 0x28, 0xa3, 0x5a, 0xf9, + ] + ); + assert_eq!( + dict_tuple.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x0f, 0x2f, 0xa0, 0x51, 0xa7, 0x57, 0xa1, 0x57, + 0xa4, 0x59, 0xd3, 0x0f, 0x2f, 0xd3, 0x0b, 0x2f, 0xd0, 0x04, 0x2f, + ] + ); + assert_eq!( + dict_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0f, 0x2f, 0xa1, 0x59, 0x99, 0x36, 0x98, 0x41, 0x90, + 0x11, 0x95, 0x06, 0xa3, 0x59, 0xf9, + ] + ); + assert_eq!( + plain_tuple_elt.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x0f, 0x26, 0xa1, 0x32, 0xd3, 0x0f, 0x26, 0xd3, + 0x0b, 0x26, 0xd0, 0x04, 0x26, + ] + ); + assert_eq!( + plain_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0f, 0x26, 0xa1, 0x32, 0x99, 0x34, 0x98, 0x31, 0x90, + 0x11, 0x95, 0x06, 0xa3, 0x32, 0xf9, + ] + ); + assert_eq!( + explicit_gen.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x0b, 0x0f, 0xd1, 0x10, 0x27, 0xa1, 0x42, 0xd3, 0x10, 0x27, 0xd3, + 0x0b, 0x28, 0xd0, 0x04, 0x28, + ] + ); + assert_eq!( + explicit_inner.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x10, 0x27, 0xa1, 0x42, 0x99, 0x44, 0x98, 0x41, 0x90, + 0x21, 0x95, 0x16, 0xa3, 0x42, 0xf9, + ] + ); } #[test] - fn test_trace_multi_pass_suite() { - let trace = compile_exec_late_cfg_trace( + fn test_genexpr_filter_cleanup_jumps_use_element_location_like_cpython() { + let code = compile_exec( "\ -if 1: - # - # - # - pass - pass - # - pass - # +def simple(names): + return (x for x in names if not _ishidden(x)) + +def boolop(fields): + return (f for f in fields if f.init and not f.kw_only) ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let simple = find_code(&code, "simple").expect("missing simple code"); + let simple_gen = find_code(simple, "").expect("missing simple genexpr code"); + let boolop = find_code(&code, "boolop").expect("missing boolop code"); + let boolop_gen = find_code(boolop, "").expect("missing boolop genexpr code"); + + // CPython 3.14 codegen_sync_comprehension_generator() emits the + // comprehension guard jump to if_cleanup, then emits the if_cleanup + // backedge with elt_loc. flowgraph.c::jump_thread() copies that target + // jump location to the threaded POP_JUMP/NOT_TAKEN cleanup path. + assert_eq!( + simple_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0b, 0x31, 0x91, 0x75, 0x90, 0x21, 0xa4, 0x49, 0xa8, + 0x61, 0xa7, 0x4c, 0x8f, 0x41, 0x8a, 0x41, 0x93, 0x75, 0xf9, + ] + ); + assert_eq!( + boolop_gen.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd0, 0x0b, 0x3a, 0x91, 0x76, 0x90, 0x21, 0xa7, 0x16, 0xa5, + 0x16, 0x8c, 0x41, 0xb0, 0x01, 0xb7, 0x09, 0xb5, 0x09, 0x8f, 0x41, 0x8a, 0x41, 0x93, + 0x76, 0xf9, + ] + ); } #[test] - fn test_trace_single_compare_if() { - let trace = compile_exec_late_cfg_trace( + fn test_try_finally_exception_scaffolding_uses_no_location_like_cpython() { + let code = compile_exec( "\ -if 1 == 1: - pass +def f(self, node): + self.flag = True + try: + self.body(node) + finally: + self.flag = False ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let f = find_code(&code, "f").expect("missing f code"); + + // CPython 3.14 codegen_try_finally() emits the exception path + // SETUP_CLEANUP/PUSH_EXC_INFO and POP_EXCEPT_AND_RERAISE with + // NO_LOCATION; flowgraph line propagation then gives only the + // finalbody's direct RERAISE the finalbody location. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x10, 0x14, 0x80, 0x44, 0x84, 0x49, 0xf0, 0x02, 0x03, 0x05, 0x1a, + 0xd8, 0x08, 0x0c, 0x8f, 0x09, 0x89, 0x09, 0x90, 0x24, 0x8c, 0x0f, 0xe0, 0x14, 0x19, + 0x88, 0x04, 0x8e, 0x09, 0xf8, 0x90, 0x45, 0x88, 0x04, 0x8d, 0x09, 0xfa, + ] + ); } #[test] - fn test_trace_comparison_suite() { - let trace = compile_exec_late_cfg_trace( + fn test_adjacent_no_location_entries_merge_like_cpython() { + let code = compile_exec( "\ -if 1: pass -x = (1 == 1) -if 1 == 1: pass -if 1 != 1: pass -if 1 < 1: pass -if 1 > 1: pass -if 1 <= 1: pass -if 1 >= 1: pass -if x is x: pass -if x is not x: pass -if 1 in (): pass -if 1 not in (): pass +def f(file): + if sys.platform == \"win32\": + try: + import nt + if not nt._supports_virtual_terminal(): + return False + except (ImportError, AttributeError): + return False + try: + return os.isatty(file.fileno()) + except OSError: + return hasattr(file, \"isatty\") and file.isatty() +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + // CPython's NO_LOCATION is {-1, -1, -1, -1}, and + // assemble.c::assemble_location_info() merges adjacent instructions + // with the same NO_LOCATION into one linetable entry. + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdc, 0x07, 0x0a, 0x87, 0x7c, 0x81, 0x7c, 0x90, 0x77, 0xd4, 0x07, 0x1e, + 0xf0, 0x02, 0x05, 0x09, 0x19, 0xdb, 0x0c, 0x15, 0xd8, 0x13, 0x15, 0xd7, 0x13, 0x30, + 0xd1, 0x13, 0x30, 0xd7, 0x13, 0x32, 0xd2, 0x13, 0x32, 0xd9, 0x17, 0x1c, 0xf0, 0x03, + 0x00, 0x14, 0x33, 0xf0, 0x08, 0x03, 0x05, 0x39, 0xdc, 0x0f, 0x11, 0x8f, 0x79, 0x89, + 0x79, 0x98, 0x14, 0x9f, 0x1b, 0x99, 0x1b, 0x9b, 0x1d, 0xd3, 0x0f, 0x27, 0xd0, 0x08, + 0x27, 0xf8, 0xf4, 0x07, 0x00, 0x11, 0x1c, 0x9c, 0x5e, 0xd0, 0x0f, 0x2c, 0xf4, 0x00, + 0x01, 0x09, 0x19, 0xda, 0x13, 0x18, 0xf0, 0x03, 0x01, 0x09, 0x19, 0xfb, 0xf4, 0x08, + 0x00, 0x0c, 0x13, 0xf4, 0x00, 0x01, 0x05, 0x39, 0xdc, 0x0f, 0x16, 0x90, 0x74, 0x98, + 0x58, 0xd3, 0x0f, 0x26, 0xd7, 0x0f, 0x38, 0xd0, 0x0f, 0x38, 0xa8, 0x34, 0xaf, 0x3b, + 0xa9, 0x3b, 0xab, 0x3d, 0xd2, 0x08, 0x38, 0xf0, 0x03, 0x01, 0x05, 0x39, 0xfa, + ] + ); + } + + #[test] + fn test_fstring_format_ops_use_formatted_value_location_like_cpython() { + let code = compile_exec( + "\ +def simple(self): + return f'{self.value}' + +def spec(x): + return f'{x!r:>3}' ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let simple = find_code(&code, "simple").expect("missing simple code"); + let spec = find_code(&code, "spec").expect("missing spec code"); + + // CPython 3.14 codegen_formatted_value() VISITs the inner expression + // first, then emits CONVERT_VALUE / FORMAT_* at LOC(FormattedValue). + assert_eq!( + simple.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0e, 0x12, 0x8f, 0x6a, 0x89, 0x6a, 0x88, 0x5c, 0xd0, 0x04, 0x1a, + ] + ); + assert_eq!( + spec.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x0e, 0x0f, 0x88, 0x58, 0x90, 0x22, 0x88, 0x58, 0xd0, 0x04, 0x16, + ] + ); } #[test] - fn test_trace_if_for_except_layout() { - let trace = compile_exec_late_cfg_trace( + fn test_debug_fstring_literal_location_like_cpython() { + fn string_load_position(code: &CodeObject, expected: &str) -> (usize, usize, usize, usize) { + code.instructions + .iter() + .zip(&code.locations) + .find_map(|(unit, (location, end_location))| { + let Instruction::LoadConst { consti } = unit.op else { + return None; + }; + let constant = + &code.constants[consti.get(OpArg::new(u32::from(u8::from(unit.arg))))]; + matches!(constant, ConstantData::Str { value } if value.to_string() == expected) + .then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing debug f-string literal") + } + + let code = compile_exec( "\ -from sys import maxsize -if maxsize == 2147483647: - for s in ('2147483648', '0o40000000000', '0x100000000', '0b10000000000000000000000000000000'): - try: - x = eval(s) - except OverflowError: - fail(\"OverflowError on huge integer literal %r\" % s) -elif maxsize == 9223372036854775807: - pass +def simple(x): + return f'{x=}' + +def prefixed(x): + return f'a {x=} b' ", ); - for (stage, dump) in trace { - eprintln!("=== {stage} ===\n{dump}"); - } + let simple = find_code(&code, "simple").expect("missing simple code"); + let prefixed = find_code(&code, "prefixed").expect("missing prefixed code"); + + assert_eq!( + string_load_position(simple, "x="), + (2, 15, 2, 17), + "CPython represents f'{{x=}}' debug text as a literal at the expression/debug-text location" + ); + assert_eq!( + string_load_position(prefixed, "a x="), + (5, 14, 5, 19), + "CPython extends a pending f-string literal through the debug text range" + ); } #[test] - fn test_break_in_finally_tail_loads_borrow_through_empty_fallthrough_block() { + fn test_fstring_format_spec_build_string_location_like_cpython() { let code = compile_exec( "\ -def f(self): - count = 0 - while count < 2: - count += 1 - try: - pass - finally: - break - self.assertEqual(count, 1) +def simple(lbl, label_width): + return f'{lbl:>{label_width}}' + +def padded(digits, int_len): + return f'{digits:0>{int_len + 1}d}' ", ); - let code = find_code(&code, "f").unwrap(); - let ops: Vec<_> = code - .instructions - .iter() - .map(|unit| unit.op) - .filter(|op| !matches!(op, Instruction::Cache)) - .collect(); - assert!( - ops.windows(5).any(|window| { - matches!( - window, - [ - Instruction::LoadFastBorrow { .. }, - Instruction::LoadAttr { .. }, - Instruction::LoadFastBorrow { .. }, - Instruction::LoadSmallInt { .. }, - Instruction::Call { .. } - ] - ) - }), - "{:?}", + let simple = find_code(&code, "simple").expect("missing simple code"); + let padded = find_code(&code, "padded").expect("missing padded code"); + + let build_string_position = |code: &CodeObject| { code.instructions .iter() - .map(|unit| unit.op) - .collect::>() + .zip(&code.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::BuildString { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing format-spec BUILD_STRING") + }; + + assert_eq!( + build_string_position(simple), + (2, 18, 2, 33), + "CPython uses the format-spec JoinedStr location, including the ':' prefix, for BUILD_STRING" + ); + assert_eq!( + build_string_position(padded), + (5, 21, 5, 38), + "CPython format-spec JoinedStr location spans from ':' through the final literal" ); } #[test] - fn test_plain_constant_bool_op_folds_to_selected_operand() { + fn test_joined_string_literals_extend_pending_literal_location_like_cpython() { let code = compile_exec( "\ -x = 1 or 2 or 3 +def f(a): + return ( + 'x' + f'y{a}z' + 'w' + ) ", ); - let ops: Vec<_> = code - .instructions - .iter() - .map(|unit| unit.op) - .filter(|op| !matches!(op, Instruction::Cache)) - .collect(); - let folded_small_int = code.instructions.iter().any(|unit| { - matches!( - unit.op, - Instruction::LoadSmallInt { i } - if i.get(OpArg::new(u32::from(u8::from(unit.arg)))) == 1 - ) - }); - let folded_const_one = code - .instructions - .iter() - .find_map(|unit| match unit.op { - Instruction::LoadConst { .. } => code.constants.get(usize::from(u8::from(unit.arg))), - _ => None, - }) - .is_some_and(|constant| { - matches!(constant, ConstantData::Integer { value } if *value == BigInt::from(1)) - }); - - assert!( - folded_small_int || folded_const_one, - "expected folded constant 1, got ops={ops:?}" - ); - assert!( - !ops.iter().any(|op| { - matches!( - op, - Instruction::Copy { .. } - | Instruction::ToBool - | Instruction::PopJumpIfTrue { .. } - | Instruction::PopJumpIfFalse { .. } - ) - }), - "plain constant BoolOp should not leave short-circuit scaffolding, got ops={ops:?}" + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xf0, 0x04, 0x01, 0x09, 0x0c, 0xd8, 0x0c, 0x0d, 0x88, 0x33, 0xf0, 0x00, + 0x01, 0x0f, 0x0c, 0xf0, 0x03, 0x02, 0x09, 0x0c, 0xf0, 0x03, 0x04, 0x05, 0x06, + ], + "CPython parser/codegen represents adjacent f-string literal fragments as Constant ranges spanning the merged fragments" ); } @@ -16667,6 +19313,60 @@ def outer(null): ); } + #[test] + fn test_decorated_definitions_use_cpython_locations() { + let code = compile_exec( + "\ +def dec(f): return f + +class C: + @dec + def f(self): + yield + +@dec +class D: + pass + +class E: + @dec + def g(self, flags: int, /) -> memoryview: + raise NotImplementedError +", + ); + let c = find_code(&code, "C").expect("missing C code"); + let d = find_code(&code, "D").expect("missing D code"); + let e = find_code(&code, "E").expect("missing E code"); + let annotate = find_code(e, "__annotate__").expect("missing annotation code"); + + // CPython 3.14 codegen_function()/codegen_class() evaluate + // decorators first, then use LOC(s) for codegen_make_closure() and + // codegen_nameop(); codegen_apply_decorators() emits CALL at each + // decorator expression's location. + assert_eq!( + c.linetable.as_ref(), + &[ + 0xf8, 0x87, 0x00, 0x80, 0x00, 0xd8, 0x05, 0x08, 0xf1, 0x02, 0x01, 0x05, 0x0e, 0xf3, + 0x03, 0x00, 0x06, 0x09, 0xf6, 0x02, 0x01, 0x05, 0x0e, + ] + ); + assert_eq!(d.linetable.as_ref(), &[0x86, 0x00, 0xe3, 0x04, 0x08]); + assert_eq!( + e.linetable.as_ref(), + &[ + 0xf8, 0x87, 0x00, 0x80, 0x00, 0xd8, 0x05, 0x08, 0xf7, 0x02, 0x01, 0x05, 0x22, 0xf3, + 0x03, 0x00, 0x06, 0x09, 0xf6, 0x02, 0x01, 0x05, 0x22, + ] + ); + assert_eq!( + annotate.linetable.as_ref(), + &[ + 0xf8, 0x80, 0x00, 0xf7, 0x00, 0x01, 0x05, 0x22, 0xf1, 0x00, 0x01, 0x05, 0x22, 0x91, + 0x73, 0xf0, 0x00, 0x01, 0x05, 0x22, 0xa1, 0x2a, 0xf1, 0x00, 0x01, 0x05, 0x22, + ] + ); + } + #[test] fn test_taken_constant_boolop_jump_disables_following_borrows() { for source in [ @@ -17337,6 +20037,25 @@ def f(xs): 1, "fallback call path should remain for shadowed any()" ); + let genexpr_const_count = f + .constants + .iter() + .filter(|constant| { + matches!(constant, ConstantData::Code { code } if code.obj_name == "") + }) + .count(); + assert_eq!( + genexpr_const_count, 1, + "optimized and fallback any(genexpr) paths should share the same CPython-range code const" + ); + assert_eq!( + f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdf, 0x0b, 0x0e, 0x8b, 0x33, 0x89, 0x6f, 0x99, 0x22, 0x8b, 0x6f, 0x8f, + 0x33, 0x8c, 0x33, 0xd0, 0x04, 0x1d, 0x8a, 0x33, 0xd0, 0x04, 0x1d, 0x88, 0x33, 0x89, + 0x6f, 0x99, 0x22, 0x8b, 0x6f, 0xd3, 0x0b, 0x1d, 0xd0, 0x04, 0x1d, + ] + ); } #[test] @@ -17369,6 +20088,14 @@ def set_f(xs): }) .expect("tuple(genexpr) fast path should emit LIST_APPEND"); assert_eq!(tuple_list_append, 2); + assert_eq!( + tuple_f.linetable.as_ref(), + &[ + 0x80, 0x00, 0xdf, 0x0b, 0x10, 0x8c, 0x35, 0x91, 0x0f, 0x99, 0x42, 0x93, 0x0f, 0x8f, + 0x35, 0xd0, 0x04, 0x1f, 0x88, 0x35, 0x91, 0x0f, 0x99, 0x42, 0x93, 0x0f, 0xd3, 0x0b, + 0x1f, 0xd0, 0x04, 0x1f, + ] + ); let list_f = find_code(&code, "list_f").expect("missing list_f code"); assert!( @@ -17787,6 +20514,27 @@ def aug(x, a, b, y): ); } + #[test] + fn test_augassign_constant_slice_copy_uses_subscript_location_like_cpython() { + let code = compile_exec( + "\ +def aug_const(x, y): + x[1:2] += y +", + ); + let aug_const = find_code(&code, "aug_const").expect("missing aug_const code"); + + // CPython 3.14 codegen_augassign() visits a constant slice, then emits + // COPY/COPY/BINARY_OP NB_SUBSCR at LOC(target), not at LOC(slice). + assert_eq!( + aug_const.linetable.as_ref(), + &[ + 0x80, 0x00, 0xd8, 0x04, 0x05, 0x80, 0x63, 0x87, 0x46, 0x88, 0x61, 0x85, 0x4b, 0x85, + 0x46, + ] + ); + } + #[test] fn test_loop_return_reorders_backedge_before_exit_cleanup() { let code = compile_exec( @@ -19618,6 +22366,38 @@ def f(x): ); } + #[test] + fn test_match_negative_value_const_precedes_implicit_none_like_cpython() { + let code = compile_exec( + "\ +def f(x): + match x: + case -0.0: + y = 0 +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let negative_zero_index = f + .constants + .iter() + .position(|constant| { + matches!( + constant, + ConstantData::Float { value } if *value == 0.0 && value.is_sign_negative() + ) + }) + .expect("missing folded -0.0 match value"); + let none_index = f + .constants + .iter() + .position(|constant| matches!(constant, ConstantData::None)) + .expect("missing implicit None"); + assert!( + negative_zero_index < none_index, + "CPython ast_preprocess.c folds MatchValue constants before codegen registers the implicit None" + ); + } + #[test] fn test_match_or_uses_shared_success_block() { let code = compile_exec( @@ -24046,6 +26826,92 @@ class C: ); } + #[test] + fn test_future_annotations_flag_is_inherited_like_cpython() { + let code = compile_exec( + "\ +from __future__ import annotations + +def f(): + class C: + pass + return C +", + ); + assert!(code.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); + let f = find_code(&code, "f").expect("missing f code"); + assert!(f.flags.contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS)); + let class_code = find_code(f, "C").expect("missing C code"); + assert!( + class_code + .flags + .contains(bytecode::CodeFlags::FUTURE_ANNOTATIONS) + ); + } + + #[test] + fn test_annotation_scope_nested_flag_matches_cpython() { + let code = compile_exec( + "\ +class C: + x: int + +def outer(): + class D: + y: int +", + ); + let class_code = find_code(&code, "C").expect("missing C code"); + let class_annotate = + find_code(class_code, "__annotate__").expect("missing class annotation code"); + assert!( + !class_annotate.flags.contains(bytecode::CodeFlags::NESTED), + "module-level class annotation scope should not be nested" + ); + + let outer = find_code(&code, "outer").expect("missing outer code"); + let nested_class = find_code(outer, "D").expect("missing nested class code"); + let nested_annotate = + find_code(nested_class, "__annotate__").expect("missing nested annotation code"); + assert!( + nested_annotate.flags.contains(bytecode::CodeFlags::NESTED), + "annotation scope under a nested class should be nested" + ); + } + + #[test] + fn test_function_like_parent_marks_child_nested_like_cpython() { + let code = compile_exec( + "\ +x = lambda: (lambda: None) +type A[T] = T +", + ); + let outer_lambda = find_code(&code, "").expect("missing outer lambda code"); + assert!( + !outer_lambda.flags.contains(bytecode::CodeFlags::NESTED), + "module-level lambda should not be nested" + ); + let inner_lambda = + find_direct_child_code(outer_lambda, "").expect("missing inner lambda code"); + assert!( + inner_lambda.flags.contains(bytecode::CodeFlags::NESTED), + "lambda inside lambda should be nested" + ); + + let type_params = + find_code(&code, "").expect("missing type params code"); + assert!( + !type_params.flags.contains(bytecode::CodeFlags::NESTED), + "module-level type-parameter scope should not be nested" + ); + let type_alias = find_direct_child_code(type_params, "A").expect("missing type alias code"); + assert!( + type_alias.flags.contains(bytecode::CodeFlags::NESTED), + "type alias body inside type-parameter scope should be nested" + ); + } + #[test] fn test_plain_super_call_keeps_class_freevar() { let code = compile_exec( @@ -25657,6 +28523,46 @@ def f(obj): ); } + #[test] + fn test_slice_none_bounds_and_build_slice_use_slice_location_like_cpython() { + let code = compile_exec( + "\ +def f(obj, step): + return obj[::step] +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let slice_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + let op = match unit.op { + Instruction::LoadConst { .. } => "LOAD_CONST", + Instruction::BuildSlice { .. } => "BUILD_SLICE", + _ => return None, + }; + Some(( + op, + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + slice_positions, + vec![ + ("LOAD_CONST", 2, 16, 2, 22), + ("LOAD_CONST", 2, 16, 2, 22), + ("BUILD_SLICE", 2, 16, 2, 22), + ], + "CPython codegen_slice() emits missing bounds and BUILD_SLICE at LOC(slice)" + ); + } + #[test] fn test_bool_int_binop_constants_fold() { let code = compile_exec( @@ -26629,6 +29535,68 @@ def f(a, b, path): ); } + #[test] + fn test_with_return_value_uses_context_expr_location_like_cpython() { + let code = compile_exec( + "\ +def f(cm, func, args, kwds): + with cm: + return func(*args, **kwds) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let return_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 10, 2, 12), (2, 10, 2, 12)], + "CPython codegen_unwind_fblock(WITH) leaves RETURN_VALUE inheriting the context expression location" + ); + } + + #[test] + fn test_async_with_return_value_uses_context_expr_location_like_cpython() { + let code = compile_exec( + "\ +async def f(cm, func, args, kwds): + async with cm: + return await func(*args, **kwds) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let return_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::ReturnValue).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert_eq!( + return_positions, + vec![(2, 16, 2, 18), (2, 16, 2, 18)], + "CPython codegen_unwind_fblock(ASYNC_WITH) leaves RETURN_VALUE inheriting the context expression location" + ); + } + #[test] fn test_try_finally_conditional_return_duplicates_finally_exit_return() { let code = compile_exec( @@ -26812,7 +29780,7 @@ def f(cls, proto): } #[test] - fn test_literal_only_fstring_statement_is_optimized_away() { + fn test_literal_only_fstring_statement_keeps_const_like_cpython() { let code = compile_exec( "\ def f(): @@ -26822,17 +29790,11 @@ def f(): let f = find_code(&code, "f").expect("missing function code"); assert!( - !f.instructions - .iter() - .any(|unit| matches!(unit.op, Instruction::PopTop)), - "literal-only f-string statement should be removed" - ); - assert!( - !f.constants.iter().any(|constant| matches!( + f.constants.iter().any(|constant| matches!( constant, ConstantData::Str { value } if value.to_string() == "Not a docstring" )), - "literal-only f-string should not survive in constants" + "constant f-string statement should survive in co_consts like CPython" ); } @@ -27122,6 +30084,29 @@ values = [item for item in [r\"\\\\'a\\\\'\", r\"\\t3\", r\"\\\\\"[0]]]\n", } } + #[test] + fn test_constant_subscript_registers_source_const_before_result_like_cpython() { + let code = compile_exec("value = 'string'[3]\n"); + let source_index = code + .constants + .iter() + .position(|constant| { + matches!(constant, ConstantData::Str { value } if value.to_string() == "string") + }) + .expect("missing source string constant"); + let result_index = code + .constants + .iter() + .position(|constant| { + matches!(constant, ConstantData::Str { value } if value.to_string() == "i") + }) + .expect("missing folded subscript result"); + assert!( + source_index < result_index, + "CPython codegen_subscript emits the source constant before flowgraph.c folds NB_SUBSCR" + ); + } + #[test] fn test_constant_slice_subscript_folds_in_load_context() { let code = compile_exec( @@ -27285,6 +30270,29 @@ zero = 0j ** 2 ))); } + #[test] + fn test_folded_nan_constants_are_not_deduplicated_like_cpython() { + let code = compile_exec( + "\ +def f(): + repr(1e300 * 1e300 * 0) + repr(-1e300 * 1e300 * 0) + str(1e300 * 1e300 * 0) + str(-1e300 * 1e300 * 0) +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let nan_count = f + .constants + .iter() + .filter(|constant| matches!(constant, ConstantData::Float { value } if value.is_nan())) + .count(); + assert_eq!( + nan_count, 4, + "CPython _PyCode_ConstantKey keeps folded NaN constants distinct" + ); + } + #[test] fn test_zero_complex_power_exception_constants_do_not_fold() { let code = compile_exec("value = 0j ** (3 - 2j)\n"); @@ -27370,6 +30378,53 @@ class C: assert_eq!(varnames, vec!["format"]); } + #[test] + fn test_non_simple_class_annotation_is_not_deferred_like_cpython() { + let code = compile_exec( + "\ +class C: + x.y: list = [] + z: int +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + let names = annotate + .names + .iter() + .map(|name| name.as_str()) + .collect::>(); + assert_eq!(names, vec!["int"]); + } + + #[test] + fn test_non_simple_annotation_only_consumes_symbol_table_cursor() { + let code = compile_exec( + "\ +class C: + x.y: (lambda: str) = [] + z: (lambda: int) +", + ); + let annotate = find_code(&code, "__annotate__").expect("missing __annotate__ code"); + let lambdas = annotate + .constants + .iter() + .filter_map(|constant| match constant { + ConstantData::Code { code } if code.obj_name == "" => Some(code.as_ref()), + _ => None, + }) + .collect::>(); + assert_eq!(lambdas.len(), 1); + assert_eq!( + lambdas[0] + .names + .iter() + .map(|name| name.as_str()) + .collect::>(), + vec!["int"] + ); + } + #[test] fn test_type_param_evaluator_uses_dot_format_varname() { let code = compile_exec( @@ -27473,6 +30528,27 @@ def func[T](a: T = 'a', *, b: T = 'b'): ); } + #[test] + fn test_generic_function_type_params_varnames_include_defaults_like_cpython() { + let code = compile_exec( + "\ +def func[T](): + pass +", + ); + let type_params = + find_code(&code, "").expect("missing type params code"); + assert_eq!(type_params.arg_count, 0); + assert_eq!( + type_params + .varnames + .iter() + .map(String::as_str) + .collect::>(), + vec![".defaults", "T"] + ); + } + #[test] fn test_class_type_param_bound_prefers_classdict_over_outer_function_local() { let code = compile_exec( @@ -27773,96 +30849,318 @@ def f(): } #[test] - fn test_tuple_not_keeps_to_bool_unary_not_like_cpython() { + fn test_tuple_not_keeps_to_bool_unary_not_like_cpython() { + let code = compile_exec( + "\ +def f(): + return not () +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + + assert!( + ops.windows(3).any(|window| { + matches!(window[0].op, Instruction::LoadConst { consti } + if matches!( + &f.constants[consti.get(OpArg::new(u32::from(u8::from(window[0].arg))))], + ConstantData::Tuple { elements } if elements.is_empty() + )) && matches!(window[1].op, Instruction::ToBool) + && matches!(window[2].op, Instruction::UnaryNot) + }), + "CPython codegen emits TO_BOOL; UNARY_NOT for UnaryOp(Not), while flowgraph.c folds tuple literals only after the LOAD_CONST+TO_BOOL pass, got instructions={:?}", + f.instructions + ); + } + + #[test] + fn test_tuple_if_test_keeps_to_bool_jump_like_cpython() { + let code = compile_exec( + "\ +def f(): + if (): + return 1 + return 2 +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let ops = f + .instructions + .iter() + .filter(|unit| !matches!(unit.op, Instruction::Cache)) + .collect::>(); + + assert!( + ops.windows(3).any(|window| { + matches!(window[0].op, Instruction::LoadConst { consti } + if matches!( + &f.constants[consti.get(OpArg::new(u32::from(u8::from(window[0].arg))))], + ConstantData::Tuple { elements } if elements.is_empty() + )) && matches!(window[1].op, Instruction::ToBool) + && matches!(window[2].op, Instruction::PopJumpIfFalse { .. }) + }), + "CPython leaves tuple literal truth tests as LOAD_CONST tuple; TO_BOOL; POP_JUMP_IF_FALSE because tuple folding happens after constant jump folding, got instructions={:?}", + f.instructions + ); + } + + #[test] + fn test_constant_list_iterable_uses_tuple() { + let code = compile_exec( + "\ +def f(): + return {x: y for x, y in [(1, 2), ]} +", + ); + let f = find_code(&code, "f").expect("missing function code"); + + assert!( + !f.instructions + .iter() + .any(|unit| matches!(unit.op, Instruction::BuildList { .. })), + "constant list iterable should avoid BUILD_LIST before GET_ITER" + ); + assert!(f.constants.iter().any(|constant| matches!( + constant, + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ConstantData::Tuple { elements: inner }] + if matches!( + inner.as_slice(), + [ + ConstantData::Integer { .. }, + ConstantData::Integer { .. } + ] + ) + ) + ))); + } + + #[test] + fn test_constant_list_iterable_preserves_cpython_const_order() { + let code = compile_exec( + "\ +def f(): + for x in ['a', 'b', 'c']: + pass +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let constants = f.constants.iter().collect::>(); + + assert!( + matches!(constants[0], ConstantData::Str { value } if value.to_string() == "a"), + "CPython emits list elements as LOAD_CONST before flowgraph folds GET_ITER lists" + ); + assert!(matches!(constants[1], ConstantData::None)); + assert!(matches!( + constants[2], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Str { value: first }, + ConstantData::Str { value: second }, + ConstantData::Str { value: third }, + ] if first.to_string() == "a" + && second.to_string() == "b" + && third.to_string() == "c" + ) + )); + } + + #[test] + fn test_try_except_folded_tuple_consts_follow_cpython_block_order() { + let code = compile_exec( + "\ +def f(macrelease): + try: + g() + except ValueError: + macrelease = (10, 3) + if macrelease >= (10, 4): + pass +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let constants = f.constants.iter().collect::>(); + + assert!( + constants.windows(2).any(|window| { + matches!( + window, + [ + ConstantData::Tuple { elements: first }, + ConstantData::Tuple { elements: second }, + ] if matches!( + (first.as_slice(), second.as_slice()), + ( + [ + ConstantData::Integer { value: a }, + ConstantData::Integer { value: b }, + ], + [ + ConstantData::Integer { value: c }, + ConstantData::Integer { value: d }, + ], + ) if a == &BigInt::from(10) + && b == &BigInt::from(3) + && c == &BigInt::from(10) + && d == &BigInt::from(4) + ) + ) + }), + "CPython flowgraph.c walks b_next order, so the except-body tuple is folded before the following if-test tuple; got {constants:?}" + ); + } + + #[test] + fn test_small_set_membership_folds_before_later_unary_const_like_cpython() { let code = compile_exec( - "\ -def f(): - return not () -", + r#" +def f(method, n): + if method not in {"linear", "ranked"}: + pass + if method == "ranked": + start = (n - 1) / -2 +"#, ); let f = find_code(&code, "f").expect("missing function code"); - let ops = f - .instructions + let constants = f.constants.iter().collect::>(); + let frozenset_index = constants .iter() - .filter(|unit| !matches!(unit.op, Instruction::Cache)) - .collect::>(); + .position(|constant| matches!(constant, ConstantData::Frozenset { .. })) + .expect("missing folded membership frozenset"); + let negative_two_index = constants + .iter() + .position(|constant| { + matches!( + constant, + ConstantData::Integer { value } if value == &BigInt::from(-2) + ) + }) + .expect("missing folded -2 constant"); assert!( - ops.windows(3).any(|window| { - matches!(window[0].op, Instruction::LoadConst { consti } - if matches!( - &f.constants[consti.get(OpArg::new(u32::from(u8::from(window[0].arg))))], - ConstantData::Tuple { elements } if elements.is_empty() - )) && matches!(window[1].op, Instruction::ToBool) - && matches!(window[2].op, Instruction::UnaryNot) - }), - "CPython codegen emits TO_BOOL; UNARY_NOT for UnaryOp(Not), while flowgraph.c folds tuple literals only after the LOAD_CONST+TO_BOOL pass, got instructions={:?}", - f.instructions + frozenset_index < negative_two_index, + "CPython flowgraph.c optimizes BUILD_SET+CONTAINS_OP inline before folding the later unary -2; got {constants:?}" ); } #[test] - fn test_tuple_if_test_keeps_to_bool_jump_like_cpython() { + fn test_boolop_const_order_keeps_cpython_codegen_constants() { let code = compile_exec( "\ -def f(): - if (): - return 1 - return 2 +def or_false(x): + return False or x + +def zero_or_tuple(): + return 0 or (1, -1) + +def tuple_or_tuple(): + return (1, -1) or (-1, 1) ", ); - let f = find_code(&code, "f").expect("missing function code"); - let ops = f - .instructions - .iter() - .filter(|unit| !matches!(unit.op, Instruction::Cache)) - .collect::>(); + let or_false = find_code(&code, "or_false").expect("missing or_false code"); + let constants = or_false.constants.iter().collect::>(); + assert_eq!(constants.len(), 1); assert!( - ops.windows(3).any(|window| { - matches!(window[0].op, Instruction::LoadConst { consti } - if matches!( - &f.constants[consti.get(OpArg::new(u32::from(u8::from(window[0].arg))))], - ConstantData::Tuple { elements } if elements.is_empty() - )) && matches!(window[1].op, Instruction::ToBool) - && matches!(window[2].op, Instruction::PopJumpIfFalse { .. }) - }), - "CPython leaves tuple literal truth tests as LOAD_CONST tuple; TO_BOOL; POP_JUMP_IF_FALSE because tuple folding happens after constant jump folding, got instructions={:?}", - f.instructions + matches!(constants[0], ConstantData::Boolean { value: false }), + "CPython registers the skipped boolop literal before flowgraph removes the branch" + ); + + let zero_or_tuple = find_code(&code, "zero_or_tuple").expect("missing zero_or_tuple code"); + let constants = zero_or_tuple.constants.iter().collect::>(); + assert_eq!(constants.len(), 2); + assert!( + matches!( + constants[0], + ConstantData::Integer { value } if value == &BigInt::from(0) + ) && matches!( + constants[1], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Integer { value: one }, + ConstantData::Integer { value: minus_one }, + ] if one == &BigInt::from(1) && minus_one == &BigInt::from(-1) + ) + ), + "CPython keeps the skipped scalar literal before the folded tuple constant" + ); + + let tuple_or_tuple = + find_code(&code, "tuple_or_tuple").expect("missing tuple_or_tuple code"); + let constants = tuple_or_tuple.constants.iter().collect::>(); + assert_eq!(constants.len(), 3); + assert!( + matches!( + constants[0], + ConstantData::Integer { value } if value == &BigInt::from(1) + ) && matches!( + constants[1], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Integer { value: one }, + ConstantData::Integer { value: minus_one }, + ] if one == &BigInt::from(1) && minus_one == &BigInt::from(-1) + ) + ) && matches!( + constants[2], + ConstantData::Tuple { elements } + if matches!( + elements.as_slice(), + [ + ConstantData::Integer { value: minus_one }, + ConstantData::Integer { value: one }, + ] if minus_one == &BigInt::from(-1) && one == &BigInt::from(1) + ) + ), + "CPython compiles boolop tuple heads before flowgraph folds them" ); } #[test] - fn test_constant_list_iterable_uses_tuple() { + fn test_lambda_without_body_constants_keeps_none_like_cpython() { + let code = compile_exec("f = lambda x: x"); + let lambda = find_code(&code, "").expect("missing lambda code"); + let constants = lambda.constants.iter().collect::>(); + assert_eq!(constants.len(), 1); + + assert!( + matches!(constants[0], ConstantData::None), + "CPython AddReturnAtEnd registers None for constant-free lambdas" + ); + } + + #[test] + fn test_call_function_ex_empty_args_tuple_is_folded_late_like_cpython() { let code = compile_exec( "\ -def f(): - return {x: y for x, y in [(1, 2), ]} +def f(g, kwargs, ns): + g(**kwargs) + ns['T'] ", ); let f = find_code(&code, "f").expect("missing function code"); + let constants = f.constants.iter().collect::>(); + assert_eq!(constants.len(), 3); assert!( - !f.instructions - .iter() - .any(|unit| matches!(unit.op, Instruction::BuildList { .. })), - "constant list iterable should avoid BUILD_LIST before GET_ITER" + matches!(constants[0], ConstantData::Str { value } if value.to_string() == "T") + && matches!(constants[1], ConstantData::None) + && matches!(constants[2], ConstantData::Tuple { elements } if elements.is_empty()), + "CPython emits BUILD_TUPLE 0 for CALL_FUNCTION_EX args and folds it after earlier constants" ); - assert!(f.constants.iter().any(|constant| matches!( - constant, - ConstantData::Tuple { elements } - if matches!( - elements.as_slice(), - [ConstantData::Tuple { elements: inner }] - if matches!( - inner.as_slice(), - [ - ConstantData::Integer { .. }, - ConstantData::Integer { .. } - ] - ) - ) - ))); } #[test] @@ -28086,6 +31384,220 @@ def g(): ); } + #[test] + fn test_comprehension_list_iterable_build_uses_iter_location_like_cpython() { + let code = compile_exec( + "\ +async def f(i): + return i + +async def run_list(): + return [await c for c in [f(1), f(41)]] +", + ); + let run_list = find_code(&code, "run_list").expect("missing run_list code"); + assert_eq!( + run_list.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xdc, 0x1e, 0x1f, 0xa0, 0x01, 0x9b, 0x64, 0xa4, 0x41, 0xa0, + 0x62, 0xa3, 0x45, 0x99, 0x5d, 0xd3, 0x0b, 0x2b, 0x99, 0x5d, 0x98, 0x01, 0x8f, 0x47, + 0x8a, 0x47, 0x99, 0x5d, 0xd1, 0x0b, 0x2b, 0xd0, 0x04, 0x2b, 0x89, 0x47, 0xf9, 0xd2, + 0x0b, 0x2b, 0xf9, + ], + "CPython codegen_comprehension_iter() emits GET_ITER at LOC(comp->iter)" + ); + } + + #[test] + fn test_comprehension_boolop_iter_get_iter_uses_iter_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + return any(not w.cancelled() for w in (self._waiters or ())) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + let get_iter_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::GetIter).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert!( + get_iter_positions.contains(&(2, 44, 2, 63)), + "CPython codegen_comprehension_iter() emits GET_ITER at LOC(comp->iter), got {get_iter_positions:?}" + ); + } + + #[test] + fn test_inlined_comprehension_backedges_use_element_location_like_cpython() { + let code = compile_exec( + "\ +async def f(i): + return i + +async def run_list(): + return [s for c in [f(''), f('abc')] for s in await c] +", + ); + let run_list = find_code(&code, "run_list").expect("missing run_list code"); + assert_eq!( + run_list.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xdc, 0x18, 0x19, 0x98, 0x22, 0x9b, 0x05, 0x9c, 0x71, 0xa0, + 0x15, 0x9b, 0x78, 0xd1, 0x17, 0x28, 0xd4, 0x0b, 0x3a, 0xd1, 0x17, 0x28, 0x90, 0x21, + 0xb7, 0x27, 0xb2, 0x27, 0xa8, 0x51, 0x8a, 0x41, 0xb1, 0x27, 0x89, 0x41, 0xd1, 0x17, + 0x28, 0xd2, 0x0b, 0x3a, 0xd0, 0x04, 0x3a, 0xb1, 0x27, 0xf9, 0xd3, 0x0b, 0x3a, 0xf9, + ], + "CPython codegen_sync_comprehension_generator() emits comprehension backedges at elt_loc" + ); + } + + #[test] + fn test_nested_dict_comprehension_outer_backedge_uses_key_location_like_cpython() { + let code = compile_exec( + "\ +def f(items): + return {op: i for i, ops in items for op in ops} +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let backedge_positions: Vec<_> = f + .instructions + .iter() + .zip(&f.locations) + .filter_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::JumpBackward { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .collect(); + + assert!( + backedge_positions.contains(&(2, 13, 2, 18)), + "CPython extends only the terminal dict-comprehension MAP_ADD/backedge location from key through value, got {backedge_positions:?}" + ); + assert!( + backedge_positions.contains(&(2, 13, 2, 15)), + "CPython keeps outer dict-comprehension generator backedges at LOC(key), got {backedge_positions:?}" + ); + } + + #[test] + fn test_inlined_comprehension_filter_jump_uses_element_location_like_cpython() { + let code = compile_exec( + "\ +def f(self): + return [action for action in self._actions if action.option_strings] +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let filter_jump_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::PopJumpIfTrue { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing optimized filter jump"); + assert_eq!( + filter_jump_position, + (2, 13, 2, 19), + "CPython inlined comprehension filter jump inherits the element/backedge location after CFG cleanup" + ); + } + + #[test] + fn test_inlined_comprehension_ifexp_guard_jump_uses_body_location_like_cpython() { + let code = compile_exec( + "\ +def f(fields): + return [f for f in fields if (f.compare if f.hash is None else f.hash)] +", + ); + let f = find_code(&code, "f").expect("missing function code"); + let jump_forward_position = f + .instructions + .iter() + .zip(&f.locations) + .find_map(|(unit, (location, end_location))| { + matches!(unit.op, Instruction::JumpForward { .. }).then_some(( + location.line.get(), + location.character_offset.get(), + end_location.line.get(), + end_location.character_offset.get(), + )) + }) + .expect("missing if-expression body jump"); + assert_eq!( + jump_forward_position, + (2, 35, 2, 44), + "CPython flowgraph.c::propagate_line_numbers() copies the if-expression body location onto the NO_LOCATION jump" + ); + } + + #[test] + fn test_inlined_async_comprehension_end_async_for_uses_comprehension_location_like_cpython() { + let code = compile_exec( + "\ +async def f(it): + for i in it: + yield i + +async def run_list(): + return [i + 1 async for i in f([10, 20])] +", + ); + let run_list = find_code(&code, "run_list").expect("missing run_list code"); + assert_eq!( + run_list.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xdc, 0x21, 0x22, 0xa0, 0x42, 0xa8, 0x02, 0xa0, 0x38, 0xa4, + 0x1b, 0xd7, 0x0b, 0x2d, 0xd3, 0x0b, 0x2d, 0x98, 0x41, 0x90, 0x01, 0x8f, 0x45, 0x88, + 0x45, 0xd4, 0x0b, 0x2d, 0xd0, 0x04, 0x2d, 0xf9, 0xd2, 0x0b, 0x2d, 0xf9, + ], + "CPython codegen_async_comprehension_generator() emits END_ASYNC_FOR at comprehension loc" + ); + } + + #[test] + fn test_async_for_anext_sequence_uses_statement_location_like_cpython() { + let code = compile_exec( + "\ +async def f(source, buffer): + async for i1, i2 in source(): + buffer.append(i1 + i2) +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.linetable.as_ref(), + &[ + 0xe9, 0x00, 0x80, 0x00, 0xd9, 0x18, 0x1e, 0x9c, 0x08, 0xf7, 0x00, 0x01, 0x05, 0x1f, + 0xf0, 0x00, 0x01, 0x05, 0x1f, 0x89, 0x66, 0x88, 0x62, 0xd8, 0x08, 0x0e, 0x8f, 0x0d, + 0x89, 0x0d, 0x90, 0x62, 0x95, 0x67, 0xd6, 0x08, 0x1e, 0xf1, 0x03, 0x01, 0x05, 0x1f, + 0x9a, 0x08, 0xf9, + ], + "CPython codegen_async_for() emits GET_ANEXT/yield-from scaffolding at LOC(s)" + ); + } + #[test] fn test_nested_comprehension_list_iterable_uses_tuple() { let code = compile_exec( @@ -28684,6 +32196,25 @@ def f(seq, emit): ); } + #[test] + fn test_inlined_comprehension_namedexpr_varnames_match_cpython_order() { + let code = compile_exec( + "\ +def f(): + def spam(a): + return a + input_data = [1, 2, 3] + res = [(x, y, x / y) for x in input_data if (y := spam(x)) > 0] + return res +", + ); + let f = find_code(&code, "f").expect("missing f code"); + assert_eq!( + f.varnames.iter().map(String::as_str).collect::>(), + vec!["spam", "input_data", "x", "y", "res"] + ); + } + #[test] fn test_global_namedexpr_in_inlined_comprehension_saves_fast_slot() { let code = compile_exec( @@ -28712,6 +32243,27 @@ def f(seq, value): })); } + #[test] + fn test_namedexpr_copy_uses_namedexpr_location_like_cpython() { + let code = compile_exec( + "\ +def outer(): + a = 10 + def spam(): + nonlocal a + (a := 20) +", + ); + let spam = find_code(&code, "spam").expect("missing spam code"); + + // CPython 3.14 NamedExpr_kind emits COPY at LOC(named expression), + // between visiting the value and visiting the target. + assert_eq!( + spam.linetable.as_ref(), + &[0xf8, 0x80, 0x00, 0xe0, 0x0e, 0x10, 0x88, 0x17, 0x8b, 0x11,] + ); + } + #[test] fn test_genexpr_namedexpr_target_is_cell_not_fast_local() { let code = compile_exec( @@ -28730,6 +32282,31 @@ def f(seq): ); } + #[test] + fn test_public_cellvars_follow_cpython_localsplus_order() { + let code = compile_exec( + "\ +def f(): + x = 10 + t = False + g = ((i, j) for i in range(x) if t for j in range(x)) + [x for x in range(3)] + return g +", + ); + let f = find_code(&code, "f").expect("missing f code"); + + assert_eq!( + f.varnames.iter().map(String::as_str).collect::>(), + ["g", "x"] + ); + assert_eq!( + f.cellvars.iter().map(String::as_str).collect::>(), + ["x", "t"], + "CPython assemble.c exposes co_cellvars in localsplus order: merged local cells before non-local cells" + ); + } + #[test] fn test_inlined_comprehension_restore_does_not_form_store_fast_load_fast() { let code = compile_exec( diff --git a/crates/codegen/src/ir.rs b/crates/codegen/src/ir.rs index 47be101f0a..db1119873b 100644 --- a/crates/codegen/src/ir.rs +++ b/crates/codegen/src/ir.rs @@ -28,6 +28,8 @@ struct LineTableLocation { end_col: i32, } +pub(crate) const LINE_ONLY_LOCATION_OVERRIDE: i32 = -4; + const MAX_INT_SIZE_BITS: u64 = 128; const MAX_COLLECTION_SIZE: usize = 256; const MAX_TOTAL_ITEMS: isize = 1024; @@ -35,13 +37,92 @@ const MAX_STR_SIZE: usize = 4096; const MIN_CONST_SEQUENCE_SIZE: usize = 3; const STACK_USE_GUIDELINE: usize = 30; +#[derive(Clone, Debug, Default)] +pub struct ConstantPool { + constants: Vec, +} + +impl ConstantPool { + fn constant_contains_nan(constant: &ConstantData) -> bool { + match constant { + ConstantData::Float { value } => value.is_nan(), + ConstantData::Complex { value } => value.re.is_nan() || value.im.is_nan(), + ConstantData::Tuple { elements } | ConstantData::Frozenset { elements } => { + elements.iter().any(Self::constant_contains_nan) + } + ConstantData::Slice { elements } => elements.iter().any(Self::constant_contains_nan), + _ => false, + } + } + + pub fn insert_full(&mut self, constant: ConstantData) -> (usize, bool) { + // CPython's _PyCode_ConstantKey() keeps NaN-bearing constants distinct + // because Python-level NaN keys do not compare equal. + if !Self::constant_contains_nan(&constant) + && let Some(idx) = self + .constants + .iter() + .position(|existing| existing == &constant) + { + return (idx, false); + } + let idx = self.constants.len(); + self.constants.push(constant); + (idx, true) + } + + pub fn insert(&mut self, constant: ConstantData) -> bool { + self.insert_full(constant).1 + } + + #[must_use] + pub fn get_index(&self, idx: usize) -> Option<&ConstantData> { + self.constants.get(idx) + } + + pub fn iter(&self) -> core::slice::Iter<'_, ConstantData> { + self.constants.iter() + } + + #[must_use] + pub fn len(&self) -> usize { + self.constants.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.constants.is_empty() + } + + pub fn clear(&mut self) { + self.constants.clear(); + } +} + +impl ops::Index for ConstantPool { + type Output = ConstantData; + + fn index(&self, idx: usize) -> &Self::Output { + &self.constants[idx] + } +} + +impl IntoIterator for ConstantPool { + type Item = ConstantData; + type IntoIter = alloc::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.constants.into_iter() + } +} + /// Metadata for a code unit // = _PyCompile_CodeUnitMetadata #[derive(Clone, Debug)] pub struct CodeUnitMetadata { pub name: String, // u_name (obj_name) pub qualname: Option, // u_qualname - pub consts: IndexSet, // u_consts + pub consts: ConstantPool, // u_consts pub names: IndexSet, // u_names pub varnames: IndexSet, // u_varnames pub cellvars: IndexSet, // u_cellvars @@ -140,6 +221,12 @@ pub struct InstructionInfo { /// This is the final jump emitted by codegen_break() after unwinding the /// iterator for a for-loop break. pub for_loop_break_cleanup_jump: bool, + /// Keep this conditional jump's own location when the preceding TO_BOOL + /// normally propagates its condition location into the jump. + pub preserve_tobool_jump_location: bool, + /// Keep the jump location copied from the second STORE_FAST NOP created + /// by STORE_FAST_STORE_FAST fusion. + pub preserve_store_fast_store_fast_jump_location: bool, } /// Exception handler information for an instruction. @@ -167,6 +254,8 @@ fn set_to_nop(info: &mut InstructionInfo) { info.match_success_jump = false; info.break_continue_cleanup_jump = false; info.for_loop_break_cleanup_jump = false; + info.preserve_tobool_jump_location = false; + info.preserve_store_fast_store_fast_jump_location = false; } fn nop_out_no_location(info: &mut InstructionInfo) { @@ -204,6 +293,8 @@ pub struct Block { pub load_fast_passthrough: bool, /// Continuation label that CPython attaches to a preceding empty block. pub load_fast_label_reuse_passthrough: bool, + /// If-expression orelse label emitted inside another conditional statement. + pub conditional_ifexp_orelse_entry: bool, } impl Default for Block { @@ -221,6 +312,7 @@ impl Default for Block { load_fast_barrier: false, load_fast_passthrough: false, load_fast_label_reuse_passthrough: false, + conditional_ifexp_orelse_entry: false, } } } @@ -295,17 +387,9 @@ impl CodeInfo { self.fold_set_constants(); self.optimize_lists_and_sets(); self.convert_to_load_small_int(); - self.remove_unused_consts(); // DCE always runs (removes dead code after terminal instructions) self.dce(); - // BUILD_TUPLE n + UNPACK_SEQUENCE n → NOP + SWAP (n=2,3) or NOP+NOP (n=1) - self.optimize_build_tuple_unpack(); - // Dead store elimination for duplicate STORE_FAST targets - // (apply_static_swaps in CPython's flowgraph.c) - self.eliminate_dead_stores(); - // apply_static_swaps: reorder stores to eliminate SWAPs - self.apply_static_swaps(); // Peephole optimizer handles constant and compare folding. self.peephole_optimize(); // Per-block walker first to preserve CPython-style instruction-order @@ -318,7 +402,6 @@ impl CodeInfo { self.fold_set_constants(); self.optimize_lists_and_sets(); self.convert_to_load_small_int(); - self.remove_unused_consts(); // CPython's CFG builder starts a new basic block after a terminator. // Peephole constant-jump folding can create new terminators, so split // before DCE clears unreachable successor instructions; otherwise the @@ -346,6 +429,14 @@ impl CodeInfo { // superinstruction insertion, so fusion decisions see propagated // source locations. resolve_line_numbers(&mut self.blocks); + // CPython flowgraph.c::optimize_cfg() runs optimize_basic_block() + // after the first resolve_line_numbers(). Keep tuple-unpack SWAP + // creation, duplicate STORE_FAST cleanup, and apply_static_swaps() + // here so synthetic no-location exits inherit the same pre-swap + // source locations as CPython. + self.optimize_build_tuple_unpack(); + self.eliminate_dead_stores(); + self.apply_static_swaps(); self.remove_nops(); self.add_checks_for_loads_of_uninitialized_variables(); // CPython inserts superinstructions in _PyCfg_OptimizeCodeUnit, before @@ -526,7 +617,7 @@ impl CodeInfo { if next_lineno == lineno { remove = true; } else if next_lineno < 0 { - src_instructions[src + 1].lineno_override = Some(lineno); + copy_instruction_location(instr, &mut src_instructions[src + 1]); remove = true; } } @@ -578,6 +669,8 @@ impl CodeInfo { } resolve_next_location_overrides(&mut blocks); + propagate_store_fast_store_fast_jump_locations(&mut blocks); + propagate_tobool_conditional_jump_locations(&mut blocks); // Pre-compute cache_entries for real (non-pseudo) instructions for block in &mut blocks { @@ -747,13 +840,31 @@ impl CodeInfo { info.arg.instr_size() + cache_count, )); // Collect linetable locations with lineno_override support - let lt_loc = LineTableLocation { - line: info - .lineno_override - .unwrap_or_else(|| info.location.line.get() as i32), - end_line: info.end_location.line.get() as i32, - col: info.location.character_offset.to_zero_indexed() as i32, - end_col: info.end_location.character_offset.to_zero_indexed() as i32, + let lt_loc = match info.lineno_override { + Some(-1) => LineTableLocation { + line: -1, + end_line: -1, + col: -1, + end_col: -1, + }, + Some(LINE_ONLY_LOCATION_OVERRIDE) => LineTableLocation { + line: info.location.line.get() as i32, + end_line: info.end_location.line.get() as i32, + col: -1, + end_col: -1, + }, + Some(lineno) => LineTableLocation { + line: lineno, + end_line: info.end_location.line.get() as i32, + col: info.location.character_offset.to_zero_indexed() as i32, + end_col: info.end_location.character_offset.to_zero_indexed() as i32, + }, + None => LineTableLocation { + line: info.location.line.get() as i32, + end_line: info.end_location.line.get() as i32, + col: info.location.character_offset.to_zero_indexed() as i32, + end_col: info.end_location.character_offset.to_zero_indexed() as i32, + }, }; linetable_locations.extend(core::iter::repeat_n(lt_loc, info.arg.instr_size())); // CACHE entries inherit parent instruction's location @@ -801,6 +912,20 @@ impl CodeInfo { // Generate exception table before moving source_path let exceptiontable = generate_exception_table(&blocks, &block_to_index); + // CPython builds u_cellvars in dictbytype() order, but the public + // co_cellvars tuple follows localsplus order from assemble.c: + // cell locals already present in varnames first, then remaining cells. + let final_cellvars = varname_cache + .iter() + .filter(|name| cellvar_cache.contains(name.as_str())) + .chain( + cellvar_cache + .iter() + .filter(|name| !varname_cache.contains(name.as_str())), + ) + .cloned() + .collect::>(); + // Build localspluskinds with cell-local merging let nlocals = varname_cache.len(); let ncells = cellvar_cache.len(); @@ -854,7 +979,7 @@ impl CodeInfo { constants: constants.into_iter().collect(), names: name_cache.into_iter().collect(), varnames: varname_cache.into_iter().collect(), - cellvars: cellvar_cache.into_iter().collect(), + cellvars: final_cellvars.into_boxed_slice(), freevars: freevar_cache.into_iter().collect(), localspluskinds: localspluskinds.into_boxed_slice(), linetable, @@ -1060,6 +1185,27 @@ impl CodeInfo { } } + fn instr_make_load_const( + metadata: &mut CodeUnitMetadata, + instr: &mut InstructionInfo, + constant: ConstantData, + ) { + if let ConstantData::Integer { value } = &constant + && let Some(small) = value.to_i32().filter(|v| (0..=255).contains(v)) + { + instr.instr = Opcode::LoadSmallInt.into(); + instr.arg = OpArg::new(small as u32); + return; + } + + let (const_idx, _) = metadata.consts.insert_full(constant); + instr.instr = Instruction::LoadConst { + consti: Arg::marker(), + } + .into(); + instr.arg = OpArg::new(const_idx as u32); + } + /// Try to fold a single unary instruction at position `i` in `block`. /// Returns true if folded. Mirrors CPython fold_const_unaryop(). fn fold_unary_constant_at( @@ -1100,7 +1246,6 @@ impl CodeInfo { let Some(folded_const) = Self::eval_unary_constant(&operand, op, intrinsic) else { return false; }; - let (const_idx, _) = metadata.consts.insert_full(folded_const); nop_out_no_location(&mut block.instructions[operand_index]); let mut prev = operand_index; while let Some(idx) = prev.checked_sub(1) { @@ -1111,18 +1256,15 @@ impl CodeInfo { block.instructions[idx].end_location = block.instructions[i].end_location; prev = idx; } - block.instructions[i].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); + Self::instr_make_load_const(metadata, &mut block.instructions[i], folded_const); block.instructions[i].folded_from_nonliteral_expr = false; true } /// Fold constant unary operations following CPython fold_const_unaryop(). fn fold_unary_constants(&mut self) { - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; let mut i = 0; while i < block.instructions.len() { if Self::fold_unary_constant_at(&mut self.metadata, block, i) { @@ -1194,6 +1336,16 @@ impl CodeInfo { None } + fn block_next_order(&self) -> Vec { + let mut order = Vec::new(); + let mut current = BlockIdx(0); + while current != BlockIdx::NULL { + order.push(current); + current = self.blocks[current.idx()].next; + } + order + } + /// Try to fold a single BINARY_OP instruction at position `i` in `block`. /// Returns true if folded. Mirrors CPython fold_const_binop(). fn fold_binop_constant_at( @@ -1224,18 +1376,13 @@ impl CodeInfo { let Some(result_const) = Self::eval_binop(&left_val, &right_val, op) else { return false; }; - let (const_idx, _) = metadata.consts.insert_full(result_const); let folded_from_nonliteral_expr = operand_indices .iter() .any(|&idx| block.instructions[idx].folded_from_nonliteral_expr); for &idx in &operand_indices { nop_out_no_location(&mut block.instructions[idx]); } - block.instructions[i].instr = Instruction::LoadConst { - consti: Arg::marker(), - } - .into(); - block.instructions[i].arg = OpArg::new(const_idx as u32); + Self::instr_make_load_const(metadata, &mut block.instructions[i], result_const); block.instructions[i].folded_from_nonliteral_expr = folded_from_nonliteral_expr; true } @@ -1244,7 +1391,8 @@ impl CodeInfo { /// into a single LOAD_CONST when the result is computable at compile time. /// = fold_binops_on_constants in CPython flowgraph.c fn fold_binop_constants(&mut self) { - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; let mut i = 0; while i < block.instructions.len() { if Self::fold_binop_constant_at(&mut self.metadata, block, i) { @@ -1260,27 +1408,26 @@ impl CodeInfo { /// unary, and binop constant folding. Mirrors optimize_basic_block() in /// flowgraph.c so constants are registered in co_consts in instruction /// order rather than in the order separate global passes would discover - /// them. Iterates per block to a fixed point so an inner fold can enable - /// a surrounding outer fold within the same block. + /// them. CPython runs optimize_basic_block() once per basic block, with a + /// single forward scan, so this deliberately does not iterate to a fixed + /// point. fn fold_constants_per_block(&mut self) { - for block in &mut self.blocks { - loop { - let mut changed = false; - let mut i = 0; - while i < block.instructions.len() { - let folded = Self::fold_tuple_constant_at(&mut self.metadata, block, i) - || Self::fold_list_constant_at(&mut self.metadata, block, i) - || Self::fold_set_constant_at(&mut self.metadata, block, i) - || Self::fold_unary_constant_at(&mut self.metadata, block, i) - || Self::fold_binop_constant_at(&mut self.metadata, block, i); - if folded { - changed = true; - } - i += 1; - } - if !changed { - break; - } + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; + let mut i = 0; + while i < block.instructions.len() { + let _ = Self::fold_tuple_constant_at(&mut self.metadata, block, i) + || Self::optimize_iterable_or_contains_collection_at( + &mut self.metadata, + block, + i, + ) + || Self::fold_list_constant_at(&mut self.metadata, block, i) + || Self::fold_set_constant_at(&mut self.metadata, block, i) + || Self::fold_constant_intrinsic_list_to_tuple_at(&mut self.metadata, block, i) + || Self::fold_unary_constant_at(&mut self.metadata, block, i) + || Self::fold_binop_constant_at(&mut self.metadata, block, i); + i += 1; } } } @@ -1588,6 +1735,46 @@ impl CodeInfo { return eval_const_subscript(left, right); } + fn constant_as_int(value: &ConstantData) -> Option<(BigInt, bool)> { + match value { + ConstantData::Boolean { value } => Some((BigInt::from(u8::from(*value)), true)), + ConstantData::Integer { value } => Some((value.clone(), false)), + _ => None, + } + } + + if let (Some((left_int, left_is_bool)), Some((right_int, right_is_bool))) = + (constant_as_int(left), constant_as_int(right)) + && (left_is_bool || right_is_bool) + { + if left_is_bool && right_is_bool { + match op { + BinOp::And => { + return Some(ConstantData::Boolean { + value: !left_int.is_zero() & !right_int.is_zero(), + }); + } + BinOp::Or => { + return Some(ConstantData::Boolean { + value: !left_int.is_zero() | !right_int.is_zero(), + }); + } + BinOp::Xor => { + return Some(ConstantData::Boolean { + value: !left_int.is_zero() ^ !right_int.is_zero(), + }); + } + _ => {} + } + } + + return Self::eval_binop( + &ConstantData::Integer { value: left_int }, + &ConstantData::Integer { value: right_int }, + op, + ); + } + match (left, right) { (ConstantData::Integer { value: l }, ConstantData::Integer { value: r }) => { let result = match op { @@ -1902,6 +2089,14 @@ impl CodeInfo { if func.get(block.instructions[i].arg) != IntrinsicFunction1::ListToTuple { return false; } + if block + .instructions + .get(i + 1) + .and_then(|instr| instr.instr.real()) + .is_some_and(|instr| matches!(instr, Instruction::GetIter)) + { + return false; + } let mut consts_found = 0usize; let mut expect_append = true; @@ -2036,11 +2231,80 @@ impl CodeInfo { true } + /// CPython's optimize_basic_block() calls optimize_lists_and_sets() in + /// place for BUILD_LIST/BUILD_SET. This handles the GET_ITER/CONTAINS_OP + /// subset there so small membership collections are folded before later + /// unary/binop constants in the same block. + fn optimize_iterable_or_contains_collection_at( + metadata: &mut CodeUnitMetadata, + block: &mut Block, + i: usize, + ) -> bool { + let Some(instr) = block.instructions[i].instr.real() else { + return false; + }; + let is_list = matches!(instr, Instruction::BuildList { .. }); + let is_set = matches!(instr, Instruction::BuildSet { .. }); + if !is_list && !is_set { + return false; + } + + let next_is_iter_or_contains = block + .instructions + .get(i + 1) + .and_then(|next| next.instr.real()) + .is_some_and(|next| { + matches!(next, Instruction::GetIter | Instruction::ContainsOp { .. }) + }); + if !next_is_iter_or_contains { + return false; + } + + let seq_size = u32::from(block.instructions[i].arg) as usize; + if seq_size > STACK_USE_GUIDELINE { + return false; + } + + let Some((operand_indices, elements)) = + Self::get_const_sequence(metadata, block, i, seq_size) + else { + if is_list { + block.instructions[i].instr = Opcode::BuildTuple.into(); + return true; + } + return false; + }; + + let const_data = if is_set { + ConstantData::Frozenset { elements } + } else { + ConstantData::Tuple { elements } + }; + let (const_idx, _) = metadata.consts.insert_full(const_data); + let folded_loc = block.instructions[i].location; + let end_loc = block.instructions[i].end_location; + let eh = block.instructions[i].except_handler; + + for &j in &operand_indices { + set_to_nop(&mut block.instructions[j]); + block.instructions[j].location = folded_loc; + block.instructions[j].end_location = end_loc; + } + + block.instructions[i].instr = Opcode::LoadConst.into(); + block.instructions[i].arg = OpArg::new(const_idx as u32); + block.instructions[i].location = folded_loc; + block.instructions[i].end_location = end_loc; + block.instructions[i].except_handler = eh; + true + } + /// Constant folding: fold LOAD_CONST/LOAD_SMALL_INT + BUILD_TUPLE into LOAD_CONST tuple /// fold_tuple_of_constants. This also folds constant list/set literals /// in block order to match CPython's optimize_basic_block() const-table order. fn fold_tuple_constants(&mut self) { - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; let mut i = 0; while i < block.instructions.len() { if Self::fold_tuple_constant_at(&mut self.metadata, block, i) @@ -2058,7 +2322,8 @@ impl CodeInfo { /// Fold constant list literals: LOAD_CONST* + BUILD_LIST N → /// BUILD_LIST 0 + LOAD_CONST (tuple) + LIST_EXTEND 1 fn fold_list_constants(&mut self) { - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; let mut i = 0; while i < block.instructions.len() { let instr = &block.instructions[i]; @@ -2135,7 +2400,8 @@ impl CodeInfo { /// - Previously folded BUILD_LIST 0 + LOAD_CONST + LIST_EXTEND and /// BUILD_SET 0 + LOAD_CONST + SET_UPDATE collapse back to LOAD_CONST. fn optimize_lists_and_sets(&mut self) { - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; let mut i = 0; while i + 1 < block.instructions.len() { if matches!( @@ -2359,7 +2625,8 @@ impl CodeInfo { /// Fold constant set literals: LOAD_CONST* + BUILD_SET N → /// BUILD_SET 0 + LOAD_CONST (frozenset-as-tuple) + SET_UPDATE 1 fn fold_set_constants(&mut self) { - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; let mut i = 0; while i < block.instructions.len() { let instr = &block.instructions[i]; @@ -3103,7 +3370,8 @@ impl CodeInfo { /// Convert LOAD_CONST for small integers to LOAD_SMALL_INT /// maybe_instr_make_load_smallint fn convert_to_load_small_int(&mut self) { - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; for instr in &mut block.instructions { // Check if it's a LOAD_CONST instruction let Some(Instruction::LoadConst { .. }) = instr.instr.real() else { @@ -3145,7 +3413,8 @@ impl CodeInfo { let mut used = vec![false; nconsts]; used[0] = true; - for block in &self.blocks { + for block_idx in self.block_next_order() { + let block = &self.blocks[block_idx]; for instr in &block.instructions { if let Some(Instruction::LoadConst { .. }) = instr.instr.real() { let idx = u32::from(instr.arg) as usize; @@ -3182,7 +3451,8 @@ impl CodeInfo { } // Update LOAD_CONST instruction arguments - for block in &mut self.blocks { + for block_idx in self.block_next_order() { + let block = &mut self.blocks[block_idx]; for instr in &mut block.instructions { if let Some(Instruction::LoadConst { .. }) = instr.instr.real() { let old_idx = u32::from(instr.arg) as usize; @@ -3338,6 +3608,28 @@ impl CodeInfo { i += 1; continue; } + let first_store_location = curr.location; + let second_store_location = next.location; + let second_store_end_location = next.end_location; + let second_store_lineno_override = next.lineno_override; + let mut after_idx = i + 2; + while after_idx < block.instructions.len() { + let after = &mut block.instructions[after_idx]; + if after.instr.is_unconditional_jump() { + if instruction_lineno(after) < 0 + || after.location == first_store_location + { + after.location = second_store_location; + after.end_location = second_store_end_location; + after.lineno_override = second_store_lineno_override; + } + break; + } + if instruction_lineno(after) >= 0 { + break; + } + after_idx += 1; + } let packed = (idx1 << 4) | idx2; block.instructions[i].instr = Instruction::StoreFastStoreFast { var_nums: Arg::marker(), @@ -3345,6 +3637,8 @@ impl CodeInfo { .into(); block.instructions[i].arg = OpArg::new(packed); set_to_nop(&mut block.instructions[i + 1]); + block.instructions[i + 1].preserve_store_fast_store_fast_jump_location = + true; i += 1; } _ => i += 1, @@ -4335,6 +4629,7 @@ impl CodeInfo { self.debug_block_dump(), )); } + self.fold_constants_per_block(); self.fold_binop_constants(); self.fold_unary_constants(); self.fold_binop_constants(); @@ -4345,23 +4640,19 @@ impl CodeInfo { self.fold_set_constants(); self.optimize_lists_and_sets(); self.convert_to_load_small_int(); - self.remove_unused_consts(); self.dce(); - self.optimize_build_tuple_unpack(); - self.eliminate_dead_stores(); - self.apply_static_swaps(); self.peephole_optimize(); trace.push(( "after_peephole_optimize".to_owned(), self.debug_block_dump(), )); + self.fold_constants_per_block(); self.fold_tuple_constants(); self.fold_binop_constants(); self.fold_list_constants(); self.fold_set_constants(); self.optimize_lists_and_sets(); self.convert_to_load_small_int(); - self.remove_unused_consts(); split_blocks_at_jumps(&mut self.blocks); trace.push(( "after_split_blocks_at_jumps".to_owned(), @@ -4384,6 +4675,9 @@ impl CodeInfo { trace.push(("after_jump_threading".to_owned(), self.debug_block_dump())); self.eliminate_unreachable_blocks(); resolve_line_numbers(&mut self.blocks); + self.optimize_build_tuple_unpack(); + self.eliminate_dead_stores(); + self.apply_static_swaps(); trace.push(( "after_first_resolve_line_numbers".to_owned(), self.debug_block_dump(), @@ -4753,6 +5047,18 @@ fn generate_linetable( // Get column information (only when debug_ranges is enabled) let col = loc.col; let end_col = loc.end_col; + if (col < 0 || end_col < 0) && end_line == line { + linetable.push( + 0x80 | ((PyCodeLocationInfoKind::NoColumns as u8) << 3) + | ((entry_length - 1) as u8), + ); + write_signed_varint(&mut linetable, line_delta); + + prev_line = line; + length -= entry_length; + i += entry_length; + continue; + } // Choose the appropriate encoding based on line delta and column info if line_delta == 0 && end_line_delta == 0 { @@ -4807,8 +5113,11 @@ fn generate_linetable( ); write_signed_varint(&mut linetable, line_delta); write_varint(&mut linetable, end_line_delta as u32); - write_varint(&mut linetable, (col as u32) + 1); - write_varint(&mut linetable, (end_col as u32) + 1); + write_varint(&mut linetable, if col < 0 { 0 } else { (col as u32) + 1 }); + write_varint( + &mut linetable, + if end_col < 0 { 0 } else { (end_col as u32) + 1 }, + ); } prev_line = line; @@ -4825,34 +5134,100 @@ fn generate_exception_table(blocks: &[Block], block_to_index: &[u32]) -> Box<[u8 let mut entries: Vec = Vec::new(); let mut current_entry: Option<(ExceptHandlerInfo, u32)> = None; // (handler_info, start_index) let mut instr_index = 0u32; - let instructions: Vec<&InstructionInfo> = iter_blocks(blocks) - .flat_map(|(_, block)| block.instructions.iter()) + let instructions: Vec<(BlockIdx, usize, &InstructionInfo)> = iter_blocks(blocks) + .flat_map(|(idx, block)| { + block + .instructions + .iter() + .enumerate() + .map(move |(instr_idx, instr)| (idx, instr_idx, instr)) + }) .collect(); + let mut jump_targets = vec![false; blocks.len()]; + for (_, block) in iter_blocks(blocks) { + for instr in &block.instructions { + if instr.target != BlockIdx::NULL { + jump_targets[instr.target.idx()] = true; + } + } + } let same_handler = |left: ExceptHandlerInfo, right: ExceptHandlerInfo| { block_to_index[left.handler_block.idx()] == block_to_index[right.handler_block.idx()] && left.stack_depth == right.stack_depth && left.preserve_lasti == right.preserve_lasti }; + let mut conditional_jumps_since_exit = 0usize; // Iterate through all instructions in block order // instr_index is the index into the final instructions array (including EXTENDED_ARG) // This matches how frame.rs uses lasti - for (pos, instr) in instructions.iter().enumerate() { + for (pos, &(block_idx, instr_idx, instr)) in instructions.iter().enumerate() { // CPython's final exception table is keyed by bytecode offsets after // empty cleanup labels have been resolved. RustPython can still have // distinct block ids for those labels here, so compare handler offsets. - let effective_except_handler = if instr.except_handler.is_none() - && matches!(instr.instr.real(), Some(Instruction::NotTaken)) - && let Some((current_handler, _)) = current_entry - && let Some(next) = instructions.get(pos + 1) - && let Some(next_handler) = next.except_handler - && same_handler(current_handler, next_handler) - && !next.instr.is_scope_exit() - { - Some(current_handler) - } else { - instr.except_handler - }; + let next = instructions.get(pos + 1).copied(); + let next_is_jump_target_block = next.is_some_and(|(next_block, _, _)| { + next_block != block_idx + && instr_idx + 1 == blocks[block_idx.idx()].instructions.len() + && jump_targets[next_block.idx()] + }); + let next_is_normalized_backward_jump = next.is_some_and(|(next_block, _, _)| { + next_block != block_idx + && instr_idx + 1 == blocks[block_idx.idx()].instructions.len() + && matches!( + blocks[next_block.idx()].instructions.as_slice(), + [not_taken, jump] + if matches!(not_taken.instr.real(), Some(Instruction::NotTaken)) + && jump.instr.is_unconditional_jump() + && jump.target != BlockIdx::NULL + && comes_before(blocks, jump.target, next_block) + ) + }); + let previous_is_conditional_ifexp_jump = pos.checked_sub(1).is_some_and(|prev_pos| { + let (_, _, previous) = instructions[prev_pos]; + previous.target != BlockIdx::NULL + && is_conditional_jump(&previous.instr) + && blocks[previous.target.idx()].conditional_ifexp_orelse_entry + }); + let previous_is_general_bool_conditional_jump = + pos.checked_sub(1).is_some_and(|prev_pos| { + let (_, _, previous) = instructions[prev_pos]; + matches!( + previous.instr.real(), + Some(Instruction::PopJumpIfFalse { .. } | Instruction::PopJumpIfTrue { .. }) + ) + }); + let previous_jump_uses_to_bool = pos.checked_sub(1).is_some_and(|prev_pos| { + let (_, _, previous) = instructions[prev_pos]; + matches!( + previous.instr.real(), + Some(Instruction::PopJumpIfFalse { .. } | Instruction::PopJumpIfTrue { .. }) + ) && instructions[..prev_pos] + .iter() + .rev() + .find(|(_, _, info)| !matches!(info.instr.real(), Some(Instruction::Cache))) + .is_some_and(|(_, _, info)| matches!(info.instr.real(), Some(Instruction::ToBool))) + }); + let effective_except_handler = + if is_conditional_jump(&instr.instr) && next_is_normalized_backward_jump { + None + } else if instr.except_handler.is_none() + && matches!(instr.instr.real(), Some(Instruction::NotTaken)) + && let Some((current_handler, _)) = current_entry + && let Some((_, _, next)) = next + && let Some(next_handler) = next.except_handler + && same_handler(current_handler, next_handler) + && !next.instr.is_scope_exit() + && !(next_is_jump_target_block && previous_jump_uses_to_bool) + && !previous_is_conditional_ifexp_jump + && !(conditional_jumps_since_exit > 1 + && previous_jump_uses_to_bool + && previous_is_general_bool_conditional_jump) + { + Some(current_handler) + } else { + instr.except_handler + }; // instr_size includes EXTENDED_ARG and CACHE entries let instr_size = instr.arg.instr_size() as u32 + instr.cache_entries; @@ -4864,6 +5239,7 @@ fn generate_exception_table(blocks: &[Block], block_to_index: &[u32]) -> Box<[u8 // No current entry, handler starts - begin new entry (None, Some(handler)) => { current_entry = Some((handler, instr_index)); + conditional_jumps_since_exit = 0; } // Current entry exists, same handler - continue @@ -4880,6 +5256,7 @@ fn generate_exception_table(blocks: &[Block], block_to_index: &[u32]) -> Box<[u8 curr_handler.preserve_lasti, )); current_entry = Some((handler, instr_index)); + conditional_jumps_since_exit = 0; } // Current entry exists, no handler - finish current entry @@ -4896,6 +5273,13 @@ fn generate_exception_table(blocks: &[Block], block_to_index: &[u32]) -> Box<[u8 } } + if effective_except_handler.is_some() && is_conditional_jump(&instr.instr) { + conditional_jumps_since_exit += 1; + } + if instr.instr.is_scope_exit() { + conditional_jumps_since_exit = 0; + } + instr_index += instr_size; // Account for EXTENDED_ARG instructions } @@ -5069,6 +5453,8 @@ fn push_cold_blocks_to_end(blocks: &mut Vec) { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }); jump_block.next = blocks[cold_idx.idx()].next; blocks[cold_idx.idx()].next = jump_block_idx; @@ -5207,7 +5593,7 @@ fn retarget_assert_conditional_jumps_to_empty_predecessor(blocks: &mut [Block]) let assertion_lines: Vec> = blocks.iter().map(assertion_failure_start_line).collect(); - for block in blocks { + for block in &mut *blocks { for instr in &mut block.instructions { if instr.target == BlockIdx::NULL || !is_conditional_jump(&instr.instr) { continue; @@ -6146,6 +6532,7 @@ fn jump_threading_impl(blocks: &mut [Block], include_conditional: bool) { && target_ins.target != BlockIdx::NULL && target_ins.target != target { + let conditional = is_conditional_jump(&ins.instr); if !include_conditional && blocks[target.idx()] .instructions @@ -6190,7 +6577,6 @@ fn jump_threading_impl(blocks: &mut [Block], include_conditional: bool) { .get(final_target.idx()) .copied() .unwrap_or(u32::MAX); - let conditional = is_conditional_jump(&ins.instr); if !include_conditional && source_pos < target_pos && final_target_pos < target_pos @@ -6254,6 +6640,7 @@ fn jump_threading_impl(blocks: &mut [Block], include_conditional: bool) { threaded.target = final_target; threaded.location = target_ins.location; threaded.end_location = target_ins.end_location; + threaded.lineno_override = target_ins.lineno_override; threaded.cache_entries = 0; blocks[bi].instructions.push(threaded); changed = true; @@ -6334,7 +6721,7 @@ fn normalize_jumps(blocks: &mut Vec) { // has no i_except edge. except_handler: None, folded_from_nonliteral_expr: false, - lineno_override: None, + lineno_override: last_ins.lineno_override, cache_entries: 0, preserve_redundant_jump_as_nop: false, remove_no_location_nop: false, @@ -6344,6 +6731,8 @@ fn normalize_jumps(blocks: &mut Vec) { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }; blocks[idx].instructions.push(not_taken); } else { @@ -6376,7 +6765,7 @@ fn normalize_jumps(blocks: &mut Vec) { // after exception targets were labelled. except_handler: None, folded_from_nonliteral_expr: false, - lineno_override: None, + lineno_override: last_ins.lineno_override, cache_entries: 0, preserve_redundant_jump_as_nop: false, remove_no_location_nop: false, @@ -6386,6 +6775,8 @@ fn normalize_jumps(blocks: &mut Vec) { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }); new_block.instructions.push(InstructionInfo { instr: PseudoOpcode::Jump.into(), @@ -6397,7 +6788,7 @@ fn normalize_jumps(blocks: &mut Vec) { // an exception-table range. except_handler: None, folded_from_nonliteral_expr: false, - lineno_override: None, + lineno_override: last_ins.lineno_override, cache_entries: 0, preserve_redundant_jump_as_nop: false, remove_no_location_nop: false, @@ -6407,6 +6798,8 @@ fn normalize_jumps(blocks: &mut Vec) { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }); new_block.next = old_next; @@ -6768,8 +7161,10 @@ fn remove_redundant_nops_in_blocks(blocks: &mut [Block]) -> usize { && src_instructions[src + 1].target != block_idx { let next_lineno = instruction_lineno(&src_instructions[src + 1]); - if next_lineno == lineno || next_lineno < 0 { - src_instructions[src + 1].lineno_override = Some(lineno); + if next_lineno < 0 { + copy_instruction_location(instr, &mut src_instructions[src + 1]); + remove = true; + } else if next_lineno == lineno { remove = true; } } else if src_instructions[src + 1].folded_from_nonliteral_expr { @@ -6779,7 +7174,7 @@ fn remove_redundant_nops_in_blocks(blocks: &mut [Block]) -> usize { if next_lineno == lineno { remove = true; } else if next_lineno < 0 { - src_instructions[src + 1].lineno_override = Some(lineno); + copy_instruction_location(instr, &mut src_instructions[src + 1]); remove = true; } } @@ -7464,7 +7859,12 @@ fn materialize_empty_conditional_exit_targets(blocks: &mut [Block]) { continue; } if let Some(first) = blocks[target.idx()].instructions.first_mut() { - overwrite_location(first, source.location, source.end_location); + overwrite_location( + first, + source.location, + source.end_location, + source.lineno_override, + ); } } @@ -7476,7 +7876,12 @@ fn materialize_empty_conditional_exit_targets(blocks: &mut [Block]) { continue; }; let mut cloned = blocks[next.idx()].instructions[0]; - overwrite_location(&mut cloned, last.location, last.end_location); + overwrite_location( + &mut cloned, + last.location, + last.end_location, + last.lineno_override, + ); blocks[target.idx()].instructions.push(cloned); } @@ -7492,7 +7897,7 @@ fn materialize_empty_conditional_exit_targets(blocks: &mut [Block]) { end_location: source.end_location, except_handler: None, folded_from_nonliteral_expr: false, - lineno_override: None, + lineno_override: source.lineno_override, cache_entries: 0, preserve_redundant_jump_as_nop: false, remove_no_location_nop: false, @@ -7502,6 +7907,8 @@ fn materialize_empty_conditional_exit_targets(blocks: &mut [Block]) { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }); } @@ -7527,7 +7934,7 @@ fn materialize_empty_conditional_exit_targets(blocks: &mut [Block]) { end_location: source.end_location, except_handler: None, folded_from_nonliteral_expr: false, - lineno_override: None, + lineno_override: source.lineno_override, cache_entries: 0, preserve_redundant_jump_as_nop: false, remove_no_location_nop: false, @@ -7537,6 +7944,8 @@ fn materialize_empty_conditional_exit_targets(blocks: &mut [Block]) { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, }, ); } @@ -7607,17 +8016,32 @@ fn block_tail_starts_with_async_with_normal_exit(instructions: &[InstructionInfo } fn instruction_lineno(instr: &InstructionInfo) -> i32 { - instr - .lineno_override - .unwrap_or_else(|| instr.location.line.get() as i32) + match instr.lineno_override { + Some(LINE_ONLY_LOCATION_OVERRIDE) | None => instr.location.line.get() as i32, + Some(lineno) => lineno, + } } fn instruction_has_lineno(instr: &InstructionInfo) -> bool { - instruction_lineno(instr) > 0 + instruction_lineno(instr) >= 0 +} + +fn copy_instruction_location(source: InstructionInfo, target: &mut InstructionInfo) { + target.location = source.location; + target.end_location = source.end_location; + target.lineno_override = source.lineno_override; + target.preserve_store_fast_store_fast_jump_location = + source.preserve_store_fast_store_fast_jump_location; } -fn propagation_location(instr: &InstructionInfo) -> Option<(SourceLocation, SourceLocation)> { - instruction_has_lineno(instr).then_some((instr.location, instr.end_location)) +fn propagation_location( + instr: &InstructionInfo, +) -> Option<(SourceLocation, SourceLocation, Option)> { + instruction_has_lineno(instr).then_some(( + instr.location, + instr.end_location, + instr.lineno_override, + )) } fn block_has_fallthrough(block: &Block) -> bool { @@ -7631,6 +8055,21 @@ fn is_jump_instruction(instr: &InstructionInfo) -> bool { instr.instr.is_unconditional_jump() || is_conditional_jump(&instr.instr) } +fn last_jump_for_line_propagation(block: &Block) -> Option { + let last = block.instructions.last().copied()?; + if matches!(last.instr.real(), Some(Instruction::NotTaken)) { + block + .instructions + .iter() + .rev() + .copied() + .find(|instr| !matches!(instr.instr.real(), Some(Instruction::NotTaken))) + .filter(is_jump_instruction) + } else { + is_jump_instruction(&last).then_some(last) + } +} + fn is_exit_without_lineno(blocks: &[Block], block_idx: BlockIdx) -> bool { let block = &blocks[block_idx.idx()]; let Some(first) = block.instructions.first() else { @@ -8795,7 +9234,7 @@ fn reorder_conditional_chain_and_jump_back_blocks(blocks: &mut Vec) { .is_some_and(|info| matches!(info.lineno_override, Some(line) if line < 0)); if is_generic_false_path_reorder && jump_targets_for_iter(blocks, jump_block) - && is_for_break_cleanup_block(blocks, chain_start) + && is_for_break_cleanup_block(blocks, next_nonempty_block(blocks, chain_start)) { current = next; continue; @@ -8922,7 +9361,12 @@ fn reorder_conditional_scope_exit_and_jump_back_blocks( if jump_block == BlockIdx::NULL { return false; } - let Some(info) = blocks[jump_block.idx()].instructions.first() else { + let Some(info) = blocks[jump_block.idx()].instructions.iter().find(|info| { + !matches!( + info.instr.real(), + Some(Instruction::Nop | Instruction::NotTaken) + ) + }) else { return false; }; matches!( @@ -8940,7 +9384,12 @@ fn reorder_conditional_scope_exit_and_jump_back_blocks( if !is_explicit_continue_to_for_iter(blocks, jump_block) { return false; } - let Some(info) = blocks[jump_block.idx()].instructions.first() else { + let Some(info) = blocks[jump_block.idx()].instructions.iter().find(|info| { + !matches!( + info.instr.real(), + Some(Instruction::Nop | Instruction::NotTaken) + ) + }) else { return false; }; instruction_lineno(info) > instruction_lineno(&cond) @@ -9059,9 +9508,10 @@ fn reorder_conditional_scope_exit_and_jump_back_blocks( // a fallthrough backward jump. This Rust layout pass must not // undo that normalized shape. || (is_jump_back_only_block(blocks, jump_block) - && next_nonempty_block(blocks, blocks[jump_block.idx()].next) == exit_block) - || (jump_targets_for_iter(blocks, jump_block) - && !is_explicit_continue_after_conditional(blocks, jump_block, cond)) + && next_nonempty_block(blocks, blocks[jump_block.idx()].next) == exit_block + && !(jump_targets_for_iter(blocks, jump_block) + && !is_explicit_continue_after_conditional(blocks, jump_block, cond) + && !block_is_protected(&blocks[idx]))) || next_nonempty_block(blocks, blocks[jump_block.idx()].next) != exit_block || !comes_before( blocks, @@ -10179,11 +10629,12 @@ fn maybe_propagate_location( instr: &mut InstructionInfo, location: SourceLocation, end_location: SourceLocation, + lineno_override: Option, ) { if instr.lineno_override != Some(-2) && !instruction_has_lineno(instr) { instr.location = location; instr.end_location = end_location; - instr.lineno_override = None; + instr.lineno_override = lineno_override; } } @@ -10191,10 +10642,11 @@ fn overwrite_location( instr: &mut InstructionInfo, location: SourceLocation, end_location: SourceLocation, + lineno_override: Option, ) { instr.location = location; instr.end_location = end_location; - instr.lineno_override = None; + instr.lineno_override = lineno_override; } fn compute_reachable_blocks(blocks: &[Block]) -> Vec { @@ -10344,9 +10796,9 @@ fn duplicate_exits_without_lineno(blocks: &mut Vec, predecessors: &mut Ve let new_idx = BlockIdx(blocks.len() as u32); let mut new_block = blocks[target.idx()].clone(); if let Some(first) = new_block.instructions.first_mut() - && let Some((location, end_location)) = propagation_location(last) + && let Some((location, end_location, lineno_override)) = propagation_location(last) { - overwrite_location(first, location, end_location); + overwrite_location(first, location, end_location, lineno_override); } let old_next = blocks[target.idx()].next; new_block.next = old_next; @@ -10381,10 +10833,10 @@ fn duplicate_exits_without_lineno(blocks: &mut Vec, predecessors: &mut Ve )) && (is_exit_without_lineno(blocks, target) || is_eval_break_without_lineno(blocks, target)) - && let Some((location, end_location)) = propagation_location(last) + && let Some((location, end_location, lineno_override)) = propagation_location(last) && let Some(first) = blocks[target.idx()].instructions.first_mut() { - maybe_propagate_location(first, location, end_location); + maybe_propagate_location(first, location, end_location, lineno_override); } } current = blocks[current.idx()].next; @@ -10406,14 +10858,14 @@ fn propagate_line_numbers(blocks: &mut [Block], predecessors: &[u32]) { let block = &mut blocks[current.idx()]; let mut prev_location = None; for instr in &mut block.instructions { - if let Some((location, end_location)) = prev_location { - maybe_propagate_location(instr, location, end_location); + if let Some((location, end_location, lineno_override)) = prev_location { + maybe_propagate_location(instr, location, end_location, lineno_override); } prev_location = propagation_location(instr); } prev_location }; - let last = blocks[current.idx()].instructions.last().copied().unwrap(); + let last_jump = last_jump_for_line_propagation(&blocks[current.idx()]); if has_fallthrough { let target = next_nonempty_block(blocks, next_block); @@ -10426,15 +10878,15 @@ fn propagate_line_numbers(blocks: &mut [Block], predecessors: &[u32]) { current, target, )) - && let Some((location, end_location)) = prev_location + && let Some((location, end_location, lineno_override)) = prev_location && let Some(first) = blocks[target.idx()].instructions.first_mut() { - maybe_propagate_location(first, location, end_location); + maybe_propagate_location(first, location, end_location, lineno_override); } } - if is_jump_instruction(&last) { - let mut target = next_nonempty_block(blocks, last.target); + if let Some(last_jump) = last_jump { + let mut target = next_nonempty_block(blocks, last_jump.target); while target != BlockIdx::NULL && blocks[target.idx()].instructions.is_empty() && predecessors[target.idx()] == 1 @@ -10450,10 +10902,10 @@ fn propagate_line_numbers(blocks: &mut [Block], predecessors: &[u32]) { current, target, )) - && let Some((location, end_location)) = prev_location + && let Some((location, end_location, lineno_override)) = prev_location && let Some(first) = blocks[target.idx()].instructions.first_mut() { - maybe_propagate_location(first, location, end_location); + maybe_propagate_location(first, location, end_location, lineno_override); } } } @@ -10504,6 +10956,125 @@ fn resolve_next_location_overrides(blocks: &mut [Block]) { } } +fn propagate_store_fast_store_fast_jump_locations(blocks: &mut [Block]) { + for block in blocks.iter_mut() { + for i in 1..block.instructions.len() { + let previous = block.instructions[i - 1]; + let follows_copy = i >= 2 + && matches!( + block.instructions[i - 2].instr.real(), + Some(Instruction::Copy { .. }) + ); + if !matches!( + previous.instr.real(), + Some(Instruction::StoreFastStoreFast { .. }) + ) || !block.instructions[i].instr.is_unconditional_jump() + || block.instructions[i].preserve_store_fast_store_fast_jump_location + || (follows_copy + && instruction_lineno(&block.instructions[i]) == instruction_lineno(&previous) + && block.instructions[i].location != previous.location) + { + continue; + } + let follows_unpack = i >= 2 + && matches!( + block.instructions[i - 2].instr.real(), + Some(Instruction::UnpackSequence { .. } | Instruction::UnpackEx { .. }) + ); + if follows_unpack && instruction_lineno(&block.instructions[i]) >= 0 { + continue; + } + block.instructions[i].location = previous.location; + block.instructions[i].end_location = previous.end_location; + block.instructions[i].lineno_override = previous.lineno_override; + } + } +} + +fn propagate_tobool_conditional_jump_locations(blocks: &mut [Block]) { + for block in blocks.iter_mut() { + let mut i = 1; + while i < block.instructions.len() { + if !matches!( + block.instructions[i - 1].instr.real(), + Some(Instruction::ToBool) + ) || !is_conditional_jump(&block.instructions[i].instr) + { + i += 1; + continue; + } + + let (location, end_location, lineno_override) = + if block.instructions[i].preserve_tobool_jump_location { + ( + block.instructions[i].location, + block.instructions[i].end_location, + block.instructions[i].lineno_override, + ) + } else { + ( + block.instructions[i - 1].location, + block.instructions[i - 1].end_location, + block.instructions[i - 1].lineno_override, + ) + }; + block.instructions[i].location = location; + block.instructions[i].end_location = end_location; + block.instructions[i].lineno_override = lineno_override; + + let mut j = i + 1; + if j < block.instructions.len() + && matches!( + block.instructions[j].instr.real(), + Some(Instruction::NotTaken) + ) + { + block.instructions[j].location = location; + block.instructions[j].end_location = end_location; + block.instructions[j].lineno_override = lineno_override; + j += 1; + } + if j < block.instructions.len() && block.instructions[j].instr.is_unconditional_jump() { + block.instructions[j].location = location; + block.instructions[j].end_location = end_location; + block.instructions[j].lineno_override = lineno_override; + } + + i = j; + } + } + + for idx in 0..blocks.len() { + let Some(last) = blocks[idx].instructions.last().copied() else { + continue; + }; + if !is_conditional_jump(&last.instr) { + continue; + } + let next = blocks[idx].next; + if next == BlockIdx::NULL { + continue; + } + let next_block = &mut blocks[next.idx()]; + if !next_block + .instructions + .first() + .is_some_and(|instr| matches!(instr.instr.real(), Some(Instruction::NotTaken))) + { + continue; + } + for instr in next_block.instructions.iter_mut().take(2) { + if matches!(instr.instr.real(), Some(Instruction::NotTaken)) + || instr.instr.is_unconditional_jump() + { + instr.location = last.location; + instr.end_location = last.end_location; + instr.lineno_override = last.lineno_override; + } + } + } +} + fn find_layout_predecessor(blocks: &[Block], target: BlockIdx) -> BlockIdx { if target == BlockIdx::NULL { return BlockIdx::NULL; @@ -10799,7 +11370,12 @@ fn duplicate_shared_jump_back_targets(blocks: &mut Vec) { let mut cloned = blocks[target.idx()].clone(); if let Some(first) = cloned.instructions.first_mut() { - overwrite_location(first, jump.location, jump.end_location); + overwrite_location( + first, + jump.location, + jump.end_location, + jump.lineno_override, + ); } let new_idx = BlockIdx(blocks.len() as u32); cloned.next = target; @@ -10812,7 +11388,12 @@ fn duplicate_shared_jump_back_targets(blocks: &mut Vec) { let jump = blocks[block_idx.idx()].instructions[instr_idx]; let mut cloned = blocks[target.idx()].clone(); if let Some(first) = cloned.instructions.first_mut() { - overwrite_location(first, jump.location, jump.end_location); + overwrite_location( + first, + jump.location, + jump.end_location, + jump.lineno_override, + ); } let new_idx = BlockIdx(blocks.len() as u32); @@ -10995,7 +11576,12 @@ fn duplicate_fallthrough_jump_back_targets(blocks: &mut Vec) { let new_idx = BlockIdx(blocks.len() as u32); let mut cloned = blocks[target.idx()].clone(); if let Some(first) = cloned.instructions.first_mut() { - overwrite_location(first, last.location, last.end_location); + overwrite_location( + first, + last.location, + last.end_location, + last.lineno_override, + ); } cloned.next = blocks[layout_pred.idx()].next; blocks.push(cloned); @@ -11200,13 +11786,13 @@ fn duplicate_end_returns(blocks: &mut Vec, metadata: &CodeUnitMetadata) { let propagated_location = blocks[block_idx.idx()] .instructions .last() - .map(|instr| (instr.location, instr.end_location)); + .map(|instr| (instr.location, instr.end_location, instr.lineno_override)); let mut cloned_return = return_insts.clone(); if !instruction_has_lineno(&cloned_return[0]) - && let Some((location, end_location)) = propagated_location + && let Some((location, end_location, lineno_override)) = propagated_location { for instr in &mut cloned_return { - overwrite_location(instr, location, end_location); + overwrite_location(instr, location, end_location, lineno_override); } } blocks[block_idx.idx()].instructions.extend(cloned_return); @@ -11218,7 +11804,12 @@ fn duplicate_end_returns(blocks: &mut Vec, metadata: &CodeUnitMetadata) { let jump = blocks[block_idx.idx()].instructions[instr_idx]; let mut cloned_return = return_insts.clone(); if let Some(first) = cloned_return.first_mut() { - overwrite_location(first, jump.location, jump.end_location); + overwrite_location( + first, + jump.location, + jump.end_location, + jump.lineno_override, + ); } let new_idx = BlockIdx(blocks.len() as u32); let is_conditional = is_conditional_jump(&jump.instr); @@ -11327,7 +11918,12 @@ fn inline_with_suppress_return_blocks(blocks: &mut [Block]) { let mut cloned_return = blocks[target.idx()].instructions.clone(); for instr in &mut cloned_return { - overwrite_location(instr, jump.location, jump.end_location); + overwrite_location( + instr, + jump.location, + jump.end_location, + jump.lineno_override, + ); } blocks[block_idx].instructions.pop(); blocks[block_idx].instructions.extend(cloned_return); @@ -11416,7 +12012,12 @@ fn duplicate_named_except_cleanup_returns(blocks: &mut Vec, metadata: &Co let jump = blocks[block_idx.idx()].instructions[instr_idx]; let mut cloned = blocks[target.idx()].instructions.clone(); if let Some(first) = cloned.first_mut() { - overwrite_location(first, jump.location, jump.end_location); + overwrite_location( + first, + jump.location, + jump.end_location, + jump.lineno_override, + ); } let new_idx = BlockIdx(blocks.len() as u32); @@ -11480,7 +12081,12 @@ fn inline_pop_except_return_blocks(blocks: &mut [Block]) { let mut cloned_return = blocks[target.idx()].instructions.clone(); for instr in &mut cloned_return { - overwrite_location(instr, jump.location, jump.end_location); + overwrite_location( + instr, + jump.location, + jump.end_location, + jump.lineno_override, + ); } blocks[block_idx].instructions.pop(); blocks[block_idx].instructions.extend(cloned_return); @@ -11890,6 +12496,8 @@ mod tests { match_success_jump: false, break_continue_cleanup_jump: false, for_loop_break_cleanup_jump: false, + preserve_tobool_jump_location: false, + preserve_store_fast_store_fast_jump_location: false, } } diff --git a/crates/codegen/src/symboltable.rs b/crates/codegen/src/symboltable.rs index 06eeaf520a..28133b1008 100644 --- a/crates/codegen/src/symboltable.rs +++ b/crates/codegen/src/symboltable.rs @@ -32,6 +32,9 @@ pub struct SymbolTable { // Return True if the block is a nested class or function pub is_nested: bool, + /// Whether this function-like scope was created directly in a class block. + pub is_method: bool, + /// A set of symbols present on this scope level. pub symbols: IndexMap, @@ -90,6 +93,7 @@ impl SymbolTable { typ, line_number, is_nested, + is_method: false, symbols: IndexMap::default(), sub_tables: vec![], next_sub_table: 0, @@ -1103,6 +1107,7 @@ impl SymbolTableBuilder { | CompilerScope::Lambda | CompilerScope::Comprehension | CompilerScope::Annotation + | CompilerScope::TypeParams ) } @@ -1118,11 +1123,17 @@ impl SymbolTableBuilder { } fn enter_scope(&mut self, name: &str, typ: CompilerScope, line_number: u32) { - let is_nested = self.tables.last().is_some_and(|table| { - table.is_nested - || matches!( - table.typ, - CompilerScope::Function | CompilerScope::AsyncFunction + let parent = self.tables.last(); + let is_nested = + parent.is_some_and(|table| table.is_nested || Self::is_function_like_scope(table.typ)); + let is_method = parent.is_some_and(|table| { + table.typ == CompilerScope::Class + && matches!( + typ, + CompilerScope::Function + | CompilerScope::AsyncFunction + | CompilerScope::Lambda + | CompilerScope::Comprehension ) }); // Inherit mangled_names from parent for non-class scopes @@ -1132,6 +1143,7 @@ impl SymbolTableBuilder { .and_then(|t| t.mangled_names.clone()) .filter(|_| typ != CompilerScope::Class); let mut table = SymbolTable::new(name.to_owned(), typ, line_number, is_nested); + table.is_method = is_method; table.future_annotations = self.future_annotations; table.mangled_names = inherited_mangled_names; self.tables.push(table); @@ -1145,6 +1157,8 @@ impl SymbolTableBuilder { name: &str, line_number: u32, for_class: bool, + has_defaults: bool, + has_kwdefaults: bool, ) -> SymbolTableResult { // Check if we're in a class scope let in_class = self @@ -1174,6 +1188,12 @@ impl SymbolTableBuilder { if for_class { self.register_name(".generic_base", SymbolUsage::Assigned, TextRange::default())?; } + if has_defaults { + self.register_name(".defaults", SymbolUsage::Parameter, TextRange::default())?; + } + if has_kwdefaults { + self.register_name(".kwdefaults", SymbolUsage::Parameter, TextRange::default())?; + } Ok(()) } @@ -1195,6 +1215,7 @@ impl SymbolTableBuilder { let can_see_class_scope = current.typ == CompilerScope::Class || current.can_see_class_scope; let has_conditional = current.has_conditional_annotations; + let is_nested = current.is_nested || Self::is_function_like_scope(current.typ); // Create annotation block if not exists if current.annotation_block.is_none() { @@ -1202,7 +1223,7 @@ impl SymbolTableBuilder { "__annotate__".to_owned(), CompilerScope::Annotation, line_number, - true, // is_nested + is_nested, ); // Annotation scope in class can see class scope annotation_table.can_see_class_scope = can_see_class_scope; @@ -1488,6 +1509,8 @@ impl SymbolTableBuilder { &format!("", name.as_str()), self.line_index_start(type_params.range), false, + true, + Self::has_kwonlydefaults(parameters), )?; self.scan_type_params(type_params)?; } @@ -1536,6 +1559,8 @@ impl SymbolTableBuilder { &format!("", name.as_str()), self.line_index_start(type_params.range), true, // for_class: enable selective mangling + false, + false, )?; // Set class_name for mangling in type param scope self.class_name = Some(name.to_string()); @@ -1847,6 +1872,8 @@ impl SymbolTableBuilder { &format!(""), self.line_index_start(type_params.range), false, + false, + false, )?; self.scan_type_params(type_params)?; } @@ -2583,6 +2610,13 @@ impl SymbolTableBuilder { Ok(()) } + fn has_kwonlydefaults(parameters: &ast::Parameters) -> bool { + parameters + .kwonlyargs + .iter() + .any(|arg| arg.default.is_some()) + } + fn enter_scope_with_parameters( &mut self, name: &str, @@ -2704,17 +2738,6 @@ impl SymbolTableBuilder { Ok(()) } - fn add_varname_to_scope(&mut self, table_idx: usize, name: &str) { - let varnames = if table_idx + 1 == self.tables.len() { - &mut self.current_varnames - } else { - &mut self.varnames_stack[table_idx + 1] - }; - if !varnames.iter().any(|existing| existing == name) { - varnames.push(name.to_owned()); - } - } - // Mirrors CPython symtable_extend_namedexpr_scope(): assignment expressions // inside comprehensions bind in the nearest function/module-like scope, not // in the synthetic comprehension scope itself. @@ -2752,9 +2775,6 @@ impl SymbolTableBuilder { match table_type { CompilerScope::Function | CompilerScope::AsyncFunction | CompilerScope::Lambda => { - let current_comp_inlined = self.tables.last().is_some_and(|table| { - table.typ == CompilerScope::Comprehension && table.comp_inlined - }); let parent_is_global = self.tables[table_idx] .symbols .get(mangled.as_str()) @@ -2777,9 +2797,6 @@ impl SymbolTableBuilder { .entry(mangled.clone()) .or_insert_with(|| Symbol::new(mangled.as_str())); symbol.flags.insert(SymbolFlags::ASSIGNED); - if !parent_is_global && current_comp_inlined { - self.add_varname_to_scope(table_idx, mangled.as_str()); - } return Ok(()); } CompilerScope::Module => { diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index 86723f4002..6f931bbd11 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -460,9 +460,12 @@ bitflags! { const GENERATOR = 0x0020; const COROUTINE = 0x0080; const ITERABLE_COROUTINE = 0x0100; + const ASYNC_GENERATOR = 0x0200; + const FUTURE_ANNOTATIONS = 0x1000000; /// If a code object represents a function and has a docstring, /// this bit is set and the first item in co_consts is the docstring. const HAS_DOCSTRING = 0x4000000; + const METHOD = 0x8000000; } } @@ -906,8 +909,6 @@ impl PartialEq for ConstantData { match (self, other) { (Integer { value: a }, Integer { value: b }) => a == b, - // we want to compare floats *by actual value* - if we have the *exact same* float - // already in a constant cache, we want to use that (Float { value: a }, Float { value: b }) => a.to_bits() == b.to_bits(), (Complex { value: a }, Complex { value: b }) => { a.re.to_bits() == b.re.to_bits() && a.im.to_bits() == b.im.to_bits() diff --git a/crates/literal/src/float.rs b/crates/literal/src/float.rs index 0856f646b2..79caca0592 100644 --- a/crates/literal/src/float.rs +++ b/crates/literal/src/float.rs @@ -3,7 +3,7 @@ use alloc::borrow::ToOwned; use alloc::format; use alloc::string::{String, ToString}; use core::f64; -use num_traits::{Float, Zero}; +use num_traits::Zero; pub fn parse_str(literal: &str) -> Option { parse_inner(literal.trim().as_bytes()) @@ -209,6 +209,111 @@ pub fn format_general( } } +fn prefer_cpython_tie_repr(s: String, value: f64) -> String { + let Some(exponent_pos) = s.find('e') else { + return s; + }; + let Some(digit_pos) = s[..exponent_pos].bytes().rposition(|b| b.is_ascii_digit()) else { + return s; + }; + + let digit = s.as_bytes()[digit_pos]; + if digit == b'0' { + return s; + } + let decremented = digit - 1; + if !(decremented - b'0').is_multiple_of(2) { + return s; + } + + let mut candidate = s.clone(); + candidate.replace_range( + digit_pos..digit_pos + 1, + core::str::from_utf8(&[decremented]).unwrap(), + ); + if parse_str(&candidate).is_none_or(|parsed| parsed.to_bits() != value.to_bits()) { + return s; + } + + let Some(current_distance) = decimal_distance_to_f64(&s, value) else { + return s; + }; + let Some(candidate_distance) = decimal_distance_to_f64(&candidate, value) else { + return s; + }; + + if candidate_distance <= current_distance { + candidate + } else { + s + } +} + +fn checked_pow_u128(base: u128, exp: u32) -> Option { + let mut result = 1u128; + for _ in 0..exp { + result = result.checked_mul(base)?; + } + Some(result) +} + +fn parse_decimal_rational(s: &str) -> Option<(u128, u32)> { + let exponent_pos = s.find('e')?; + let exponent = s[exponent_pos + 1..].parse::().ok()?; + let significand = s[..exponent_pos] + .strip_prefix('-') + .unwrap_or(&s[..exponent_pos]); + let dot_pos = significand.find('.'); + let frac_digits = dot_pos + .map(|pos| significand.len().saturating_sub(pos + 1)) + .unwrap_or(0); + let mut digits = String::with_capacity(significand.len()); + for ch in significand.chars() { + if ch != '.' { + digits.push(ch); + } + } + let mut int = digits.parse::().ok()?; + let mut scale = i32::try_from(frac_digits).ok()? - exponent; + if scale < 0 { + int = int.checked_mul(checked_pow_u128(10, (-scale) as u32)?)?; + scale = 0; + } + Some((int, scale as u32)) +} + +fn f64_mantissa_exponent(value: f64) -> Option<(u128, i32)> { + let bits = value.abs().to_bits(); + let exponent = ((bits >> 52) & 0x7ff) as i32; + let fraction = bits & ((1u64 << 52) - 1); + if exponent == 0 { + Some((u128::from(fraction), 1 - 1023 - 52)) + } else if exponent < 0x7ff { + Some((u128::from((1u64 << 52) | fraction), exponent - 1023 - 52)) + } else { + None + } +} + +fn decimal_distance_to_f64(s: &str, value: f64) -> Option { + let (decimal_int, decimal_scale) = parse_decimal_rational(s)?; + let (mantissa, binary_exponent) = f64_mantissa_exponent(value)?; + if binary_exponent >= 0 || decimal_scale > 38 { + return None; + } + + let binary_scale = u32::try_from(-binary_exponent).ok()?; + let common_twos = decimal_scale.max(binary_scale); + let decimal_scaled = + decimal_int.checked_mul(checked_pow_u128(2, common_twos - decimal_scale)?)?; + let five_power = checked_pow_u128(5, decimal_scale)?; + let binary_scaled = mantissa + .checked_mul(checked_pow_u128(2, common_twos - binary_scale)?)? + .checked_mul(five_power)?; + + Some(decimal_scaled.abs_diff(binary_scaled)) +} + // TODO: rewrite using format_general pub fn to_string(value: f64) -> String { let lit = format!("{value:e}"); @@ -223,7 +328,7 @@ pub fn to_string(value: f64) -> String { value.to_string() } } else { - format!("{significand}e{exponent:+#03}") + prefer_cpython_tie_repr(format!("{significand}e{exponent:+#03}"), value) } } else { let mut s = value.to_string(); @@ -232,6 +337,22 @@ pub fn to_string(value: f64) -> String { } } +#[cfg(test)] +mod tests { + use super::to_string; + + #[test] + fn repr_uses_cpython_tie_digit_for_power_of_two() { + assert_eq!(to_string(2.0f64.powi(-25)), "2.9802322387695312e-08"); + assert_eq!(to_string((-2.0f64).powi(-25)), "-2.9802322387695312e-08"); + assert_eq!(to_string(2.0f64.powi(-26)), "1.4901161193847656e-08"); + assert_eq!( + to_string(2.0f64.powi(-14) - 2.0f64.powi(-25)), + "6.1005353927612305e-05" + ); + } +} + pub fn from_hex(s: &str) -> Option { if let Ok(f) = hexf_parse::parse_hexf64(s, false) { return Some(f); @@ -281,22 +402,23 @@ pub fn from_hex(s: &str) -> Option { } pub fn to_hex(value: f64) -> String { - let (mantissa, exponent, sign) = value.integer_decode(); - let sign_fmt = if sign < 0 { "-" } else { "" }; + let bits = value.to_bits(); + let sign_fmt = if bits >> 63 != 0 { "-" } else { "" }; match value { value if value.is_zero() => format!("{sign_fmt}0x0.0p+0"), value if value.is_infinite() => format!("{sign_fmt}inf"), value if value.is_nan() => "nan".to_owned(), _ => { - const BITS: i16 = 52; - const FRACT_MASK: u64 = 0xf_ffff_ffff_ffff; - format!( - "{}{:#x}.{:013x}p{:+}", - sign_fmt, - mantissa >> BITS, - mantissa & FRACT_MASK, - exponent + BITS - ) + const FRACT_MASK: u64 = (1u64 << 52) - 1; + const EXP_MASK: u64 = 0x7ff; + let exponent = (bits >> 52) & EXP_MASK; + let fraction = bits & FRACT_MASK; + if exponent == 0 { + format!("{sign_fmt}0x0.{fraction:013x}p-1022") + } else { + let exponent = i32::try_from(exponent).unwrap() - 1023; + format!("{sign_fmt}0x1.{fraction:013x}p{exponent:+}") + } } } } @@ -304,6 +426,10 @@ pub fn to_hex(value: f64) -> String { #[test] fn test_to_hex() { use rand::Rng; + assert_eq!(to_hex(f64::from_bits(1)), "0x0.0000000000001p-1022"); + assert_eq!(to_hex(f64::from_bits(2)), "0x0.0000000000002p-1022"); + assert_eq!(to_hex(-f64::from_bits(1)), "-0x0.0000000000001p-1022"); + assert_eq!(to_hex(f64::MIN_POSITIVE), "0x1.0000000000000p-1022"); for _ in 0..20000 { let bytes = rand::rng().random::(); let f = f64::from_bits(bytes); diff --git a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap index a30fa6a78c..34b5ce7e5c 100644 --- a/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap +++ b/crates/stdlib/src/snapshots/rustpython_stdlib___opcode__tests__nested_double_async_with.snap @@ -13,9 +13,9 @@ expression: "dis(r#\"\nasync def test():\n for stop_exc in (StopIteration('sp Disassembly of ", line 1>: 1 RETURN_GENERATOR POP_TOP - RESUME 0 + L1: RESUME 0 - 2 L1: LOAD_GLOBAL 1 (StopIteration + NULL) + 2 LOAD_GLOBAL 1 (StopIteration + NULL) LOAD_CONST 0 ('spam') CALL 1 LOAD_GLOBAL 3 (StopAsyncIteration + NULL) @@ -90,10 +90,12 @@ Disassembly of ", line 1>: POP_TOP POP_TOP JUMP_FORWARD 3 (to L25) - L24: COPY 3 + + -- L24: COPY 3 POP_EXCEPT RERAISE 1 - L25: NOP + + 5 L25: NOP 10 L26: LOAD_GLOBAL 4 (self) LOAD_ATTR 13 (fail + NULL|self) @@ -153,11 +155,11 @@ Disassembly of ", line 1>: POP_TOP POP_TOP JUMP_BACKWARD 205 (to L2) - L39: COPY 3 + + -- L39: COPY 3 POP_EXCEPT RERAISE 1 - - -- L40: CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) + L40: CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) RERAISE 1 ExceptionTable: L1 to L3 -> L40 [0] lasti diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index bea8428110..6052e4fe25 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -564,7 +564,8 @@ impl Py { let is_gen = code.flags.contains(bytecode::CodeFlags::GENERATOR); let is_coro = code.flags.contains(bytecode::CodeFlags::COROUTINE); - let use_datastack = !(is_gen || is_coro); + let is_async_gen = code.flags.contains(bytecode::CodeFlags::ASYNC_GENERATOR); + let use_datastack = !(is_gen || is_coro || is_async_gen); // Construct frame: let frame = Frame::new( @@ -579,35 +580,30 @@ impl Py { .into_ref(&vm.ctx); self.fill_locals_from_args(&frame, func_args, vm)?; - match (is_gen, is_coro) { - (true, false) => { - let obj = PyGenerator::new(frame.clone(), self.__name__(), self.__qualname__()) - .into_pyobject(vm); - frame.set_generator(&obj); - Ok(obj) - } - (false, true) => { - let obj = PyCoroutine::new(frame.clone(), self.__name__(), self.__qualname__()) - .into_pyobject(vm); - frame.set_generator(&obj); - Ok(obj) - } - (true, true) => { - let obj = PyAsyncGen::new(frame.clone(), self.__name__(), self.__qualname__()) - .into_pyobject(vm); - frame.set_generator(&obj); - Ok(obj) - } - (false, false) => { - let result = vm.run_frame(frame.clone()); - // Release data stack memory after frame execution completes. - unsafe { - if let Some(base) = frame.materialize_localsplus() { - vm.datastack_pop(base); - } + if is_async_gen { + let obj = PyAsyncGen::new(frame.clone(), self.__name__(), self.__qualname__()) + .into_pyobject(vm); + frame.set_generator(&obj); + Ok(obj) + } else if is_gen { + let obj = PyGenerator::new(frame.clone(), self.__name__(), self.__qualname__()) + .into_pyobject(vm); + frame.set_generator(&obj); + Ok(obj) + } else if is_coro { + let obj = PyCoroutine::new(frame.clone(), self.__name__(), self.__qualname__()) + .into_pyobject(vm); + frame.set_generator(&obj); + Ok(obj) + } else { + let result = vm.run_frame(frame.clone()); + // Release data stack memory after frame execution completes. + unsafe { + if let Some(base) = frame.materialize_localsplus() { + vm.datastack_pop(base); } - result } + result } } @@ -689,11 +685,11 @@ impl Py { .intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS) ); debug_assert_eq!(code.kwonlyarg_count, 0); - debug_assert!( - !code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - ); + debug_assert!(!code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + )); let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) { None @@ -741,10 +737,11 @@ impl Py { // Generator/coroutine code objects are SIMPLE_FUNCTION in call // specialization classification, but their call path must still // go through invoke() to produce generator/coroutine objects. - if code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - { + if code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ) { return self.invoke(FuncArgs::from(args), vm); } let frame = self.prepare_exact_args_frame(args, vm); @@ -760,10 +757,11 @@ impl Py { } pub(crate) fn datastack_frame_size_bytes_for_code(code: &Py) -> Option { - if code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - { + if code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ) { return None; } let nlocalsplus = code.localspluskinds.len(); @@ -1468,9 +1466,11 @@ pub(crate) fn vectorcall_function( && !code.flags.contains(bytecode::CodeFlags::VARARGS) && !code.flags.contains(bytecode::CodeFlags::VARKEYWORDS) && code.kwonlyarg_count == 0 - && !code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE); + && !code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ); if is_simple && nargs == code.arg_count as usize { // FAST PATH: simple positional-only call, exact arg count. diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 1bafe7f26a..f1ed31d718 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -710,10 +710,11 @@ impl Frame { // For generators/coroutines, initialize prev_line to the def line // so that preamble instructions (RETURN_GENERATOR, POP_TOP) don't // fire spurious LINE events. - let prev_line = if code - .flags - .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) - { + let prev_line = if code.flags.intersects( + bytecode::CodeFlags::GENERATOR + | bytecode::CodeFlags::COROUTINE + | bytecode::CodeFlags::ASYNC_GENERATOR, + ) { code.first_line_number.map_or(0, |line| line.get() as u32) } else { 0 @@ -9523,9 +9524,7 @@ impl ExecutingFrame<'_> { // Returns the exception object; RERAISE will re-raise it if arg.fast_isinstance(vm.ctx.exceptions.stop_iteration) { let flags = &self.code.flags; - let msg = if flags - .contains(bytecode::CodeFlags::COROUTINE | bytecode::CodeFlags::GENERATOR) - { + let msg = if flags.contains(bytecode::CodeFlags::ASYNC_GENERATOR) { "async generator raised StopIteration" } else if flags.contains(bytecode::CodeFlags::COROUTINE) { "coroutine raised StopIteration" diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index 8358d41b2b..f2d4dcdb64 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -160,7 +160,11 @@ mod builtins { .map(|&b| b as char) .collect(); - if name.is_empty() { None } else { Some(name) } + if name.is_empty() { + None + } else { + Some(normalize_source_encoding(&name)) + } } // Split into lines (first two only) @@ -186,15 +190,39 @@ mod builtins { lines.next().and_then(find_encoding_in_line) } + /// Match CPython's Parser/tokenizer/helpers.c:get_normal_name(). + #[cfg(feature = "parser")] + fn normalize_source_encoding(name: &str) -> String { + let mut normalized = String::with_capacity(name.len().min(12)); + for ch in name.chars().take(12) { + if ch == '_' { + normalized.push('-'); + } else { + normalized.push(ch.to_ascii_lowercase()); + } + } + + if normalized == "utf-8" || normalized.starts_with("utf-8-") { + "utf-8".to_owned() + } else if normalized == "latin-1" + || normalized == "iso-8859-1" + || normalized == "iso-latin-1" + || normalized.starts_with("latin-1-") + || normalized.starts_with("iso-8859-1-") + || normalized.starts_with("iso-latin-1-") + { + "iso-8859-1".to_owned() + } else { + name.to_owned() + } + } + /// Decode source bytes to a string, handling PEP 263 encoding declarations /// and BOM. Raises SyntaxError for invalid UTF-8 without an encoding /// declaration. - /// Check if an encoding name is a UTF-8 variant after normalization. - /// Matches: utf-8, utf_8, utf8, UTF-8, etc. #[cfg(feature = "parser")] fn is_utf8_encoding(name: &str) -> bool { - let normalized: String = name.chars().filter(|&c| c != '-' && c != '_').collect(); - normalized.eq_ignore_ascii_case("utf8") + name == "utf-8" } #[cfg(feature = "parser")] @@ -206,9 +234,10 @@ mod builtins { // Validate BOM + encoding combination if has_bom && !is_utf8 { + let enc = encoding.as_deref().unwrap_or("utf-8"); return Err(vm.new_exception_msg( vm.ctx.exceptions.syntax_error.to_owned(), - format!("encoding problem for '{filename}': utf-8").into(), + format!("encoding problem: {enc} with BOM").into(), )); } diff --git a/scripts/dis_dump.py b/scripts/dis_dump.py index d888cd23df..813de22e65 100755 --- a/scripts/dis_dump.py +++ b/scripts/dis_dump.py @@ -18,7 +18,6 @@ import json import os import re -import struct import sys import types @@ -109,22 +108,6 @@ def _unescape(m): return argrepr -def _normalize_const_repr(value): - """Return a cross-interpreter representation for LOAD_CONST values.""" - if isinstance(value, float): - return f"float:{struct.pack('>d', value).hex()}" - if isinstance(value, tuple): - if not value: - return "()" - parts = [_normalize_const_repr(item) for item in value] - trailing = "," if len(parts) == 1 else "" - return f"({', '.join(parts)}{trailing})" - if isinstance(value, frozenset): - parts = sorted(_normalize_const_repr(item) for item in value) - return f"frozenset({{{', '.join(parts)}}})" - return _normalize_argrepr(repr(value)) - - _IS_RUSTPYTHON = ( hasattr(sys, "implementation") and sys.implementation.name == "rustpython" ) @@ -168,7 +151,7 @@ def _resolve_arg_fallback(code, opname, arg): return _resolve_localsplus_name(code, arg) elif opname == "LOAD_CONST": if 0 <= arg < len(code.co_consts): - return _normalize_const_repr(code.co_consts[arg]) + return _normalize_argrepr(repr(code.co_consts[arg])) elif opname in ( "LOAD_DEREF", "STORE_DEREF", @@ -311,10 +294,7 @@ def _metadata_cache_slot_offsets(inst): elif inst.arg is not None and inst.argrepr: # If argrepr is just a number, try to resolve it via fallback # (RustPython may return raw index instead of variable name) - if opname == "LOAD_CONST" and 0 <= inst.arg < len(code.co_consts): - argrepr = _normalize_const_repr(code.co_consts[inst.arg]) - else: - argrepr = inst.argrepr + argrepr = inst.argrepr if argrepr.isdigit() or (argrepr.startswith("-") and argrepr[1:].isdigit()): resolved = _resolve_arg_fallback(code, opname, inst.arg) if isinstance(resolved, str) and not resolved.isdigit(): From b9efe1053795614659f4a10a9e9cbcfb4810865e Mon Sep 17 00:00:00 2001 From: Ivan Mironov Date: Sun, 24 May 2026 04:43:56 +0000 Subject: [PATCH 09/18] Add missing test for select.select() (#7953) This is a follow up for https://github.com/RustPython/RustPython/pull/7948 --- extra_tests/snippets/stdlib_select.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/extra_tests/snippets/stdlib_select.py b/extra_tests/snippets/stdlib_select.py index 9afd95beca..5263bc344f 100644 --- a/extra_tests/snippets/stdlib_select.py +++ b/extra_tests/snippets/stdlib_select.py @@ -62,6 +62,16 @@ def fileno(self): resource.setrlimit(resource.RLIMIT_NOFILE, (soft_max_fds, hard_max_fds)) sockets = [s for _ in range(TOO_MANY_SELECT_FDS // 2) for s in socket.socketpair()] assert_raises(ValueError, select.select, sockets, [], [], 0) +if sys.platform != "win32": + # Try to overflow descriptor bit mask on *nix with a single item + max_fd = -1 + max_fd_sock = None + sockets.reverse() + for sock in sockets: + if sock.fileno() > max_fd: + max_fd = sock.fileno() + max_fd_sock = sock + assert_raises(ValueError, select.select, [max_fd_sock], [], [], 0) del sockets a, b = socket.socketpair() # CPython disallows this on *nix systems too. From e1d9a1123eff81fac38de2c824ac121f41c72585 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Sun, 24 May 2026 07:48:00 +0300 Subject: [PATCH 10/18] Skip flaky tests (#7961) --- Lib/test/test_asyncio/test_sendfile.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Lib/test/test_asyncio/test_sendfile.py b/Lib/test/test_asyncio/test_sendfile.py index dcd963b335..3acbc37f24 100644 --- a/Lib/test/test_asyncio/test_sendfile.py +++ b/Lib/test/test_asyncio/test_sendfile.py @@ -573,6 +573,10 @@ class PollEventLoopTests(SendfileTestsBase, def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.PollSelector()) + @unittest.skipIf(sys.platform != "win32", "TODO: RUSTPYTHON; Flaky on CI") + def test_sendfile_ssl_pre_and_post_data(self): + return super().test_sendfile_ssl_pre_and_post_data() + # Should always exist. class SelectEventLoopTests(SendfileTestsBase, test_utils.TestCase): @@ -580,6 +584,10 @@ class SelectEventLoopTests(SendfileTestsBase, def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.SelectSelector()) + @unittest.skipIf(sys.platform != "win32", "TODO: RUSTPYTHON; Flaky on CI") + def test_sendfile_ssl_pre_and_post_data(self): + return super().test_sendfile_ssl_pre_and_post_data() + if __name__ == '__main__': unittest.main() From b5ff41c2196f6c80d8d01812608daba7085c0f2b Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sun, 24 May 2026 19:30:42 +0900 Subject: [PATCH 11/18] Align marshal and .pyc with CPython 3.14 (#7958) * Share marshal ref table between code object and its internals read_marshal_bytes, _str, _str_vec, _name_tuple, and _const_tuple now take a shared ref table and resolve TYPE_REF / register FLAG_REF entries. deserialize_code is split into a public wrapper and an inner function that receives the ref table; deserialize_value_depth opens a fresh inner ref space when it hits Type::Code, mirroring CPython's behaviour of putting the code object itself at ref slot 0. Nested code objects inside const tuples reuse the surrounding code's ref space via the new read_const_value helper. * Align PYC magic number, FORMAT_VERSION, and header check with CPython 3.14 PYC_MAGIC_NUMBER changes from 2994 to 3627, matching CPython 3.14's pyc_magic_number_token (0x0a0d0e2b). marshal FORMAT_VERSION drops from 5 to 4 (the encoder/marshal.version value; the decoder already accepts both). check_pyc_magic_number_bytes now compares all four magic bytes instead of the first two. * Add CPython 3.14 .pyc decoding regression tests Two fixture-based tests pin the marshal decoder against actual CPython 3.14 marshal.dumps() output: a trivial module that exercises FLAG_REF plus TYPE_REF for qualname, and a module with a nested function that exercises ref sharing between a const tuple and its surrounding code object. * Accept CPython-tagged .pyc as read-only bytecode source SourceFileLoader.get_code now also looks for .pyc files using _RP_FALLBACK_CACHE_TAGS (currently ('cpython-314',)) in addition to sys.implementation.cache_tag. The matched .pyc is only used for reading; recompilation still writes to the RustPython-tagged path, so CPython's .pyc is never overwritten. Source-stat / hash / timestamp validation logic is unchanged. * Apply rustfmt to marshal helpers * Marshal PySlice from format version 4 instead of 5 CPython's marshal supports TYPE_SLICE from format version 4 onwards and that is the default version. Rejecting slice dumps below version 5 made marshal.dumps(slice(...)) fail with the default version and broke test.test_marshal.SliceTestCase.test_slice. * Revert "Accept CPython-tagged .pyc as read-only bytecode source" Lib/importlib/_bootstrap_external.py is CPython's own code copied verbatim; local patches here defeat compatibility tracking. The cpython-XX cache_tag fallback needs to live on the RustPython side (Rust code or sys.implementation.cache_tag policy), not as edits to the imported standard library. This reverts commit 1fc426d0fb5fcdb50d35cad13bbb43e8f6ce1c7f. * Format sys.implementation.cache_tag as cpython-{MAJOR}{MINOR} Use the CPython compatibility version (e.g. cpython-314) instead of the rustpython-{MAJOR_IMPL}_{MINOR_IMPL} interpreter version string. * Set marshal FORMAT_VERSION to 5 to match CPython 3.14.5 Py_MARSHAL_VERSION is 5 in CPython 3.14.5 (Include/marshal.h:16) and TYPE_SLICE serialization rejects version < 5 (Python/marshal.c:720). Restore the same threshold and constant so marshal.version and the slice-marshal gate match CPython. * Thread marshal recursion depth through nested code objects Code objects embedded in const-tuples reset the depth budget on each recursion, so a hostile or pathological marshal stream of code-in-tuple- in-code can blow the stack despite MAX_MARSHAL_STACK_DEPTH. Pass the current depth through deserialize_code_inner and read_marshal_const_tuple and decrement at each code-object/tuple boundary. Also route dict keys through deserialize_value_after_header so TYPE_CODE keys decode instead of failing with BadType. --- crates/compiler-core/src/marshal.rs | 401 +++++++++++++++++++++++----- crates/vm/src/import.rs | 2 +- crates/vm/src/stdlib/sys.rs | 3 +- crates/vm/src/version.rs | 4 +- 4 files changed, 343 insertions(+), 67 deletions(-) diff --git a/crates/compiler-core/src/marshal.rs b/crates/compiler-core/src/marshal.rs index 829c1dc951..503369a983 100644 --- a/crates/compiler-core/src/marshal.rs +++ b/crates/compiler-core/src/marshal.rs @@ -194,6 +194,22 @@ pub fn deserialize_code( rdr: &mut R, bag: Bag, ) -> Result> { + let mut refs: Vec> = Vec::new(); + deserialize_code_inner(rdr, bag, MAX_MARSHAL_STACK_DEPTH, &mut refs) +} + +/// Inner code-object deserializer that shares a ref table with caller. +/// Used when decoding a code object embedded in another marshal stream so +/// that TYPE_REF entries inside the code can resolve across nested values. +fn deserialize_code_inner( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, +) -> Result> { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } // 1–5: scalar fields let arg_count = rdr.read_u32()?; let posonlyarg_count = rdr.read_u32()?; @@ -202,24 +218,24 @@ pub fn deserialize_code( let flags = CodeFlags::from_bits_truncate(rdr.read_u32()?); // 6: co_code - let code_bytes = read_marshal_bytes(rdr)?; + let code_bytes = read_marshal_bytes(rdr, &bag, refs)?; // 7: co_consts - let constants = read_marshal_const_tuple(rdr, bag)?; + let constants = read_marshal_const_tuple(rdr, bag, depth, refs)?; // 8: co_names - let names = read_marshal_name_tuple(rdr, &bag)?; + let names = read_marshal_name_tuple(rdr, &bag, refs)?; // 9: co_localsplusnames - let localsplusnames = read_marshal_str_vec(rdr)?; + let localsplusnames = read_marshal_str_vec(rdr, &bag, refs)?; // 10: co_localspluskinds - let localspluskinds = read_marshal_bytes(rdr)?; + let localspluskinds = read_marshal_bytes(rdr, &bag, refs)?; // 11–13: filename, name, qualname - let source_path = bag.make_name(&read_marshal_str(rdr)?); - let obj_name = bag.make_name(&read_marshal_str(rdr)?); - let qualname = bag.make_name(&read_marshal_str(rdr)?); + let source_path = bag.make_name(&read_marshal_str(rdr, &bag, refs)?); + let obj_name = bag.make_name(&read_marshal_str(rdr, &bag, refs)?); + let qualname = bag.make_name(&read_marshal_str(rdr, &bag, refs)?); // 14: co_firstlineno let first_line_raw = rdr.read_u32()? as i32; @@ -230,8 +246,8 @@ pub fn deserialize_code( }; // 15–16: linetable, exceptiontable - let linetable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice(); - let exceptiontable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice(); + let linetable = read_marshal_bytes(rdr, &bag, refs)?.into_boxed_slice(); + let exceptiontable = read_marshal_bytes(rdr, &bag, refs)?.into_boxed_slice(); // Split localsplusnames/kinds → varnames/cellvars/freevars let lp = split_localplus( @@ -275,72 +291,238 @@ pub fn deserialize_code( }) } -/// Read a marshal bytes object (TYPE_STRING = b's'). -fn read_marshal_bytes(rdr: &mut R) -> Result> { - let type_byte = rdr.read_u8()? & !FLAG_REF; +/// Reserve a ref slot if `FLAG_REF` was present, returning its index. +fn reserve_ref_slot(has_flag: bool, refs: &mut Vec>) -> Option { + if has_flag { + let idx = refs.len(); + refs.push(None); + Some(idx) + } else { + None + } +} + +/// Resolve a TYPE_REF index, returning the previously stored value. +fn resolve_ref(idx: usize, refs: &[Option]) -> Result { + refs.get(idx) + .and_then(|v| v.clone()) + .ok_or(MarshalError::InvalidBytecode) +} + +/// Read a marshal bytes object (TYPE_STRING = b's'), resolving TYPE_REF +/// and registering this read in the ref table when `FLAG_REF` is set. +fn read_marshal_bytes( + rdr: &mut R, + bag: &Bag, + refs: &mut Vec>, +) -> Result> { + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Bytes { value } => Ok(value.to_vec()), + _ => Err(MarshalError::BadType), + }; + } + if type_byte != Type::Bytes as u8 { return Err(MarshalError::BadType); } + + let slot = reserve_ref_slot(has_flag, refs); let len = rdr.read_u32()?; - Ok(rdr.read_slice(len)?.to_vec()) + let bytes = rdr.read_slice(len)?.to_vec(); + if let Some(idx) = slot { + refs[idx] = + Some(bag.make_constant::(BorrowedConstant::Bytes { value: &bytes })); + } + Ok(bytes) } -/// Read a marshal string object. -fn read_marshal_str(rdr: &mut R) -> Result { - let type_byte = rdr.read_u8()? & !FLAG_REF; - let s = match type_byte { +/// Read a marshal string object, resolving TYPE_REF and registering +/// this read in the ref table when `FLAG_REF` is set. +fn read_marshal_str( + rdr: &mut R, + bag: &Bag, + refs: &mut Vec>, +) -> Result { + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Str { value } => Ok(value.to_string_lossy().into_owned()), + _ => Err(MarshalError::BadType), + }; + } + + let slot = reserve_ref_slot(has_flag, refs); + let owned = match type_byte { b'u' | b't' | b'a' | b'A' => { let len = rdr.read_u32()?; - rdr.read_str(len)? + alloc::string::String::from(rdr.read_str(len)?) } b'z' | b'Z' => { let len = rdr.read_u8()? as u32; - rdr.read_str(len)? + alloc::string::String::from(rdr.read_str(len)?) } _ => return Err(MarshalError::BadType), }; - Ok(alloc::string::String::from(s)) + if let Some(idx) = slot { + refs[idx] = Some(bag.make_constant::(BorrowedConstant::Str { + value: Wtf8::new(owned.as_str()), + })); + } + Ok(owned) } /// Read a marshal tuple of strings, returning owned Strings. -fn read_marshal_str_vec(rdr: &mut R) -> Result> { - let type_byte = rdr.read_u8()? & !FLAG_REF; +fn read_marshal_str_vec( + rdr: &mut R, + bag: &Bag, + refs: &mut Vec>, +) -> Result> { + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Tuple { elements } => elements + .iter() + .map(|c| match c.borrow_constant() { + BorrowedConstant::Str { value } => Ok(value.to_string_lossy().into_owned()), + _ => Err(MarshalError::BadType), + }) + .collect(), + _ => Err(MarshalError::BadType), + }; + } + let n = match type_byte { b'(' => rdr.read_u32()? as usize, b')' => rdr.read_u8()? as usize, _ => return Err(MarshalError::BadType), }; - (0..n).map(|_| read_marshal_str(rdr)).collect() + let slot = reserve_ref_slot(has_flag, refs); + let items: Vec = (0..n) + .map(|_| read_marshal_str(rdr, bag, refs)) + .collect::>()?; + if let Some(idx) = slot { + let elements: Vec = items + .iter() + .map(|s| { + bag.make_constant::(BorrowedConstant::Str { + value: Wtf8::new(s.as_str()), + }) + }) + .collect(); + refs[idx] = Some(bag.make_constant::(BorrowedConstant::Tuple { + elements: &elements, + })); + } + Ok(items) } fn read_marshal_name_tuple( rdr: &mut R, bag: &Bag, + refs: &mut Vec>, ) -> Result::Name]>> { - let type_byte = rdr.read_u8()? & !FLAG_REF; - let n = match type_byte { - b'(' => rdr.read_u32()? as usize, - b')' => rdr.read_u8()? as usize, - _ => return Err(MarshalError::BadType), - }; - (0..n) - .map(|_| Ok(bag.make_name(&read_marshal_str(rdr)?))) - .collect::>>() - .map(Vec::into_boxed_slice) + let names = read_marshal_str_vec(rdr, bag, refs)?; + Ok(names + .iter() + .map(|s| bag.make_name(s)) + .collect::>() + .into_boxed_slice()) } -/// Read a marshal tuple of constants. +/// Read a marshal tuple of constants. Shares the ref table with the +/// surrounding code-object decode so that nested TYPE_REF entries (for +/// strings, bytes, code objects, etc.) resolve correctly. fn read_marshal_const_tuple( rdr: &mut R, bag: Bag, + depth: usize, + refs: &mut Vec>, ) -> Result> { - let type_byte = rdr.read_u8()? & !FLAG_REF; + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } + let raw = rdr.read_u8()?; + let type_byte = raw & !FLAG_REF; + let has_flag = raw & FLAG_REF != 0; + + if type_byte == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + let stored = resolve_ref(idx, refs)?; + return match stored.borrow_constant() { + BorrowedConstant::Tuple { elements } => Ok(elements.iter().cloned().collect()), + _ => Err(MarshalError::BadType), + }; + } + let n = match type_byte { b'(' => rdr.read_u32()? as usize, b')' => rdr.read_u8()? as usize, _ => return Err(MarshalError::BadType), }; - (0..n).map(|_| deserialize_value(rdr, bag)).collect() + let slot = reserve_ref_slot(has_flag, refs); + let child_depth = depth - 1; + let items: Vec = (0..n) + .map(|_| read_const_value(rdr, bag, child_depth, refs)) + .collect::>()?; + if let Some(idx) = slot { + refs[idx] = + Some(bag.make_constant::(BorrowedConstant::Tuple { elements: &items })); + } + Ok(items.into_iter().collect()) +} + +/// Read a single value while staying inside an existing code-object ref +/// space. Unlike `deserialize_value_depth`, encountering `Type::Code` +/// here reuses the caller's ref table instead of opening a fresh one — +/// this matches CPython's single global ref space for objects nested +/// inside a code object's const tuple. +fn read_const_value( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, +) -> Result { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } + let raw = rdr.read_u8()?; + let flag = raw & FLAG_REF != 0; + let type_code = raw & !FLAG_REF; + + if type_code == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + return resolve_ref(idx, refs); + } + + let slot = reserve_ref_slot(flag, refs); + let typ = Type::try_from(type_code)?; + let value = if matches!(typ, Type::Code) { + let code = deserialize_code_inner(rdr, bag, depth - 1, refs)?; + bag.make_code(code) + } else { + deserialize_value_typed(rdr, bag, depth, refs, typ)? + }; + if let Some(idx) = slot { + refs[idx] = Some(value.clone()); + } + Ok(value) } pub trait MarshalBag: Copy { @@ -506,6 +688,23 @@ fn deserialize_value_depth( return Err(MarshalError::InvalidBytecode); } let raw = rdr.read_u8()?; + deserialize_value_after_header(rdr, bag, depth, refs, raw) +} + +/// Continue deserializing a value after the header byte has already been +/// consumed. Shared by `deserialize_value_depth` and the dict-key branch, +/// where the header byte is read up front to detect the TYPE_NULL +/// terminator. +fn deserialize_value_after_header( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, + raw: u8, +) -> Result { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } let flag = raw & FLAG_REF != 0; let type_code = raw & !FLAG_REF; @@ -528,7 +727,21 @@ fn deserialize_value_depth( }; let typ = Type::try_from(type_code)?; - let value = deserialize_value_typed(rdr, bag, depth, refs, typ)?; + // Code-objects keep their own inner ref table because Bag::Value (the + // outer marshal value) and the constant-bag's Constant type are not + // in general the same. When the outer header carried FLAG_REF, the + // code object occupies slot 0 of the single global ref space, so we + // mirror that by reserving slot 0 of the inner table. + let value = if matches!(typ, Type::Code) { + let mut inner_refs: Vec::Constant>> = Vec::new(); + if flag { + inner_refs.push(None); + } + let code = deserialize_code_inner(rdr, bag.constant_bag(), depth - 1, &mut inner_refs)?; + bag.make_code(code) + } else { + deserialize_value_typed(rdr, bag, depth, refs, typ)? + }; if let Some(idx) = slot { refs[idx] = Some(value.clone()); @@ -630,32 +843,10 @@ fn deserialize_value_typed( let mut pairs = Vec::new(); loop { let raw = rdr.read_u8()?; - let type_code = raw & !FLAG_REF; - if type_code == b'0' { + if raw & !FLAG_REF == b'0' { break; } - // TYPE_REF for key - let k = if type_code == Type::Ref as u8 { - let idx = rdr.read_u32()? as usize; - refs.get(idx) - .and_then(|v| v.clone()) - .ok_or(MarshalError::InvalidBytecode)? - } else { - let flag = raw & FLAG_REF != 0; - let key_slot = if flag { - let idx = refs.len(); - refs.push(None); - Some(idx) - } else { - None - }; - let key_type = Type::try_from(type_code)?; - let k = deserialize_value_typed(rdr, bag, d, refs, key_type)?; - if let Some(idx) = key_slot { - refs[idx] = Some(k.clone()); - } - k - }; + let k = deserialize_value_after_header(rdr, bag, d, refs, raw)?; let v = deserialize_value_depth(rdr, bag, d, refs)?; pairs.push((k, v)); } @@ -667,7 +858,7 @@ fn deserialize_value_typed( let value = rdr.read_slice(len)?; bag.make_bytes(value) } - Type::Code => bag.make_code(deserialize_code(rdr, bag.constant_bag())?), + Type::Code => return Err(MarshalError::BadType), Type::Slice => { let d = depth - 1; let start = deserialize_value_depth(rdr, bag, d, refs)?; @@ -1288,3 +1479,87 @@ fn lt_read_signed_varint(data: &[u8], pos: &mut usize) -> i32 { (val >> 1) as i32 } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::bytecode::{BasicBag, ConstantData}; + + fn hex_to_bytes(hex: &str) -> Vec { + (0..hex.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap()) + .collect() + } + + fn decode_code(hex: &str) -> CodeObject { + let bytes = hex_to_bytes(hex); + let value = deserialize_value(&mut &bytes[..], BasicBag).expect("decode failed"); + match value { + ConstantData::Code { code } => *code, + other => panic!("expected Code, got {other:?}"), + } + } + + /// CPython 3.14 marshal output for: `compile("x = 1", "", "exec")`. + /// Exercises FLAG_REF on the code object and TYPE_REF for qualname + /// pointing back at the obj_name slot. + #[test] + fn cpython_314_trivial_assignment() { + let hex = "e30000000000000000000000000100000000000000f30a00000080005e017400520123002902\ + e9010000004e2901da0178a900f300000000da033c743eda083c6d6f64756c653e72070000000100\ + 0000730a000000f003010101d8040582017205000000"; + let code = decode_code(hex); + assert_eq!(code.obj_name.as_str(), ""); + assert_eq!(code.qualname.as_str(), ""); + assert_eq!(code.source_path.as_str(), ""); + assert_eq!(code.arg_count, 0); + assert_eq!(code.max_stackdepth, 1); + assert_eq!(code.names.len(), 1); + assert_eq!(code.names[0].as_str(), "x"); + assert_eq!(code.constants.len(), 2); + // (1, None) + let consts: &[ConstantData] = &code.constants; + assert!(matches!( + consts[0], + ConstantData::Integer { ref value } if *value == 1.into(), + )); + assert!(matches!(consts[1], ConstantData::None)); + } + + /// CPython 3.14 marshal output for a module with a nested function + /// and a string constant. Verifies that nested code objects inside + /// a const tuple share the surrounding code's ref space. + #[test] + fn cpython_314_nested_code_and_string_const() { + let hex = "e30000000000000000000000000100000000000000f310000000800052001700740052017401\ + 520223002903630200000000000000000000000200000003000000f3120000008000570\ + 12c0000000000000000000000230029014ea9002902da0161da016273020000002626da033c743e\ + da0361646472070000000200000073090000008000d80b0c8d35804cf300000000da0568656c6c\ + 6f4e29027207000000da084752454554494e47720300000072080000007206000000da083c6d6f\ + 64756c653e720b000000010000007311000000f003010101f204010111f006000c1382087208000000"; + let code = decode_code(hex); + assert_eq!(code.obj_name.as_str(), ""); + assert_eq!(code.names.len(), 2); + assert_eq!(code.names[0].as_str(), "add"); + assert_eq!(code.names[1].as_str(), "GREETING"); + assert_eq!(code.constants.len(), 3); + // Inner code, "hello", None + let consts: &[ConstantData] = &code.constants; + let inner = match &consts[0] { + ConstantData::Code { code } => code, + other => panic!("expected nested Code, got {other:?}"), + }; + assert_eq!(inner.obj_name.as_str(), "add"); + assert_eq!(inner.qualname.as_str(), "add"); + assert_eq!(inner.arg_count, 2); + assert_eq!(inner.varnames.len(), 2); + assert_eq!(inner.varnames[0].as_str(), "a"); + assert_eq!(inner.varnames[1].as_str(), "b"); + assert!(matches!( + consts[1], + ConstantData::Str { ref value } if value.as_str().ok() == Some("hello"), + )); + assert!(matches!(consts[2], ConstantData::None)); + } +} diff --git a/crates/vm/src/import.rs b/crates/vm/src/import.rs index f33290ff83..798bc258b7 100644 --- a/crates/vm/src/import.rs +++ b/crates/vm/src/import.rs @@ -9,7 +9,7 @@ use crate::{ }; pub(crate) fn check_pyc_magic_number_bytes(buf: &[u8]) -> bool { - buf.starts_with(&crate::version::PYC_MAGIC_NUMBER_BYTES[..2]) + buf.starts_with(&crate::version::PYC_MAGIC_NUMBER_BYTES) } pub(crate) fn init_importlib_base(vm: &mut VirtualMachine) -> PyResult { diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 44e692f178..68d57d225a 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -658,7 +658,8 @@ pub mod sys { fn implementation(vm: &VirtualMachine) -> PyRef { const NAME: &str = "rustpython"; - let cache_tag = format!("{NAME}-{}_{}", version::MAJOR_IMPL, version::MINOR_IMPL); + // cache tag uses 'cpython' because our compiler is cpython compatible + let cache_tag = format!("cpython-{}{}", version::MAJOR, version::MINOR); let ctx = &vm.ctx; py_namespace!(vm, { "name" => ctx.new_str(NAME), diff --git a/crates/vm/src/version.rs b/crates/vm/src/version.rs index f30a6a4942..05eb12a294 100644 --- a/crates/vm/src/version.rs +++ b/crates/vm/src/version.rs @@ -69,8 +69,8 @@ pub const RUSTPYTHON_VERSION: &str = const { }; // Must be aligned to Lib/importlib/_bootstrap_external.py -// Bumped to 2994 for new CommonConstant discriminants (BuiltinList, BuiltinSet) -pub const PYC_MAGIC_NUMBER: u16 = 2994; +// Matches CPython 3.14 (Include/internal/pycore_magic_number.h). +pub const PYC_MAGIC_NUMBER: u16 = 3627; // CPython format: magic_number | ('\r' << 16) | ('\n' << 24) // This protects against text-mode file reads From d7d936575c2b3c39fbca1a703ed0d45ed9e7f5f9 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Sun, 24 May 2026 13:55:22 +0300 Subject: [PATCH 12/18] General code nitpicks (#7955) --- crates/derive-impl/src/from_args.rs | 5 +-- crates/derive-impl/src/pyclass.rs | 12 +++---- crates/derive-impl/src/pymodule.rs | 22 ++++++------ crates/derive-impl/src/pytraverse.rs | 52 +++++++++++++--------------- crates/derive-impl/src/util.rs | 27 +++++++++------ crates/vm/src/format.rs | 6 +++- 6 files changed, 63 insertions(+), 61 deletions(-) diff --git a/crates/derive-impl/src/from_args.rs b/crates/derive-impl/src/from_args.rs index 8149a3aa65..adcbbd418d 100644 --- a/crates/derive-impl/src/from_args.rs +++ b/crates/derive-impl/src/from_args.rs @@ -6,7 +6,7 @@ use syn::{Attribute, Data, DeriveInput, Expr, Field, Ident, Result, Token, parse /// The kind of the python parameter, this corresponds to the value of Parameter.kind /// (https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind) -#[derive(Default)] +#[derive(Clone, Copy, Default, Eq, PartialEq)] enum ParameterKind { PositionalOnly, #[default] @@ -77,9 +77,10 @@ impl ArgAttribute { } fn parse_argument(&mut self, meta: ParseNestedMeta<'_>) -> Result<()> { - if let ParameterKind::Flatten = self.kind { + if self.kind == ParameterKind::Flatten { return Err(meta.error("can't put additional arguments on a flatten arg")); } + if meta.path.is_ident("default") && meta.input.peek(Token![=]) { if matches!(self.default, Some(Some(_))) { return Err(meta.error("Default already set")); diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index 5e25e381df..0cc0faa113 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -644,8 +644,8 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result T // their own __init__ in __dict__. let slot_init = quote!(); - let extra_attrs_tokens = if extra_attrs.is_empty() { - quote!() - } else { - quote!(, #(#extra_attrs),*) - }; + let extra_attrs_tokens = quote!(#(#extra_attrs),*); quote! { - #[pyclass(flags(BASETYPE, HAS_DICT), with(#(#with_items),*) #extra_attrs_tokens)] + #[pyclass(flags(BASETYPE, HAS_DICT), with(#(#with_items),*), #extra_attrs_tokens)] impl #generics #self_ty { #(#items)* } diff --git a/crates/derive-impl/src/pymodule.rs b/crates/derive-impl/src/pymodule.rs index cee8b1be4a..32d7a0fa6b 100644 --- a/crates/derive-impl/src/pymodule.rs +++ b/crates/derive-impl/src/pymodule.rs @@ -84,13 +84,13 @@ fn negate_cfg_attrs(cfg_attrs: &[Attribute]) -> Vec { if cfg_attrs.is_empty() { return vec![]; } - let predicates: Vec<_> = cfg_attrs + let predicates = cfg_attrs .iter() .map(|attr| match &attr.meta { syn::Meta::List(list) => list.tokens.clone(), _ => unreachable!("only #[cfg(...)] should be here"), }) - .collect(); + .collect::>(); if predicates.len() == 1 { let predicate = &predicates[0]; vec![parse_quote!(#[cfg(not(#predicate))])] @@ -295,19 +295,20 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result, Vec<_>) = with_items.iter().partition(|w| w.cfg_attrs.is_empty()); - let uncond_paths: Vec<_> = uncond_withs.iter().map(|w| &w.path).collect(); + let uncond_paths = uncond_withs.iter().map(|w| &w.path).collect::>(); let method_defs = if with_items.is_empty() { quote!(#function_items) } else { // For cfg-gated with items, generate conditional const declarations // so the total array size adapts to the cfg at compile time - let cond_const_names: Vec<_> = cond_withs + let cond_const_names = cond_withs .iter() .enumerate() .map(|(i, _)| format_ident!("__WITH_METHODS_{}", i)) - .collect(); - let cond_const_decls: Vec<_> = cond_withs + .collect::>(); + + let cond_const_decls= cond_withs .iter() .zip(&cond_const_names) .map(|(w, name)| { @@ -321,7 +322,7 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result>(); quote!({ const OWN_METHODS: &'static [::rustpython_vm::function::PyMethodDef] = &#function_items; @@ -340,7 +341,7 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result = with_items + let init_with_calls = with_items .iter() .map(|w| { let cfg_attrs = &w.cfg_attrs; @@ -350,7 +351,7 @@ pub(crate) fn impl_pymodule(args: PyModuleArgs, module_item: Item) -> Result>(); items.extend([ parse_quote! { @@ -702,8 +703,7 @@ impl ModuleItem for FunctionItem { let r = loop_unit(); args.context.errors.ok_or_push(r); } - let py_names: Vec<_> = py_names.into_iter().collect(); - py_names + py_names.into_iter().collect::>() } }; let call_flags = infer_native_call_flags(func.sig(), 0); diff --git a/crates/derive-impl/src/pytraverse.rs b/crates/derive-impl/src/pytraverse.rs index c4ec382329..c75eee87ba 100644 --- a/crates/derive-impl/src/pytraverse.rs +++ b/crates/derive-impl/src/pytraverse.rs @@ -38,30 +38,31 @@ fn field_to_traverse_code(field: &Field) -> Result { .iter() .filter_map(pytraverse_arg) .collect::, _>>()?; - let do_trace = if pytraverse_attrs.len() > 1 { + + if pytraverse_attrs.len() > 1 { bail_span!( field, "found multiple #[pytraverse] attributes on the same field, expect at most one" ) - } else if pytraverse_attrs.is_empty() { - // default to always traverse every field - true - } else { - !pytraverse_attrs[0].skip - }; + } + let name = field.ident.as_ref().ok_or_else(|| { syn::Error::new_spanned( field.clone(), "Field should have a name in non-tuple struct", ) })?; - if do_trace { - Ok(quote!( + + // default to always traverse every field + let do_trace = pytraverse_attrs.first().is_none_or(|attr| !attr.skip); + + Ok(if do_trace { + quote!( ::rustpython_vm::object::Traverse::traverse(&self.#name, tracer_fn); - )) + ) } else { - Ok(quote!()) - } + quote!() + }) } /// not trace corresponding field @@ -76,20 +77,16 @@ fn gen_trace_code(item: &mut DeriveInput) -> Result { .iter_mut() .map(|f| -> Result { field_to_traverse_code(f) }) .collect::>()?; - let res = res.into_iter().collect::(); - Ok(res) - } - syn::Fields::Unnamed(fields) => { - let res: TokenStream = (0..fields.unnamed.len()) - .map(|i| { - let i = syn::Index::from(i); - quote!( - ::rustpython_vm::object::Traverse::traverse(&self.#i, tracer_fn); - ) - }) - .collect(); - Ok(res) + Ok(res.into_iter().collect::()) } + syn::Fields::Unnamed(fields) => Ok((0..fields.unnamed.len()) + .map(|i| { + let i = syn::Index::from(i); + quote!( + ::rustpython_vm::object::Traverse::traverse(&self.#i, tracer_fn); + ) + }) + .collect::()), _ => Err(syn::Error::new_spanned( fields, "Only named and unnamed fields are supported", @@ -116,12 +113,11 @@ pub(crate) fn impl_pytraverse(mut item: DeriveInput) -> Result { let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); - let ret = quote! { + Ok(quote! { unsafe impl #impl_generics ::rustpython_vm::object::Traverse for #ty #ty_generics #where_clause { fn traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) { #trace_code } } - }; - Ok(ret) + }) } diff --git a/crates/derive-impl/src/util.rs b/crates/derive-impl/src/util.rs index 3ad41679c3..52e1fa236f 100644 --- a/crates/derive-impl/src/util.rs +++ b/crates/derive-impl/src/util.rs @@ -80,7 +80,7 @@ impl ToTokens for ValidatedItemNursery { let cfgs = &item.cfgs; let tokens = &item.tokens; quote! { - #( #cfgs )* + #(#cfgs)* { #tokens } @@ -99,12 +99,15 @@ pub(crate) trait ContentItem { type AttrName: core::str::FromStr + core::fmt::Display; fn inner(&self) -> &ContentItemInner; + fn index(&self) -> usize { self.inner().index } + fn attr_name(&self) -> &Self::AttrName { &self.inner().attr_name } + fn new_syn_error(&self, span: Span, message: &str) -> syn::Error { syn::Error::new(span, format!("#[{}] {}", self.attr_name(), message)) } @@ -142,6 +145,7 @@ impl ItemMetaInner { Ok(None) } })?; + if !lits.is_empty() { bail_span!(meta_ident, "#[{meta_ident}(..)] cannot contain literal") } @@ -153,6 +157,10 @@ impl ItemMetaInner { }) } + pub(crate) fn contains_key(&self, key: &str) -> bool { + self.meta_map.contains_key(key) + } + pub(crate) fn item_name(&self) -> String { self.item_ident.to_string() } @@ -215,10 +223,6 @@ impl ItemMetaInner { Ok(value) } - pub(crate) fn _has_key(&self, key: &str) -> bool { - matches!(self.meta_map.get(key), Some((_, _))) - } - pub(crate) fn _bool(&self, key: &str) -> Result { let value = if let Some((_, meta)) = self.meta_map.get(key) { match meta { @@ -517,6 +521,7 @@ impl ExceptionItemMeta { impl core::ops::Deref for ExceptionItemMeta { type Target = ClassItemMeta; + fn deref(&self) -> &Self::Target { &self.0 } @@ -524,8 +529,11 @@ impl core::ops::Deref for ExceptionItemMeta { pub(crate) trait AttributeExt: SynAttributeExt { fn promoted_nested(&self) -> Result; + fn ident_and_promoted_nested(&self) -> Result<(&Ident, PunctuatedNestedMeta)>; + fn try_remove_name(&mut self, name: &str) -> Result>; + fn fill_nested_meta(&mut self, name: &str, new_item: F) -> Result<()> where F: Fn() -> NestedMeta; @@ -544,6 +552,7 @@ impl AttributeExt for Attribute { })?; Ok(list.nested) } + fn ident_and_promoted_nested(&self) -> Result<(&Ident, PunctuatedNestedMeta)> { Ok((self.get_ident().unwrap(), self.promoted_nested()?)) } @@ -564,14 +573,10 @@ impl AttributeExt for Attribute { let mut found = None; for (i, item) in nested.iter().enumerate() { - let ident = if let Some(ident) = item.get_ident() { - ident - } else { - continue; - }; - if *ident != item_name { + if item.get_ident().is_none_or(|ident| ident != item_name) { continue; } + if found.is_some() { return Err(syn::Error::new( item.span(), diff --git a/crates/vm/src/format.rs b/crates/vm/src/format.rs index 2e3d6ad48a..7e18dd75d2 100644 --- a/crates/vm/src/format.rs +++ b/crates/vm/src/format.rs @@ -13,7 +13,11 @@ use crate::common::wtf8::{Wtf8, Wtf8Buf}; #[cfg(any(unix, windows))] pub(crate) fn get_locale_info() -> LocaleInfo { let lc = crate::host_env::locale::localeconv_data(); - let mut grouping: Vec = lc.grouping.iter().map(|&c| c as u8).collect(); + #[allow( + clippy::unnecessary_cast, + reason = "libc::c_char is not u8 on all platforms" + )] + let mut grouping = lc.grouping.iter().map(|&c| c as u8).collect::>(); if !grouping.is_empty() { grouping.push(0); } From 2fabf38d8fb2bbef7a90c060914b4aef7e126e61 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Sun, 24 May 2026 13:56:35 +0300 Subject: [PATCH 13/18] Impl `sys.audithook` (#7960) * Update tests * Add basic audit support * Add audit for `time.sleep` * Add some for `socket` * Some syslog * Some sys related audits * some marshal * monitoring callback * Mark failing tests * clippy * Clippy * clippy * mark failing test * mark more * Update `test_sys_setprofile.py` to 3.14.5 * Mark failing tests --- Lib/test/audit-tests.py | 681 +++++++++++++++++++++++++ Lib/test/test_audit.py | 34 +- Lib/test/test_bdb.py | 1 + Lib/test/test_sys_setprofile.py | 151 +++++- crates/stdlib/src/socket.rs | 23 + crates/stdlib/src/syslog.rs | 38 +- crates/vm/src/stdlib/marshal.rs | 8 + crates/vm/src/stdlib/sys.rs | 96 +++- crates/vm/src/stdlib/sys/monitoring.rs | 10 + crates/vm/src/stdlib/time.rs | 4 + crates/vm/src/vm/mod.rs | 2 + crates/vm/src/vm/thread.rs | 1 + 12 files changed, 1014 insertions(+), 35 deletions(-) create mode 100644 Lib/test/audit-tests.py diff --git a/Lib/test/audit-tests.py b/Lib/test/audit-tests.py new file mode 100644 index 0000000000..6884ac0dbe --- /dev/null +++ b/Lib/test/audit-tests.py @@ -0,0 +1,681 @@ +"""This script contains the actual auditing tests. + +It should not be imported directly, but should be run by the test_audit +module with arguments identifying each test. + +""" + +import contextlib +import os +import sys + + +class TestHook: + """Used in standard hook tests to collect any logged events. + + Should be used in a with block to ensure that it has no impact + after the test completes. + """ + + def __init__(self, raise_on_events=None, exc_type=RuntimeError): + self.raise_on_events = raise_on_events or () + self.exc_type = exc_type + self.seen = [] + self.closed = False + + def __enter__(self, *a): + sys.addaudithook(self) + return self + + def __exit__(self, *a): + self.close() + + def close(self): + self.closed = True + + @property + def seen_events(self): + return [i[0] for i in self.seen] + + def __call__(self, event, args): + if self.closed: + return + self.seen.append((event, args)) + if event in self.raise_on_events: + raise self.exc_type("saw event " + event) + + +# Simple helpers, since we are not in unittest here +def assertEqual(x, y): + if x != y: + raise AssertionError(f"{x!r} should equal {y!r}") + + +def assertIn(el, series): + if el not in series: + raise AssertionError(f"{el!r} should be in {series!r}") + + +def assertNotIn(el, series): + if el in series: + raise AssertionError(f"{el!r} should not be in {series!r}") + + +def assertSequenceEqual(x, y): + if len(x) != len(y): + raise AssertionError(f"{x!r} should equal {y!r}") + if any(ix != iy for ix, iy in zip(x, y)): + raise AssertionError(f"{x!r} should equal {y!r}") + + +@contextlib.contextmanager +def assertRaises(ex_type): + try: + yield + assert False, f"expected {ex_type}" + except BaseException as ex: + if isinstance(ex, AssertionError): + raise + assert type(ex) is ex_type, f"{ex} should be {ex_type}" + + +def test_basic(): + with TestHook() as hook: + sys.audit("test_event", 1, 2, 3) + assertEqual(hook.seen[0][0], "test_event") + assertEqual(hook.seen[0][1], (1, 2, 3)) + + +def test_block_add_hook(): + # Raising an exception should prevent a new hook from being added, + # but will not propagate out. + with TestHook(raise_on_events="sys.addaudithook") as hook1: + with TestHook() as hook2: + sys.audit("test_event") + assertIn("test_event", hook1.seen_events) + assertNotIn("test_event", hook2.seen_events) + + +def test_block_add_hook_baseexception(): + # Raising BaseException will propagate out when adding a hook + with assertRaises(BaseException): + with TestHook( + raise_on_events="sys.addaudithook", exc_type=BaseException + ) as hook1: + # Adding this next hook should raise BaseException + with TestHook() as hook2: + pass + + +def test_marshal(): + import marshal + o = ("a", "b", "c", 1, 2, 3) + payload = marshal.dumps(o) + + with TestHook() as hook: + assertEqual(o, marshal.loads(marshal.dumps(o))) + + try: + with open("test-marshal.bin", "wb") as f: + marshal.dump(o, f) + with open("test-marshal.bin", "rb") as f: + assertEqual(o, marshal.load(f)) + finally: + os.unlink("test-marshal.bin") + + actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"] + assertSequenceEqual(actual, [(o, marshal.version)] * 2) + + actual = [a[0] for e, a in hook.seen if e == "marshal.loads"] + assertSequenceEqual(actual, [payload]) + + actual = [e for e, a in hook.seen if e == "marshal.load"] + assertSequenceEqual(actual, ["marshal.load"]) + + +def test_pickle(): + import pickle + + class PicklePrint: + def __reduce_ex__(self, p): + return str, ("Pwned!",) + + payload_1 = pickle.dumps(PicklePrint()) + payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3)) + + # Before we add the hook, ensure our malicious pickle loads + assertEqual("Pwned!", pickle.loads(payload_1)) + + with TestHook(raise_on_events="pickle.find_class") as hook: + with assertRaises(RuntimeError): + # With the hook enabled, loading globals is not allowed + pickle.loads(payload_1) + # pickles with no globals are okay + pickle.loads(payload_2) + + +def test_monkeypatch(): + class A: + pass + + class B: + pass + + class C(A): + pass + + a = A() + + with TestHook() as hook: + # Catch name changes + C.__name__ = "X" + # Catch type changes + C.__bases__ = (B,) + # Ensure bypassing __setattr__ is still caught + type.__dict__["__bases__"].__set__(C, (B,)) + # Catch attribute replacement + C.__init__ = B.__init__ + # Catch attribute addition + C.new_attr = 123 + # Catch class changes + a.__class__ = B + + actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"] + assertSequenceEqual( + [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual + ) + + +def test_open(testfn): + # SSLContext.load_dh_params uses Py_fopen() rather than normal open() + try: + import ssl + + load_dh_params = ssl.create_default_context().load_dh_params + except ImportError: + load_dh_params = None + + try: + import readline + except ImportError: + readline = None + + def rl(name): + if readline: + return getattr(readline, name, None) + else: + return None + + # Try a range of "open" functions. + # All of them should fail + with TestHook(raise_on_events={"open"}) as hook: + for fn, *args in [ + (open, testfn, "r"), + (open, sys.executable, "rb"), + (open, 3, "wb"), + (open, testfn, "w", -1, None, None, None, False, lambda *a: 1), + (load_dh_params, testfn), + (rl("read_history_file"), testfn), + (rl("read_history_file"), None), + (rl("write_history_file"), testfn), + (rl("write_history_file"), None), + (rl("append_history_file"), 0, testfn), + (rl("append_history_file"), 0, None), + (rl("read_init_file"), testfn), + (rl("read_init_file"), None), + ]: + if not fn: + continue + with assertRaises(RuntimeError): + try: + fn(*args) + except NotImplementedError: + if fn == load_dh_params: + # Not callable in some builds + load_dh_params = None + raise RuntimeError + else: + raise + + actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]] + actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]] + assertSequenceEqual( + [ + i + for i in [ + (testfn, "r"), + (sys.executable, "r"), + (3, "w"), + (testfn, "w"), + (testfn, "rb") if load_dh_params else None, + (testfn, "r") if readline else None, + ("~/.history", "r") if readline else None, + (testfn, "w") if readline else None, + ("~/.history", "w") if readline else None, + (testfn, "a") if rl("append_history_file") else None, + ("~/.history", "a") if rl("append_history_file") else None, + (testfn, "r") if readline else None, + ("", "r") if readline else None, + ] + if i is not None + ], + actual_mode, + ) + assertSequenceEqual([], actual_flag) + + +def test_cantrace(): + traced = [] + + def trace(frame, event, *args): + if frame.f_code == TestHook.__call__.__code__: + traced.append(event) + + old = sys.settrace(trace) + try: + with TestHook() as hook: + # No traced call + eval("1") + + # No traced call + hook.__cantrace__ = False + eval("2") + + # One traced call + hook.__cantrace__ = True + eval("3") + + # Two traced calls (writing to private member, eval) + hook.__cantrace__ = 1 + eval("4") + + # One traced call (writing to private member) + hook.__cantrace__ = 0 + finally: + sys.settrace(old) + + assertSequenceEqual(["call"] * 4, traced) + + +def test_mmap(): + import mmap + + with TestHook() as hook: + mmap.mmap(-1, 8) + assertEqual(hook.seen[0][1][:2], (-1, 8)) + + +def test_ctypes_call_function(): + import ctypes + import _ctypes + + with TestHook() as hook: + _ctypes.call_function(ctypes._memmove_addr, (0, 0, 0)) + assert ("ctypes.call_function", (ctypes._memmove_addr, (0, 0, 0))) in hook.seen, f"{ctypes._memmove_addr=} {hook.seen=}" + + ctypes.CFUNCTYPE(ctypes.c_voidp)(ctypes._memset_addr)(1, 0, 0) + assert ("ctypes.call_function", (ctypes._memset_addr, (1, 0, 0))) in hook.seen, f"{ctypes._memset_addr=} {hook.seen=}" + + with TestHook() as hook: + ctypes.cast(ctypes.c_voidp(0), ctypes.POINTER(ctypes.c_char)) + assert "ctypes.call_function" in hook.seen_events + + with TestHook() as hook: + ctypes.string_at(id("ctypes.string_at") + 40) + assert "ctypes.call_function" in hook.seen_events + assert "ctypes.string_at" in hook.seen_events + + +def test_posixsubprocess(): + import multiprocessing.util + + exe = b"xxx" + args = [b"yyy", b"zzz"] + with TestHook() as hook: + multiprocessing.util.spawnv_passfds(exe, args, ()) + assert ("_posixsubprocess.fork_exec", ([exe], args, None)) in hook.seen + + +def test_excepthook(): + def excepthook(exc_type, exc_value, exc_tb): + if exc_type is not RuntimeError: + sys.__excepthook__(exc_type, exc_value, exc_tb) + + def hook(event, args): + if event == "sys.excepthook": + if not isinstance(args[2], args[1]): + raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})") + if args[0] != excepthook: + raise ValueError(f"Expected {args[0]} == {excepthook}") + print(event, repr(args[2])) + + sys.addaudithook(hook) + sys.excepthook = excepthook + raise RuntimeError("fatal-error") + + +def test_unraisablehook(): + from _testcapi import err_formatunraisable + + def unraisablehook(hookargs): + pass + + def hook(event, args): + if event == "sys.unraisablehook": + if args[0] != unraisablehook: + raise ValueError(f"Expected {args[0]} == {unraisablehook}") + print(event, repr(args[1].exc_value), args[1].err_msg) + + sys.addaudithook(hook) + sys.unraisablehook = unraisablehook + err_formatunraisable(RuntimeError("nonfatal-error"), + "Exception ignored for audit hook test") + + +def test_winreg(): + from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE + + def hook(event, args): + if not event.startswith("winreg."): + return + print(event, *args) + + sys.addaudithook(hook) + + k = OpenKey(HKEY_LOCAL_MACHINE, "Software") + EnumKey(k, 0) + try: + EnumKey(k, 10000) + except OSError: + pass + else: + raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail") + + kv = k.Detach() + CloseKey(kv) + + +def test_socket(): + import socket + + def hook(event, args): + if event.startswith("socket."): + print(event, *args) + + sys.addaudithook(hook) + + socket.gethostname() + + # Don't care if this fails, we just want the audit message + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + # Don't care if this fails, we just want the audit message + sock.bind(('127.0.0.1', 8080)) + except Exception: + pass + finally: + sock.close() + + +def test_gc(): + import gc + + def hook(event, args): + if event.startswith("gc."): + print(event, *args) + + sys.addaudithook(hook) + + gc.get_objects(generation=1) + + x = object() + y = [x] + + gc.get_referrers(x) + gc.get_referents(y) + + +def test_http_client(): + import http.client + + def hook(event, args): + if event.startswith("http.client."): + print(event, *args[1:]) + + sys.addaudithook(hook) + + conn = http.client.HTTPConnection('www.python.org') + try: + conn.request('GET', '/') + except OSError: + print('http.client.send', '[cannot send]') + finally: + conn.close() + + +def test_sqlite3(): + import sqlite3 + + def hook(event, *args): + if event.startswith("sqlite3."): + print(event, *args) + + sys.addaudithook(hook) + cx1 = sqlite3.connect(":memory:") + cx2 = sqlite3.Connection(":memory:") + + # Configured without --enable-loadable-sqlite-extensions + try: + if hasattr(sqlite3.Connection, "enable_load_extension"): + cx1.enable_load_extension(False) + try: + cx1.load_extension("test") + except sqlite3.OperationalError: + pass + else: + raise RuntimeError("Expected sqlite3.load_extension to fail") + finally: + cx1.close() + cx2.close() + +def test_sys_getframe(): + import sys + + def hook(event, args): + if event.startswith("sys."): + print(event, args[0].f_code.co_name) + + sys.addaudithook(hook) + sys._getframe() + + +def test_sys_getframemodulename(): + import sys + + def hook(event, args): + if event.startswith("sys."): + print(event, *args) + + sys.addaudithook(hook) + sys._getframemodulename() + + +def test_threading(): + import _thread + + def hook(event, args): + if event.startswith(("_thread.", "cpython.PyThreadState", "test.")): + print(event, args) + + sys.addaudithook(hook) + + lock = _thread.allocate_lock() + lock.acquire() + + class test_func: + def __repr__(self): return "" + def __call__(self): + sys.audit("test.test_func") + lock.release() + + i = _thread.start_new_thread(test_func(), ()) + lock.acquire() + + handle = _thread.start_joinable_thread(test_func()) + handle.join() + + +def test_threading_abort(): + # Ensures that aborting PyThreadState_New raises the correct exception + import _thread + + class ThreadNewAbortError(Exception): + pass + + def hook(event, args): + if event == "cpython.PyThreadState_New": + raise ThreadNewAbortError() + + sys.addaudithook(hook) + + try: + _thread.start_new_thread(lambda: None, ()) + except ThreadNewAbortError: + # Other exceptions are raised and the test will fail + pass + + +def test_wmi_exec_query(): + import _wmi + + def hook(event, args): + if event.startswith("_wmi."): + print(event, args[0]) + + sys.addaudithook(hook) + try: + _wmi.exec_query("SELECT * FROM Win32_OperatingSystem") + except WindowsError as e: + # gh-112278: WMI may be slow response when first called, but we still + # get the audit event, so just ignore the timeout + if e.winerror != 258: + raise + +def test_syslog(): + import syslog + + def hook(event, args): + if event.startswith("syslog."): + print(event, *args) + + sys.addaudithook(hook) + syslog.openlog('python') + syslog.syslog('test') + syslog.setlogmask(syslog.LOG_DEBUG) + syslog.closelog() + # implicit open + syslog.syslog('test2') + # open with default ident + syslog.openlog(logoption=syslog.LOG_NDELAY, facility=syslog.LOG_LOCAL0) + sys.argv = None + syslog.openlog() + syslog.closelog() + + +def test_not_in_gc(): + import gc + + hook = lambda *a: None + sys.addaudithook(hook) + + for o in gc.get_objects(): + if isinstance(o, list): + assert hook not in o + + +def test_time(mode): + import time + + def hook(event, args): + if event.startswith("time."): + if mode == 'print': + print(event, *args) + elif mode == 'fail': + raise AssertionError('hook failed') + sys.addaudithook(hook) + + time.sleep(0) + time.sleep(0.0625) # 1/16, a small exact float + try: + time.sleep(-1) + except ValueError: + pass + +def test_sys_monitoring_register_callback(): + import sys + + def hook(event, args): + if event.startswith("sys.monitoring"): + print(event, args) + + sys.addaudithook(hook) + sys.monitoring.register_callback(1, 1, None) + + +def test_winapi_createnamedpipe(pipe_name): + import _winapi + + def hook(event, args): + if event == "_winapi.CreateNamedPipe": + print(event, args) + + sys.addaudithook(hook) + _winapi.CreateNamedPipe(pipe_name, _winapi.PIPE_ACCESS_DUPLEX, 8, 2, 0, 0, 0, 0) + + +def test_assert_unicode(): + import sys + sys.addaudithook(lambda *args: None) + try: + sys.audit(9) + except TypeError: + pass + else: + raise RuntimeError("Expected sys.audit(9) to fail.") + +def test_sys_remote_exec(): + import tempfile + + pid = os.getpid() + event_pid = -1 + event_script_path = "" + remote_event_script_path = "" + def hook(event, args): + if event not in ["sys.remote_exec", "cpython.remote_debugger_script"]: + return + print(event, args) + match event: + case "sys.remote_exec": + nonlocal event_pid, event_script_path + event_pid = args[0] + event_script_path = args[1] + case "cpython.remote_debugger_script": + nonlocal remote_event_script_path + remote_event_script_path = args[0] + + sys.addaudithook(hook) + with tempfile.NamedTemporaryFile(mode='w+', delete=True) as tmp_file: + tmp_file.write("a = 1+1\n") + tmp_file.flush() + sys.remote_exec(pid, tmp_file.name) + assertEqual(event_pid, pid) + assertEqual(event_script_path, tmp_file.name) + assertEqual(remote_event_script_path, tmp_file.name) + +if __name__ == "__main__": + from test.support import suppress_msvcrt_asserts + + suppress_msvcrt_asserts() + + test = sys.argv[1] + globals()[test](*sys.argv[2:]) diff --git a/Lib/test/test_audit.py b/Lib/test/test_audit.py index ddd9f95114..d01d36ad3d 100644 --- a/Lib/test/test_audit.py +++ b/Lib/test/test_audit.py @@ -23,6 +23,7 @@ def run_test_in_subprocess(self, *args): with subprocess.Popen( [sys.executable, "-X utf8", AUDIT_TESTS_PY, *args], encoding="utf-8", + errors="backslashreplace", stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) as p: @@ -57,6 +58,7 @@ def test_block_add_hook(self): def test_block_add_hook_baseexception(self): self.do_test("test_block_add_hook_baseexception") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_marshal(self): import_helper.import_module("marshal") @@ -67,18 +69,33 @@ def test_pickle(self): self.do_test("test_pickle") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_monkeypatch(self): self.do_test("test_monkeypatch") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_open(self): self.do_test("test_open", os_helper.TESTFN) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cantrace(self): self.do_test("test_cantrace") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_mmap(self): self.do_test("test_mmap") + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_ctypes_call_function(self): + import_helper.import_module("ctypes") + self.do_test("test_ctypes_call_function") + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_posixsubprocess(self): + import_helper.import_module("_posixsubprocess") + self.do_test("test_posixsubprocess") + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_excepthook(self): returncode, events, stderr = self.run_python("test_excepthook") if not returncode: @@ -100,6 +117,7 @@ def test_unraisablehook(self): "RuntimeError('nonfatal-error') Exception ignored for audit hook test", ) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_winreg(self): import_helper.import_module("winreg") returncode, events, stderr = self.run_python("test_winreg") @@ -125,8 +143,9 @@ def test_socket(self): self.assertEqual(events[0][0], "socket.gethostname") self.assertEqual(events[1][0], "socket.__new__") self.assertEqual(events[2][0], "socket.bind") - self.assertTrue(events[2][2].endswith("('127.0.0.1', 8080)")) + self.assertEndsWith(events[2][2], "('127.0.0.1', 8080)") + @unittest.expectedFailure # TODO: RUSTPYTHON def test_gc(self): returncode, events, stderr = self.run_python("test_gc") if returncode: @@ -156,6 +175,7 @@ def test_http(self): self.assertIn('HTTP', events[1][2]) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_sqlite3(self): sqlite3 = import_helper.import_module("sqlite3") returncode, events, stderr = self.run_python("test_sqlite3") @@ -200,6 +220,7 @@ def test_sys_getframemodulename(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_threading(self): returncode, events, stderr = self.run_python("test_threading") if returncode: @@ -218,6 +239,7 @@ def test_threading(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_wmi_exec_query(self): import_helper.import_module("_wmi") returncode, events, stderr = self.run_python("test_wmi_exec_query") @@ -231,6 +253,7 @@ def test_wmi_exec_query(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_syslog(self): syslog = import_helper.import_module("syslog") @@ -292,6 +315,7 @@ def test_sys_monitoring_register_callback(self): self.assertEqual(actual, expected) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_winapi_createnamedpipe(self): winapi = import_helper.import_module("_winapi") @@ -313,6 +337,14 @@ def test_assert_unicode(self): if returncode: self.fail(stderr) + @support.support_remote_exec_only + @support.cpython_only + def test_sys_remote_exec(self): + returncode, events, stderr = self.run_python("test_sys_remote_exec") + self.assertTrue(any(["sys.remote_exec" in event for event in events])) + self.assertTrue(any(["cpython.remote_debugger_script" in event for event in events])) + if returncode: + self.fail(stderr) if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_bdb.py b/Lib/test/test_bdb.py index f15dae13eb..590d8166b6 100644 --- a/Lib/test/test_bdb.py +++ b/Lib/test/test_bdb.py @@ -728,6 +728,7 @@ def test_until_in_caller_frame(self): with TracerRun(self) as tracer: tracer.runcall(tfunc_main) + @unittest.skipIf(hasattr(__import__("sys"), "addaudithook"), "TODO: RUSTPYTHON; Currently no conditional tracing toggle") @patch_list(sys.meta_path) def test_skip(self): # Check that tracing is skipped over the import statement in diff --git a/Lib/test/test_sys_setprofile.py b/Lib/test/test_sys_setprofile.py index 21a09b5192..813adff2a3 100644 --- a/Lib/test/test_sys_setprofile.py +++ b/Lib/test/test_sys_setprofile.py @@ -30,9 +30,9 @@ def callback(self, frame, event, arg): if (event == "call" or event == "return" or event == "exception"): - self.add_event(event, frame) + self.add_event(event, frame, arg) - def add_event(self, event, frame=None): + def add_event(self, event, frame=None, arg=None): """Add an event to the log.""" if frame is None: frame = sys._getframe(1) @@ -43,7 +43,7 @@ def add_event(self, event, frame=None): frameno = len(self.frames) self.frames.append(frame) - self.events.append((frameno, event, ident(frame))) + self.events.append((frameno, event, ident(frame), arg)) def get_events(self): """Remove calls to add_event().""" @@ -89,11 +89,16 @@ def trace_pass(self, frame): class TestCaseBase(unittest.TestCase): - def check_events(self, callable, expected): + def check_events(self, callable, expected, check_args=False): events = capture_events(callable, self.new_watcher()) - if events != expected: - self.fail("Expected events:\n%s\nReceived events:\n%s" - % (pprint.pformat(expected), pprint.pformat(events))) + if check_args: + if events != expected: + self.fail("Expected events:\n%s\nReceived events:\n%s" + % (pprint.pformat(expected), pprint.pformat(events))) + else: + if [(frameno, event, ident) for frameno, event, ident, arg in events] != expected: + self.fail("Expected events:\n%s\nReceived events:\n%s" + % (pprint.pformat(expected), pprint.pformat(events))) class ProfileHookTestCase(TestCaseBase): @@ -119,7 +124,7 @@ def f(p): def test_caught_exception(self): def f(p): try: 1/0 - except: pass + except ZeroDivisionError: pass f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -128,7 +133,7 @@ def f(p): def test_caught_nested_exception(self): def f(p): try: 1/0 - except: pass + except ZeroDivisionError: pass f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -151,9 +156,9 @@ def f(p): def g(p): try: f(p) - except: + except ZeroDivisionError: try: f(p) - except: pass + except ZeroDivisionError: pass f_ident = ident(f) g_ident = ident(g) self.check_events(g, [(1, 'call', g_ident), @@ -164,6 +169,7 @@ def g(p): (1, 'return', g_ident), ]) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_exception_propagation(self): def f(p): 1/0 @@ -182,7 +188,7 @@ def g(p): def test_raise_twice(self): def f(p): try: 1/0 - except: 1/0 + except ZeroDivisionError: 1/0 f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -191,7 +197,7 @@ def f(p): def test_raise_reraise(self): def f(p): try: 1/0 - except: raise + except ZeroDivisionError: raise f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -255,6 +261,24 @@ def g(p): (1, 'return', g_ident), ]) + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_unfinished_generator(self): + def f(): + for i in range(2): + yield i + def g(p): + next(f()) + + f_ident = ident(f) + g_ident = ident(g) + self.check_events(g, [(1, 'call', g_ident, None), + (2, 'call', f_ident, None), + (2, 'return', f_ident, 0), + (2, 'call', f_ident, None), + (2, 'return', f_ident, None), + (1, 'return', g_ident, None), + ], check_args=True) + def test_stop_iteration(self): def f(): for i in range(2): @@ -300,7 +324,7 @@ def f(p): def test_caught_exception(self): def f(p): try: 1/0 - except: pass + except ZeroDivisionError: pass f_ident = ident(f) self.check_events(f, [(1, 'call', f_ident), (1, 'return', f_ident), @@ -415,5 +439,104 @@ def show_events(callable): pprint.pprint(capture_events(callable)) +class TestEdgeCases(unittest.TestCase): + + def setUp(self): + self.addCleanup(sys.setprofile, sys.getprofile()) + sys.setprofile(None) + + def test_reentrancy(self): + def foo(*args): + ... + + def bar(*args): + ... + + class A: + def __call__(self, *args): + pass + + def __del__(self): + sys.setprofile(bar) + + sys.setprofile(A()) + sys.setprofile(foo) + self.assertEqual(sys.getprofile(), bar) + + def test_same_object(self): + def foo(*args): + ... + + sys.setprofile(foo) + del foo + sys.setprofile(sys.getprofile()) + + def test_profile_after_trace_opcodes(self): + def f(): + ... + + sys._getframe().f_trace_opcodes = True + prev_trace = sys.gettrace() + sys.settrace(lambda *args: None) + f() + sys.settrace(prev_trace) + sys.setprofile(lambda *args: None) + f() + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_method_with_c_function(self): + # gh-122029 + # When we have a PyMethodObject whose im_func is a C function, we + # should record both the call and the return. f = classmethod(repr) + # is just a way to create a PyMethodObject with a C function. + class A: + f = classmethod(repr) + events = [] + sys.setprofile(lambda frame, event, args: events.append(event)) + A().f() + sys.setprofile(None) + # The last c_call is the call to sys.setprofile + self.assertEqual(events, ['c_call', 'c_return', 'c_call']) + + class B: + f = classmethod(max) + events = [] + sys.setprofile(lambda frame, event, args: events.append(event)) + # Not important, we only want to trigger INSTRUMENTED_CALL_KW + B().f(1, key=lambda x: 0) + sys.setprofile(None) + # The last c_call is the call to sys.setprofile + self.assertEqual( + events, + ['c_call', + 'call', 'return', + 'call', 'return', + 'c_return', + 'c_call' + ] + ) + + # Test CALL_FUNCTION_EX + events = [] + sys.setprofile(lambda frame, event, args: events.append(event)) + # Not important, we only want to trigger INSTRUMENTED_CALL_KW + args = (1,) + m = B().f + m(*args, key=lambda x: 0) + sys.setprofile(None) + # The last c_call is the call to sys.setprofile + # INSTRUMENTED_CALL_FUNCTION_EX has different behavior than the other + # instrumented call bytecodes, it does not unpack the callable before + # calling it. This is probably not ideal because it's not consistent, + # but at least we get a consistent call stack (no unmatched c_call). + self.assertEqual( + events, + ['call', 'return', + 'call', 'return', + 'c_call' + ] + ) + + if __name__ == "__main__": unittest.main() diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index 6aff5d452c..366de2ecc2 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -1426,6 +1426,13 @@ mod _socket { let mut socket_kind = args.r#type.unwrap_or(-1); let mut proto = args.proto.unwrap_or(-1); + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call( + (vm.ctx.new_str("socket.__new__"), family, socket_kind, proto), + vm, + )?; + } + let fileno = args.fileno; let sock; @@ -1555,6 +1562,18 @@ mod _socket { #[pymethod] fn bind(&self, address: PyObjectRef, vm: &VirtualMachine) -> Result<(), IoOrPyException> { let sock_addr = self.extract_address(address, "bind", vm)?; + + if let Some(addr) = sock_addr.as_socket() + && let Ok(audit) = vm.sys_module.get_attr("audit", vm) + { + let (ip, port) = match addr { + SocketAddr::V4(addr) => (addr.ip().to_string(), addr.port()), + SocketAddr::V6(addr) => (addr.ip().to_string(), addr.port()), + }; + + audit.call((vm.ctx.new_str("socket.bind"), (ip, port)), vm)?; + } + Ok(self.sock()?.bind(&sock_addr)?) } @@ -2300,6 +2319,10 @@ mod _socket { #[pyfunction] fn gethostname(vm: &VirtualMachine) -> PyResult { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("socket.gethostname"),), vm)?; + } + gethostname::gethostname() .into_string() .map(|hostname| vm.ctx.new_str(hostname)) diff --git a/crates/stdlib/src/syslog.rs b/crates/stdlib/src/syslog.rs index fab82f14f5..52424972a0 100644 --- a/crates/stdlib/src/syslog.rs +++ b/crates/stdlib/src/syslog.rs @@ -53,12 +53,29 @@ mod syslog { fn openlog(args: OpenLogArgs, vm: &VirtualMachine) -> PyResult<()> { let logoption = args.logoption.unwrap_or(0); let facility = args.facility.unwrap_or(LOG_USER); - let ident = match args.ident.flatten() { + let ident = match args.ident.clone().flatten() { Some(args) => Some(args.to_cstring(vm)?), None => get_argv(vm).map(|argv| argv.to_cstring(vm)).transpose()?, } .map(|ident| ident.into_boxed_c_str()); + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + let audit_ident: PyObjectRef = args.ident.flatten().map_or_else( + || get_argv(vm).map_or_else(|| vm.ctx.none(), Into::into), + Into::into, + ); + + audit.call( + ( + vm.ctx.new_str("syslog.openlog"), + audit_ident, + logoption, + facility, + ), + vm, + )?; + } + host_syslog::openlog(ident, logoption, facility); Ok(()) } @@ -78,6 +95,10 @@ mod syslog { None => (LOG_INFO, args.priority.try_into_value(vm)?), }; + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("syslog.syslog"), priority, msg.clone()), vm)?; + } + if !host_syslog::is_open() { openlog(OpenLogArgs::default(), vm)?; } @@ -88,13 +109,22 @@ mod syslog { } #[pyfunction] - fn closelog() { + fn closelog(vm: &VirtualMachine) -> PyResult<()> { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("syslog.closelog"),), vm)?; + } + host_syslog::closelog(); + Ok(()) } #[pyfunction] - fn setlogmask(maskpri: i32) -> i32 { - host_syslog::setlogmask(maskpri) + fn setlogmask(maskpri: i32, vm: &VirtualMachine) -> PyResult { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("syslog.setlogmask"), maskpri), vm)?; + } + + Ok(host_syslog::setlogmask(maskpri)) } #[inline] diff --git a/crates/vm/src/stdlib/marshal.rs b/crates/vm/src/stdlib/marshal.rs index 60ecb1792f..6e0fc4e7f5 100644 --- a/crates/vm/src/stdlib/marshal.rs +++ b/crates/vm/src/stdlib/marshal.rs @@ -107,6 +107,14 @@ mod decl { _version, } = args; let version = _version.unwrap_or(marshal::FORMAT_VERSION as i32); + + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call( + (vm.ctx.new_str("marshal.dumps"), value.clone(), version), + vm, + )?; + } + if !allow_code { check_no_code(&value, vm)?; } diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 68d57d225a..5e3ba265ea 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -755,11 +755,6 @@ pub mod sys { Ok(()) } - #[pyfunction] - fn audit(_args: FuncArgs) { - // TODO: sys.audit implementation - } - #[pyfunction] const fn _is_gil_enabled() -> bool { false // RustPython has no GIL (like free-threaded Python) @@ -973,25 +968,40 @@ pub mod sys { #[pyfunction] fn _getframe(offset: OptionalArg, vm: &VirtualMachine) -> PyResult { let offset = offset.into_option().unwrap_or(0); - let frames = vm.frames.borrow(); - if offset >= frames.len() { - return Err(vm.new_value_error("call stack is not deep enough")); + let frame_ref = { + let frames = vm.frames.borrow(); + if offset >= frames.len() { + return Err(vm.new_value_error("call stack is not deep enough")); + } + + let idx = frames.len() - offset - 1; + // SAFETY: the FrameRef is alive on the call stack while it's in the Vec + let py: &crate::Py = unsafe { frames[idx].as_ref() }; + py.to_owned() + }; + + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("sys._getframe"), frame_ref.to_owned()), vm)?; } - let idx = frames.len() - offset - 1; - // SAFETY: the FrameRef is alive on the call stack while it's in the Vec - let py: &crate::Py = unsafe { frames[idx].as_ref() }; - Ok(py.to_owned()) + + Ok(frame_ref) } #[pyfunction] - fn _getframemodulename(depth: OptionalArg, vm: &VirtualMachine) -> PyObjectRef { + fn _getframemodulename( + depth: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { let depth = depth.into_option().unwrap_or(0); + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("sys._getframemodulename"), depth), vm)?; + } // Get the frame at the specified depth let func_obj = { let frames = vm.frames.borrow(); if depth >= frames.len() { - return vm.ctx.none(); + return Ok(vm.ctx.none()); } let idx = frames.len() - depth - 1; // SAFETY: the FrameRef is alive on the call stack while it's in the Vec @@ -1000,7 +1010,7 @@ pub mod sys { }; // If the frame has a function object, return its __module__ attribute - if let Some(func_obj) = func_obj { + Ok(if let Some(func_obj) = func_obj { func_obj .get_attr(identifier!(vm, __module__), vm) .unwrap_or_else( @@ -1009,7 +1019,7 @@ pub mod sys { ) } else { vm.ctx.none() - } + }) } /// Return a dictionary mapping each thread's identifier to the topmost stack frame @@ -1721,6 +1731,60 @@ pub mod sys { #[pyclass(with(PyStructSequence))] impl PyUnraisableHookArgs {} + + pub(crate) fn run_audit_hooks( + event: PyStrRef, + args: &PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + let hooks = vm.audit_hooks.borrow().clone(); + + if hooks.is_empty() { + return Ok(()); + } + + for hook in hooks { + hook.call((event.clone(), args.clone()), vm)?; + } + + Ok(()) + } + + #[pyfunction] + fn audit(event: PyStrRef, args: PosArgs, vm: &VirtualMachine) -> PyResult<()> { + if vm.audit_hooks.borrow().is_empty() { + return Ok(()); + } + + let args_tup = vm.ctx.new_tuple(args.into_vec()).into(); + run_audit_hooks(event, &args_tup, vm) + } + + #[pyfunction] + fn addaudithook(hook: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let hooks = vm.audit_hooks.borrow().clone(); + + if hooks.is_empty() { + vm.audit_hooks.borrow_mut().push(hook); + return Ok(()); + } + + let args: PyObjectRef = vm.ctx.new_tuple(vec![]).into(); + let event: PyObjectRef = vm.ctx.new_str("sys.addaudithook").into(); + + for existing_hook in hooks { + let Err(exc) = existing_hook.call((event.clone(), args.clone()), vm) else { + continue; + }; + if exc.class().fast_issubclass(vm.ctx.exceptions.runtime_error) { + return Ok(()); + } + return Err(exc); + } + + vm.audit_hooks.borrow_mut().push(hook); + Ok(()) + } } pub(crate) fn init_module(vm: &VirtualMachine, module: &Py, builtins: &Py) { diff --git a/crates/vm/src/stdlib/sys/monitoring.rs b/crates/vm/src/stdlib/sys/monitoring.rs index bf113c7937..6e61692507 100644 --- a/crates/vm/src/stdlib/sys/monitoring.rs +++ b/crates/vm/src/stdlib/sys/monitoring.rs @@ -598,6 +598,16 @@ fn register_callback( let tool = check_valid_tool(tool_id, vm)?; let event_id = parse_single_event(event, vm)?; + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call( + ( + vm.ctx.new_str("sys.monitoring.register_callback"), + func.clone(), + ), + vm, + )?; + } + let mut state = vm.state.monitoring.lock(); let prev = state .callbacks diff --git a/crates/vm/src/stdlib/time.rs b/crates/vm/src/stdlib/time.rs index 93d04341d2..a56dffe08c 100644 --- a/crates/vm/src/stdlib/time.rs +++ b/crates/vm/src/stdlib/time.rs @@ -74,6 +74,10 @@ mod decl { #[pyfunction] fn sleep(seconds: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if let Ok(audit) = vm.sys_module.get_attr("audit", vm) { + audit.call((vm.ctx.new_str("time.sleep"), seconds.clone()), vm)?; + } + let seconds_type_name = seconds.class().name().to_owned(); let dur = seconds.try_into_value::(vm).map_err(|e| { if e.class().is(vm.ctx.exceptions.value_error) diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index 85416a0912..a99012cd14 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -100,6 +100,7 @@ pub struct VirtualMachine { /// Current running asyncio task for this thread pub asyncio_running_task: RefCell>, pub(crate) callable_cache: CallableCache, + pub(crate) audit_hooks: RefCell>, } /// Non-owning frame pointer for the frames stack. @@ -750,6 +751,7 @@ impl VirtualMachine { asyncio_running_loop: RefCell::new(None), asyncio_running_task: RefCell::new(None), callable_cache: CallableCache::default(), + audit_hooks: RefCell::new(vec![]), }; if vm.state.hash_secret.hash_str("") diff --git a/crates/vm/src/vm/thread.rs b/crates/vm/src/vm/thread.rs index a81396bff3..f89cf818d2 100644 --- a/crates/vm/src/vm/thread.rs +++ b/crates/vm/src/vm/thread.rs @@ -730,6 +730,7 @@ impl VirtualMachine { asyncio_running_loop: RefCell::new(None), asyncio_running_task: RefCell::new(None), callable_cache: self.callable_cache.clone(), + audit_hooks: RefCell::new(vec![]), }; ThreadedVirtualMachine { vm } } From 438925401f3411045259662e11b1ebb5fa564955 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 24 May 2026 19:57:20 +0900 Subject: [PATCH 14/18] Bump qs and express in /wasm/demo (#7959) Bumps [qs](https://github.com/ljharb/qs) and [express](https://github.com/expressjs/express). These dependencies needed to be updated together. Updates `qs` from 6.14.2 to 6.15.2 - [Changelog](https://github.com/ljharb/qs/blob/main/CHANGELOG.md) - [Commits](https://github.com/ljharb/qs/compare/v6.14.2...v6.15.2) Updates `express` from 4.22.1 to 4.22.2 - [Release notes](https://github.com/expressjs/express/releases) - [Changelog](https://github.com/expressjs/express/blob/v4.22.2/History.md) - [Commits](https://github.com/expressjs/express/compare/v4.22.1...v4.22.2) --- updated-dependencies: - dependency-name: express dependency-version: 4.22.2 dependency-type: indirect - dependency-name: qs dependency-version: 6.15.2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- wasm/demo/package-lock.json | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/wasm/demo/package-lock.json b/wasm/demo/package-lock.json index 9169b00a67..14744312ac 100644 --- a/wasm/demo/package-lock.json +++ b/wasm/demo/package-lock.json @@ -1282,22 +1282,6 @@ "url": "https://opencollective.com/express" } }, - "node_modules/body-parser/node_modules/qs": { - "version": "6.15.1", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.15.1.tgz", - "integrity": "sha512-6YHEFRL9mfgcAvql/XhwTvf5jKcOiiupt2FiJxHkiX1z4j7WL8J/jRHYLluORvc1XxB5rV20KoeK00gVJamspg==", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "side-channel": "^1.1.0" - }, - "engines": { - "node": ">=0.6" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/body-parser/node_modules/statuses": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz", @@ -2470,15 +2454,15 @@ } }, "node_modules/express": { - "version": "4.22.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.22.1.tgz", - "integrity": "sha512-F2X8g9P1X7uCPZMA3MVf9wcTqlyNp7IhH5qPCI0izhaOIYXaW9L535tGA3qmjRzpH+bZczqq7hVKxTR4NWnu+g==", + "version": "4.22.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.22.2.tgz", + "integrity": "sha512-IuL+Elrou2ZvCFHs18/CIzy2Nzvo25nZ1/D2eIZlz7c+QUayAcYoiM2BthCjs+EBHVpjYjcuLDAiCWgeIX3X1Q==", "dev": true, "license": "MIT", "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "~1.20.3", + "body-parser": "~1.20.5", "content-disposition": "~0.5.4", "content-type": "~1.0.4", "cookie": "~0.7.1", @@ -2497,7 +2481,7 @@ "parseurl": "~1.3.3", "path-to-regexp": "~0.1.12", "proxy-addr": "~2.0.7", - "qs": "~6.14.0", + "qs": "~6.15.1", "range-parser": "~1.2.1", "safe-buffer": "5.2.1", "send": "~0.19.0", @@ -4260,9 +4244,9 @@ } }, "node_modules/qs": { - "version": "6.14.2", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz", - "integrity": "sha512-V/yCWTTF7VJ9hIh18Ugr2zhJMP01MY7c5kh4J870L7imm6/DIzBsNLTXzMwUA3yZ5b/KBqLx8Kp3uRvd7xSe3Q==", + "version": "6.15.2", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.15.2.tgz", + "integrity": "sha512-Rzq0KEyX/w/tEybncDgdkZrJgVUsUMk3xjh3t5bv3S1HTAtg+uOYt72+ZfwiQwKdysThkTBdL/rTi6HDmX9Ddw==", "dev": true, "license": "BSD-3-Clause", "dependencies": { From 7011942e4eade964f0e4ffcdf4fc3d260a362f82 Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Sun, 24 May 2026 13:58:16 +0300 Subject: [PATCH 15/18] Add `builtin.PythonFinalizationError` (#7966) * Add PythonFinalizationError to builtins * Patch failing tests (unrelated) * Unmark passing test * Update `exception_hierarchy.txt` to 3.14.5 --- Lib/test/exception_hierarchy.txt | 1 + Lib/test/test_builtin.py | 4 ++++ Lib/test/test_pickle.py | 1 - crates/vm/src/stdlib/builtins.rs | 1 + 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Lib/test/exception_hierarchy.txt b/Lib/test/exception_hierarchy.txt index 1eca123be0..f2649aa2d4 100644 --- a/Lib/test/exception_hierarchy.txt +++ b/Lib/test/exception_hierarchy.txt @@ -40,6 +40,7 @@ BaseException ├── ReferenceError ├── RuntimeError │ ├── NotImplementedError + │ ├── PythonFinalizationError │ └── RecursionError ├── StopAsyncIteration ├── StopIteration diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index cf0268c2ce..13b0d4a2a2 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -2696,6 +2696,7 @@ def detach_readline(self): else: yield + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty(self): # Test input() functionality when wired to a tty self.check_input_tty("prompt", b"quux") @@ -2710,17 +2711,20 @@ def test_input_tty_non_ascii_unicode_errors(self): # Check stdin/stdout error handler is used when invoking PyOS_Readline() self.check_input_tty("prompté", b"quux\xe9", "ascii") + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty_null_in_prompt(self): self.check_input_tty("prompt\0", b"", expected='ValueError: input: prompt string cannot contain ' 'null characters') + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty_nonencodable_prompt(self): self.check_input_tty("prompté", b"quux", "ascii", stdout_errors='strict', expected="UnicodeEncodeError: 'ascii' codec can't encode " "character '\\xe9' in position 6: ordinal not in " "range(128)") + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: got 0 lines in pipe but expected 2, child output was: quux def test_input_tty_nondecodable_input(self): self.check_input_tty("prompt", b"quux\xe9", "ascii", stdin_errors='strict', expected="UnicodeDecodeError: 'ascii' codec can't decode " diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index c9d4a34844..bfb0c17337 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -775,7 +775,6 @@ def test_reverse_name_mapping(self): module, name = mapping(module, name) self.assertEqual((module, name), (module3, name3)) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_exceptions(self): self.assertEqual(mapping('exceptions', 'StandardError'), ('builtins', 'Exception')) diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index f2d4dcdb64..403e8c12e6 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -1474,6 +1474,7 @@ pub fn init_module(vm: &VirtualMachine, module: &Py) { "TimeoutError" => ctx.exceptions.timeout_error.to_owned(), "ReferenceError" => ctx.exceptions.reference_error.to_owned(), "RuntimeError" => ctx.exceptions.runtime_error.to_owned(), + "PythonFinalizationError" => ctx.exceptions.python_finalization_error.to_owned(), "NotImplementedError" => ctx.exceptions.not_implemented_error.to_owned(), "RecursionError" => ctx.exceptions.recursion_error.to_owned(), "SyntaxError" => ctx.exceptions.syntax_error.to_owned(), From 52305c0c7268aabcdca681aff8f6d95f8e0b1e2e Mon Sep 17 00:00:00 2001 From: Shahar Naveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Sun, 24 May 2026 14:00:32 +0300 Subject: [PATCH 16/18] Skip flaky test (#7967) --- Lib/test/test_asyncio/test_sendfile.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Lib/test/test_asyncio/test_sendfile.py b/Lib/test/test_asyncio/test_sendfile.py index 3acbc37f24..e266d57742 100644 --- a/Lib/test/test_asyncio/test_sendfile.py +++ b/Lib/test/test_asyncio/test_sendfile.py @@ -566,6 +566,10 @@ class EPollEventLoopTests(SendfileTestsBase, def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.EpollSelector()) + @unittest.skipIf(sys.platform != "win32", "TODO: RUSTPYTHON; Flaky on CI") + def test_sendfile_ssl_pre_and_post_data(self): + return super().test_sendfile_ssl_pre_and_post_data() + if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(SendfileTestsBase, test_utils.TestCase): From bc3d00e879396cea7171d24b44e64bf197e8a2ca Mon Sep 17 00:00:00 2001 From: fanninpm <27117322+fanninpm@users.noreply.github.com> Date: Mon, 25 May 2026 00:53:54 -0400 Subject: [PATCH 17/18] Replace `ahash` with `rapidhash` (#7954) * Add `rapidhash` to list of dependencies * Use `rapidhash::quality::RandomState` in `codegen` crate * Use `rapidhash::quality::RandomState` in `stdlib` crate * Use `rapidhash::quality::RandomState` in `vm` crate * Remove `ahash` from lists of dependencies --- .github/dependabot.yml | 1 - Cargo.lock | 28 ++++++++++++---------------- Cargo.toml | 2 +- crates/codegen/Cargo.toml | 2 +- crates/codegen/src/lib.rs | 4 ++-- crates/stdlib/Cargo.toml | 2 +- crates/stdlib/src/contextvars.rs | 2 +- crates/vm/Cargo.toml | 2 +- crates/vm/src/builtins/type.rs | 3 ++- crates/vm/src/intern.rs | 2 +- crates/vm/src/vm/interpreter.rs | 7 +++++-- crates/vm/src/vm/mod.rs | 2 +- 12 files changed, 28 insertions(+), 29 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index e3f9ba3b7a..7533ce3680 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -85,7 +85,6 @@ updates: - "quote-use*" random: patterns: - - "ahash" - "getrandom" - "mt19937" - "rand*" diff --git a/Cargo.lock b/Cargo.lock index eab7707b5b..0d191fc99a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,19 +25,6 @@ dependencies = [ "cpufeatures", ] -[[package]] -name = "ahash" -version = "0.8.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" -dependencies = [ - "cfg-if", - "getrandom 0.3.4", - "once_cell", - "version_check", - "zerocopy", -] - [[package]] name = "aho-corasick" version = "1.1.4" @@ -2928,6 +2915,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rapidhash" +version = "4.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e48930979c155e2f33aa36ab3119b5ee81332beb6482199a8ecd6029b80b59" +dependencies = [ + "rustversion", +] + [[package]] name = "rayon" version = "1.12.0" @@ -3241,7 +3237,6 @@ dependencies = [ name = "rustpython-codegen" version = "0.5.0" dependencies = [ - "ahash", "bitflags 2.11.1", "indexmap", "itertools 0.14.0", @@ -3250,6 +3245,7 @@ dependencies = [ "memchr", "num-complex", "num-traits", + "rapidhash", "rustpython-compiler-core", "rustpython-literal", "rustpython-ruff_python_ast", @@ -3498,7 +3494,6 @@ name = "rustpython-stdlib" version = "0.5.0" dependencies = [ "adler32", - "ahash", "ascii", "aws-lc-rs", "base64", @@ -3549,6 +3544,7 @@ dependencies = [ "pkcs8", "pymath", "rand_core 0.9.5", + "rapidhash", "rustls", "rustls-native-certs", "rustls-pemfile", @@ -3588,7 +3584,6 @@ version = "0.5.0" name = "rustpython-vm" version = "0.5.0" dependencies = [ - "ahash", "ascii", "bitflags 2.11.1", "bstr", @@ -3619,6 +3614,7 @@ dependencies = [ "parking_lot", "paste", "psm", + "rapidhash", "result-like", "rustpython-codegen", "rustpython-common", diff --git a/Cargo.toml b/Cargo.toml index 567e68a91d..9fbdd9d71e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -176,7 +176,6 @@ ruff_source_file = { package = "rustpython-ruff_source_file", version = "0.15.8" der = { version = "0.8", features = ["alloc", "oid", "pem", "zeroize"] } phf = { version = "0.13.1", default-features = false, features = ["macros"]} adler32 = "1.2.0" -ahash = "0.8.12" approx = "0.5.1" ascii = "1.1" aws-lc-rs = "1.16.3" @@ -260,6 +259,7 @@ quote = "1.0.45" radium = "1.1.1" rand = "0.9" rand_core = { version = "0.9", features = ["os_rng"] } +rapidhash = "4.4.1" result-like = "0.5.0" rustix = { version = "1.1", features = ["event", "param", "system"] } rustls = { version = "0.23.39", default-features = false } diff --git a/crates/codegen/Cargo.toml b/crates/codegen/Cargo.toml index 4f32fcc3c3..031f3b9652 100644 --- a/crates/codegen/Cargo.toml +++ b/crates/codegen/Cargo.toml @@ -19,7 +19,6 @@ rustpython-wtf8 = { workspace = true } ruff_python_ast = { workspace = true } ruff_text_size = { workspace = true } -ahash = { workspace = true } bitflags = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } @@ -29,6 +28,7 @@ num-traits = { workspace = true } thiserror = { workspace = true } malachite-bigint = { workspace = true } memchr = { workspace = true } +rapidhash = { workspace = true } unicode_names2 = { workspace = true } [dev-dependencies] diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 8d6ad98435..b598ab7e93 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -8,8 +8,8 @@ extern crate log; extern crate alloc; -type IndexMap = indexmap::IndexMap; -type IndexSet = indexmap::IndexSet; +type IndexMap = indexmap::IndexMap; +type IndexSet = indexmap::IndexSet; pub mod compile; pub mod error; diff --git a/crates/stdlib/Cargo.toml b/crates/stdlib/Cargo.toml index f77d467d63..88afc24327 100644 --- a/crates/stdlib/Cargo.toml +++ b/crates/stdlib/Cargo.toml @@ -37,7 +37,6 @@ ruff_python_ast = { workspace = true } ruff_text_size = { workspace = true } ruff_source_file = { workspace = true } -ahash = { workspace = true } ascii = { workspace = true } crossbeam-utils = { workspace = true } flame = { workspace = true, optional = true } @@ -51,6 +50,7 @@ num-traits = { workspace = true } num_enum = { workspace = true } parking_lot = { workspace = true } phf = { workspace = true, default-features = true, features = ["macros"] } +rapidhash = { workspace = true } memchr = { workspace = true } base64 = { workspace = true } diff --git a/crates/stdlib/src/contextvars.rs b/crates/stdlib/src/contextvars.rs index ee5942755c..0a6e0f1231 100644 --- a/crates/stdlib/src/contextvars.rs +++ b/crates/stdlib/src/contextvars.rs @@ -28,7 +28,7 @@ mod _contextvars { use indexmap::IndexMap; // TODO: Real hamt implementation - type Hamt = IndexMap, PyObjectRef, ahash::RandomState>; + type Hamt = IndexMap, PyObjectRef, rapidhash::quality::RandomState>; #[pyclass(no_attr, name = "Hamt", module = "contextvars")] #[derive(Debug, PyPayload)] diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 8ff078b818..83e41fa1f5 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -44,7 +44,6 @@ rustpython-literal = { workspace = true } rustpython-sre_engine = { workspace = true } ascii = { workspace = true } -ahash = { workspace = true } bitflags = { workspace = true } bstr = { workspace = true } crossbeam-utils = { workspace = true } @@ -64,6 +63,7 @@ num-traits = { workspace = true } num_enum = { workspace = true } parking_lot = { workspace = true } paste = { workspace = true } +rapidhash = { workspace = true } scopeguard = { workspace = true } serde = { workspace = true, optional = true } static_assertions = { workspace = true } diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 8b920c2fee..fa673d4203 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -408,7 +408,8 @@ cfg_select! { /// For attributes we do not use a dict, but an IndexMap, which is an Hash Table /// that maintains order and is compatible with the standard HashMap This is probably /// faster and only supports strings as keys. -pub(crate) type PyAttributes = IndexMap<&'static PyStrInterned, PyObjectRef, ahash::RandomState>; +pub(crate) type PyAttributes = + IndexMap<&'static PyStrInterned, PyObjectRef, rapidhash::quality::RandomState>; unsafe impl Traverse for PyAttributes { fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { diff --git a/crates/vm/src/intern.rs b/crates/vm/src/intern.rs index da1d63f879..981732f2dd 100644 --- a/crates/vm/src/intern.rs +++ b/crates/vm/src/intern.rs @@ -11,7 +11,7 @@ use core::{borrow::Borrow, ops::Deref}; #[derive(Debug)] pub(crate) struct StringPool { - inner: PyRwLock>, + inner: PyRwLock>, } impl Default for StringPool { diff --git a/crates/vm/src/vm/interpreter.rs b/crates/vm/src/vm/interpreter.rs index 505986acae..bb0b80bb63 100644 --- a/crates/vm/src/vm/interpreter.rs +++ b/crates/vm/src/vm/interpreter.rs @@ -95,8 +95,11 @@ where } as usize); // Initialize frozen modules (core + user-provided) - let mut frozen: std::collections::HashMap<&'static str, FrozenModule, ahash::RandomState> = - core_frozen_inits().collect(); + let mut frozen: std::collections::HashMap< + &'static str, + FrozenModule, + rapidhash::quality::RandomState, + > = core_frozen_inits().collect(); frozen.extend(frozen_modules); // Create PyGlobalState diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index a99012cd14..ca89e91857 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -589,7 +589,7 @@ pub(crate) struct CallableCache { pub struct PyGlobalState { pub config: PyConfig, pub module_defs: BTreeMap<&'static str, &'static builtins::PyModuleDef>, - pub frozen: HashMap<&'static str, FrozenModule, ahash::RandomState>, + pub frozen: HashMap<&'static str, FrozenModule, rapidhash::quality::RandomState>, pub stacksize: AtomicCell, pub thread_count: AtomicCell, pub hash_secret: HashSecret, From a5775e0c07b8a8b15ef6bb0c5b10e8dc5919ffee Mon Sep 17 00:00:00 2001 From: James Clarke Date: Mon, 25 May 2026 05:55:32 +0100 Subject: [PATCH 18/18] Fix thread teardown panic when weakref callback fires during cleanup (#7965) --- .gitignore | 2 +- crates/vm/src/stdlib/_thread.rs | 31 ++++++++++++----- extra_tests/snippets/stdlib_threading.py | 43 ++++++++++++++++++++++++ extra_tests/snippets/test_threading.py | 24 ------------- 4 files changed, 66 insertions(+), 34 deletions(-) delete mode 100644 extra_tests/snippets/test_threading.py diff --git a/.gitignore b/.gitignore index 338a6437ca..b5887be53b 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ Lib/site-packages/* Lib/test/data/* !Lib/test/data/README cpython/ - +.claude/scheduled_tasks.lock \ No newline at end of file diff --git a/crates/vm/src/stdlib/_thread.rs b/crates/vm/src/stdlib/_thread.rs index b6fbf146a9..0af8d7add3 100644 --- a/crates/vm/src/stdlib/_thread.rs +++ b/crates/vm/src/stdlib/_thread.rs @@ -530,10 +530,16 @@ pub(crate) mod _thread { // Increment thread count when thread actually starts executing vm.state.thread_count.fetch_add(1); - match func.invoke(args, vm) { - Ok(_obj) => {} - Err(e) if e.fast_isinstance(vm.ctx.exceptions.system_exit) => {} - Err(exc) => { + // Inner scope: drop `func` (and its Python refs) before the thread + // slot is torn down below. Otherwise the parameter `func` would drop + // at end-of-function, after cleanup_current_thread_frames has cleared + // CURRENT_THREAD_SLOT, and a weakref callback fired during that drop + // would panic in push_thread_frame. + { + let func = func; + if let Err(exc) = func.invoke(args, vm) + && !exc.fast_isinstance(vm.ctx.exceptions.system_exit) + { vm.run_unraisable( exc, Some("Exception ignored in thread started by".to_owned()), @@ -1663,11 +1669,18 @@ pub(crate) mod _thread { // Increment thread count when thread actually starts executing vm_state.thread_count.fetch_add(1); - // Run the function - match func.invoke((), vm) { - Ok(_) => {} - Err(e) if e.fast_isinstance(vm.ctx.exceptions.system_exit) => {} - Err(exc) => { + // Inner scope: drop `func` (and its Python refs) before the + // outer scopeguard::defer tears down the thread slot. As a + // `move` closure capture, `func` would otherwise drop after + // all locals (including the scopeguard `_guard`), and a + // weakref callback fired during that drop would panic in + // push_thread_frame. + { + let func = func; + // Run the function + if let Err(exc) = func.invoke((), vm) + && !exc.fast_isinstance(vm.ctx.exceptions.system_exit) + { vm.run_unraisable( exc, Some("Exception ignored in thread started by".to_owned()), diff --git a/extra_tests/snippets/stdlib_threading.py b/extra_tests/snippets/stdlib_threading.py index f35d7e9d08..cb989e1fd3 100644 --- a/extra_tests/snippets/stdlib_threading.py +++ b/extra_tests/snippets/stdlib_threading.py @@ -1,6 +1,7 @@ import multiprocessing import os import threading +import time def import_in_thread(module_name): @@ -62,6 +63,48 @@ def start_fork_process_after_thread(): assert process.exitcode == 0, process.exitcode +def thread_join_ordering(): + output = [] + + def thread_function(name): + output.append((name, 0)) + time.sleep(2.0) + output.append((name, 1)) + + output.append((0, 0)) + x = threading.Thread(target=thread_function, args=(1,)) + output.append((0, 1)) + x.start() + output.append((0, 2)) + x.join() + output.append((0, 3)) + + assert len(output) == 6, output + # CPython has [(1, 0), (0, 2)] for the middle 2, but we have [(0, 2), (1, 0)] + # TODO: maybe fix this, if it turns out to be a problem? + # assert output == [(0, 0), (0, 1), (1, 0), (0, 2), (1, 1), (0, 3)] + + +def thread_exit_without_join(): + # Regression for https://github.com/RustPython/RustPython/issues/7813: + # a thread started without ``.join()`` must exit cleanly even when the + # captured target callable drops during teardown (which can fire + # weakref callbacks that re-enter the VM). + output = [] + + def runner(): + output.append("runner done") + + threading.Thread(target=runner).start() + time.sleep(1) + output.append("main done") + assert "runner done" in output, output + assert "main done" in output, output + + +thread_join_ordering() +thread_exit_without_join() + import_in_thread("functools") import_in_thread("tempfile") import_in_thread("multiprocessing.connection") diff --git a/extra_tests/snippets/test_threading.py b/extra_tests/snippets/test_threading.py deleted file mode 100644 index 4d7c29f509..0000000000 --- a/extra_tests/snippets/test_threading.py +++ /dev/null @@ -1,24 +0,0 @@ -import threading -import time - -output = [] - - -def thread_function(name): - output.append((name, 0)) - time.sleep(2.0) - output.append((name, 1)) - - -output.append((0, 0)) -x = threading.Thread(target=thread_function, args=(1,)) -output.append((0, 1)) -x.start() -output.append((0, 2)) -x.join() -output.append((0, 3)) - -assert len(output) == 6, output -# CPython has [(1, 0), (0, 2)] for the middle 2, but we have [(0, 2), (1, 0)] -# TODO: maybe fix this, if it turns out to be a problem? -# assert output == [(0, 0), (0, 1), (1, 0), (0, 2), (1, 1), (0, 3)]