diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 1a06b426f71..b60c7452f3f 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -2876,7 +2876,6 @@ def test_get_event_loop_after_set_none(self): policy.set_event_loop(None) self.assertRaises(RuntimeError, policy.get_event_loop) - @unittest.expectedFailure # TODO: RUSTPYTHON; - mock.patch doesn't work correctly with threading.current_thread @mock.patch('asyncio.events.threading.current_thread') def test_get_event_loop_thread(self, m_current_thread): diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 8db0bbdb949..17693ae093f 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -1162,8 +1162,6 @@ def import_threading(): self.assertEqual(out, b'') self.assertEqual(err, b'') - # TODO: RUSTPYTHON - __del__ not called during interpreter finalization (no cyclic GC) - @unittest.expectedFailure def test_start_new_thread_at_finalization(self): code = """if 1: import _thread diff --git a/crates/common/src/lock/thread_mutex.rs b/crates/common/src/lock/thread_mutex.rs index 5b5b89f4eb1..884556c4476 100644 --- a/crates/common/src/lock/thread_mutex.rs +++ b/crates/common/src/lock/thread_mutex.rs @@ -54,6 +54,18 @@ impl RawThreadMutex { .is_some() } + /// Like `lock()` but wraps the blocking wait in `wrap_fn`. + /// The caller can use this to detach thread state while waiting. + pub fn lock_wrapped(&self, wrap_fn: F) -> bool { + let id = self.get_thread_id.nonzero_thread_id().get(); + if self.owner.load(Ordering::Relaxed) == id { + return false; + } + wrap_fn(&|| self.mutex.lock()); + self.owner.store(id, Ordering::Relaxed); + true + } + /// Returns `Some(true)` if able to successfully lock without blocking, `Some(false)` /// otherwise, and `None` when the mutex is already locked on the current thread. pub fn try_lock(&self) -> Option { @@ -135,6 +147,23 @@ impl ThreadMutex { None } } + + /// Like `lock()` but wraps the blocking wait in `wrap_fn`. + /// The caller can use this to detach thread state while waiting. + pub fn lock_wrapped( + &self, + wrap_fn: F, + ) -> Option> { + if self.raw.lock_wrapped(wrap_fn) { + Some(ThreadMutexGuard { + mu: self, + marker: PhantomData, + }) + } else { + None + } + } + pub fn try_lock(&self) -> Result, TryLockThreadError> { match self.raw.try_lock() { Some(true) => Ok(ThreadMutexGuard { diff --git a/crates/stdlib/src/multiprocessing.rs b/crates/stdlib/src/multiprocessing.rs index cab0fc4c159..64049df5599 100644 --- a/crates/stdlib/src/multiprocessing.rs +++ b/crates/stdlib/src/multiprocessing.rs @@ -484,7 +484,7 @@ mod _multiprocessing { tv_sec: (delay / 1_000_000) as _, tv_usec: (delay % 1_000_000) as _, }; - unsafe { + vm.allow_threads(|| unsafe { libc::select( 0, core::ptr::null_mut(), @@ -492,7 +492,7 @@ mod _multiprocessing { core::ptr::null_mut(), &mut tv_delay, ) - }; + }); // check for signals - preserve the exception (e.g., KeyboardInterrupt) if let Err(exc) = vm.check_signals() { @@ -710,13 +710,13 @@ mod _multiprocessing { #[cfg(not(target_vendor = "apple"))] { loop { + let sem_ptr = self.handle.as_ptr(); // Py_BEGIN_ALLOW_THREADS / Py_END_ALLOW_THREADS - // RustPython doesn't have GIL, so we just do the wait - if let Some(ref dl) = deadline { - res = unsafe { libc::sem_timedwait(self.handle.as_ptr(), dl) }; + res = if let Some(ref dl) = deadline { + vm.allow_threads(|| unsafe { libc::sem_timedwait(sem_ptr, dl) }) } else { - res = unsafe { libc::sem_wait(self.handle.as_ptr()) }; - } + vm.allow_threads(|| unsafe { libc::sem_wait(sem_ptr) }) + }; if res >= 0 { break; @@ -750,7 +750,8 @@ mod _multiprocessing { } else { // No timeout: use sem_wait (available on macOS) loop { - res = unsafe { libc::sem_wait(self.handle.as_ptr()) }; + let sem_ptr = self.handle.as_ptr(); + res = vm.allow_threads(|| unsafe { libc::sem_wait(sem_ptr) }); if res >= 0 { break; } diff --git a/crates/stdlib/src/select.rs b/crates/stdlib/src/select.rs index 05d9eaa550d..aeb2271735b 100644 --- a/crates/stdlib/src/select.rs +++ b/crates/stdlib/src/select.rs @@ -280,9 +280,7 @@ mod decl { loop { let mut tv = timeout.map(sec_to_timeval); - let res = vm.allow_threads(|| { - super::select(nfds, &mut r, &mut w, &mut x, tv.as_mut()) - }); + let res = vm.allow_threads(|| super::select(nfds, &mut r, &mut w, &mut x, tv.as_mut())); match res { Ok(_) => break, @@ -504,7 +502,9 @@ mod decl { let deadline = timeout.map(|d| Instant::now() + d); let mut poll_timeout = timeout_ms; loop { - let res = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as _, poll_timeout) }; + let res = vm.allow_threads(|| unsafe { + libc::poll(fds.as_mut_ptr(), fds.len() as _, poll_timeout) + }); match nix::Error::result(res) { Ok(_) => break, Err(nix::Error::EINTR) => vm.check_signals()?, @@ -697,11 +697,13 @@ mod decl { loop { events.clear(); - match epoll::wait( - epoll, - rustix::buffer::spare_capacity(&mut events), - poll_timeout.as_ref(), - ) { + match vm.allow_threads(|| { + epoll::wait( + epoll, + rustix::buffer::spare_capacity(&mut events), + poll_timeout.as_ref(), + ) + }) { Ok(_) => break, Err(rustix::io::Errno::INTR) => vm.check_signals()?, Err(e) => return Err(e.into_pyexception(vm)), diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 69e4c062994..81a4c5d5683 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -1199,7 +1199,9 @@ impl ExecutingFrame<'_> { } } - if let Err(exception) = vm.check_signals() { + if vm.eval_breaker_tripped() + && let Err(exception) = vm.check_signals() + { #[cold] fn handle_signal_exception( frame: &mut ExecutingFrame<'_>, diff --git a/crates/vm/src/signal.rs b/crates/vm/src/signal.rs index 87c4fe2749f..177ba06b84d 100644 --- a/crates/vm/src/signal.rs +++ b/crates/vm/src/signal.rs @@ -91,6 +91,11 @@ pub(crate) fn set_triggered() { ANY_TRIGGERED.store(true, Ordering::Release); } +#[inline(always)] +pub(crate) fn is_triggered() -> bool { + ANY_TRIGGERED.load(Ordering::Relaxed) +} + /// Reset all signal trigger state after fork in child process. /// Stale triggers from the parent must not fire in the child. #[cfg(unix)] diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index a313b3d98df..dadde9e8e32 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -1580,7 +1580,7 @@ mod _io { fn lock(&self, vm: &VirtualMachine) -> PyResult> { self.data() - .lock() + .lock_wrapped(|do_lock| vm.allow_threads(do_lock)) .ok_or_else(|| vm.new_runtime_error("reentrant call inside buffered io")) } @@ -2812,7 +2812,7 @@ mod _io { vm: &VirtualMachine, ) -> PyResult>> { self.data - .lock() + .lock_wrapped(|do_lock| vm.allow_threads(do_lock)) .ok_or_else(|| vm.new_runtime_error("reentrant call inside textio")) } @@ -4158,7 +4158,7 @@ mod _io { vm: &VirtualMachine, ) -> PyResult>> { self.data - .lock() + .lock_wrapped(|do_lock| vm.allow_threads(do_lock)) .ok_or_else(|| vm.new_runtime_error("reentrant call inside nldecoder")) } @@ -5336,7 +5336,7 @@ mod fileio { types::{Constructor, DefaultConstructor, Destructor, Initializer, Representable}, }; use crossbeam_utils::atomic::AtomicCell; - use std::io::{Read, Write}; + use std::io::Read; bitflags::bitflags! { #[derive(Copy, Clone, Debug, PartialEq)] @@ -5740,12 +5740,12 @@ mod fileio { "File or stream is not readable".to_owned(), )); } - let mut handle = zelf.get_fd(vm)?; + let handle = zelf.get_fd(vm)?; let bytes = if let Some(read_byte) = read_byte.to_usize() { let mut bytes = vec![0; read_byte]; // Loop on EINTR (PEP 475) let n = loop { - match handle.read(&mut bytes) { + match vm.allow_threads(|| crt_fd::read(handle, &mut bytes)) { Ok(n) => break n, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -5764,7 +5764,10 @@ mod fileio { let mut bytes = vec![]; // Loop on EINTR (PEP 475) loop { - match handle.read_to_end(&mut bytes) { + match vm.allow_threads(|| { + let mut h = handle; + h.read_to_end(&mut bytes) + }) { Ok(_) => break, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -5802,10 +5805,9 @@ mod fileio { let handle = zelf.get_fd(vm)?; let mut buf = obj.borrow_buf_mut(); - let mut f = handle.take(buf.len() as _); // Loop on EINTR (PEP 475) let ret = loop { - match f.read(&mut buf) { + match vm.allow_threads(|| crt_fd::read(handle, &mut buf)) { Ok(n) => break n, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -5835,11 +5837,11 @@ mod fileio { )); } - let mut handle = zelf.get_fd(vm)?; + let handle = zelf.get_fd(vm)?; // Loop on EINTR (PEP 475) let len = loop { - match obj.with_ref(|b| handle.write(b)) { + match obj.with_ref(|b| vm.allow_threads(|| crt_fd::write(handle, b))) { Ok(n) => break n, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index 6f2342ec3e3..5ecc72c7087 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -872,57 +872,129 @@ pub mod module { run_at_forkers(after_forkers_parent, false, vm); } - /// Warn if forking from a multi-threaded process - fn warn_if_multi_threaded(name: &str, vm: &VirtualMachine) { - // Only check threading if it was already imported - // Avoid vm.import() which can execute arbitrary Python code in the fork path - let threading = match vm - .sys_module - .get_attr("modules", vm) - .and_then(|m| m.get_item("threading", vm)) + /// Best-effort number of OS threads in this process. + /// Returns <= 0 when unavailable, mirroring CPython fallback behavior. + fn get_number_of_os_threads() -> isize { + #[cfg(target_os = "macos")] { - Ok(m) => m, - Err(_) => return, - }; - let active = threading.get_attr("_active", vm).ok(); - let limbo = threading.get_attr("_limbo", vm).ok(); + type MachPortT = libc::c_uint; + type KernReturnT = libc::c_int; + type MachMsgTypeNumberT = libc::c_uint; + type ThreadActArrayT = *mut MachPortT; + const KERN_SUCCESS: KernReturnT = 0; + unsafe extern "C" { + fn mach_task_self() -> MachPortT; + fn task_for_pid( + task: MachPortT, + pid: libc::c_int, + target_task: *mut MachPortT, + ) -> KernReturnT; + fn task_threads( + target_task: MachPortT, + act_list: *mut ThreadActArrayT, + act_list_cnt: *mut MachMsgTypeNumberT, + ) -> KernReturnT; + fn vm_deallocate( + target_task: MachPortT, + address: libc::uintptr_t, + size: libc::uintptr_t, + ) -> KernReturnT; + } - let count_dict = |obj: Option| -> usize { - obj.and_then(|o| o.length_opt(vm)) - .and_then(|r| r.ok()) - .unwrap_or(0) - }; + let self_task = unsafe { mach_task_self() }; + let mut proc_task: MachPortT = 0; + if unsafe { task_for_pid(self_task, libc::getpid(), &mut proc_task) } == KERN_SUCCESS { + let mut threads: ThreadActArrayT = core::ptr::null_mut(); + let mut n_threads: MachMsgTypeNumberT = 0; + if unsafe { task_threads(proc_task, &mut threads, &mut n_threads) } == KERN_SUCCESS + { + if !threads.is_null() { + let _ = unsafe { + vm_deallocate( + self_task, + threads as libc::uintptr_t, + (n_threads as usize * core::mem::size_of::()) + as libc::uintptr_t, + ) + }; + } + return n_threads as isize; + } + } + 0 + } + #[cfg(target_os = "linux")] + { + use std::io::Read as _; + let mut file = match std::fs::File::open("/proc/self/stat") { + Ok(f) => f, + Err(_) => return 0, + }; + let mut buf = [0u8; 160]; + let n = match file.read(&mut buf) { + Ok(n) => n, + Err(_) => return 0, + }; + let line = match core::str::from_utf8(&buf[..n]) { + Ok(s) => s, + Err(_) => return 0, + }; + if let Some(field) = line.split_whitespace().nth(19) { + return field.parse::().unwrap_or(0); + } + 0 + } + #[cfg(not(any(target_os = "macos", target_os = "linux")))] + { + 0 + } + } - let num_threads = count_dict(active) + count_dict(limbo); - if num_threads > 1 { - // Use Python warnings module to ensure filters are applied correctly - let Ok(warnings) = vm.import("warnings", 0) else { - return; + /// Warn if forking from a multi-threaded process. + /// `num_os_threads` should be captured before parent after-fork hooks run. + fn warn_if_multi_threaded(name: &str, num_os_threads: isize, vm: &VirtualMachine) { + let num_threads = if num_os_threads > 0 { + num_os_threads as usize + } else { + // CPython fallback: if OS-level count isn't available, use the + // threading module's active+limbo view. + // Only check threading if it was already imported. Avoid vm.import() + // which can execute arbitrary Python code in the fork path. + let threading = match vm + .sys_module + .get_attr("modules", vm) + .and_then(|m| m.get_item("threading", vm)) + { + Ok(m) => m, + Err(_) => return, }; - let Ok(warn_fn) = warnings.get_attr("warn", vm) else { - return; + let active = threading.get_attr("_active", vm).ok(); + let limbo = threading.get_attr("_limbo", vm).ok(); + + // Match threading module internals and avoid sequence overcounting: + // count only dict-backed _active/_limbo containers. + let count_dict = |obj: Option| -> usize { + obj.and_then(|o| { + o.downcast_ref::() + .map(|d| d.__len__()) + }) + .unwrap_or(0) }; + count_dict(active) + count_dict(limbo) + }; + + if num_threads > 1 { let pid = unsafe { libc::getpid() }; let msg = format!( "This process (pid={}) is multi-threaded, use of {}() may lead to deadlocks in the child.", pid, name ); - // Call warnings.warn(message, DeprecationWarning, stacklevel=2) - // stacklevel=2 to point to the caller of fork() - let args = crate::function::FuncArgs::new( - vec![ - vm.ctx.new_str(msg).into(), - vm.ctx.exceptions.deprecation_warning.as_object().to_owned(), - ], - crate::function::KwArgs::new( - [("stacklevel".to_owned(), vm.ctx.new_int(2).into())] - .into_iter() - .collect(), - ), - ); - let _ = warn_fn.call(args, vm); + // Match PyErr_WarnFormat(..., stacklevel=1) in CPython. + // Best effort: ignore failures like CPython does in this path. + let _ = + crate::stdlib::warnings::warn(vm.ctx.exceptions.deprecation_warning, msg, 1, vm); } } @@ -953,9 +1025,12 @@ pub mod module { if pid == 0 { py_os_after_fork_child(vm); } else { + // Match CPython timing: capture this before parent after-fork hooks + // in case those hooks start threads. + let num_os_threads = get_number_of_os_threads(); py_os_after_fork_parent(vm); // Match CPython timing: warn only after parent callback path resumes world. - warn_if_multi_threaded("fork", vm); + warn_if_multi_threaded("fork", num_os_threads, vm); } if pid == -1 { Err(nix::Error::from_raw(saved_errno).into_pyexception(vm)) diff --git a/crates/vm/src/stdlib/thread.rs b/crates/vm/src/stdlib/thread.rs index 38abc47138d..f6849a11696 100644 --- a/crates/vm/src/stdlib/thread.rs +++ b/crates/vm/src/stdlib/thread.rs @@ -27,6 +27,7 @@ pub(crate) mod _thread { RawMutex, RawThreadId, lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, }; + use rustpython_common::str::levenshtein::{MOVE_COST, levenshtein_distance}; use std::thread; // PYTHREAD_NAME: show current thread name @@ -296,7 +297,7 @@ pub(crate) mod _thread { if count == 0 { return Ok(()); } - self.mu.lock(); + vm.allow_threads(|| self.mu.lock()); self.count .store(count, core::sync::atomic::Ordering::Relaxed); Ok(()) @@ -377,8 +378,8 @@ pub(crate) mod _thread { vm, )?; d.set_item( - "detach_wait_yields", - vm.ctx.new_int(stats.detach_wait_yields).into(), + "world_stopped", + vm.ctx.new_bool(stats.world_stopped).into(), vm, )?; Ok(d) @@ -435,7 +436,7 @@ pub(crate) mod _thread { /// This is important for fork compatibility - the ID must remain stable after fork #[cfg(unix)] fn current_thread_id() -> u64 { - // pthread_self() like CPython for fork compatibility + // pthread_self() for fork compatibility unsafe { libc::pthread_self() as u64 } } @@ -487,12 +488,68 @@ pub(crate) mod _thread { } #[pyfunction] - fn start_new_thread( - func: ArgCallable, - args: PyTupleRef, - kwargs: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { + fn start_new_thread(mut f_args: FuncArgs, vm: &VirtualMachine) -> PyResult { + if !f_args.kwargs.is_empty() { + return Err(vm.new_type_error("start_new_thread() takes no keyword arguments")); + } + let given = f_args.args.len(); + if given < 2 { + return Err(vm.new_type_error(format!( + "start_new_thread expected at least 2 arguments, got {given}" + ))); + } + if given > 3 { + return Err(vm.new_type_error(format!( + "start_new_thread expected at most 3 arguments, got {given}" + ))); + } + + let func_obj = f_args.take_positional().unwrap(); + let args_obj = f_args.take_positional().unwrap(); + let kwargs_obj = f_args.take_positional(); + + if func_obj.to_callable().is_none() { + return Err(vm.new_type_error("first arg must be callable")); + } + if !args_obj.fast_isinstance(vm.ctx.types.tuple_type) { + return Err(vm.new_type_error("2nd arg must be a tuple")); + } + if kwargs_obj + .as_ref() + .is_some_and(|obj| !obj.fast_isinstance(vm.ctx.types.dict_type)) + { + return Err(vm.new_type_error("optional 3rd arg must be a dictionary")); + } + + let func: ArgCallable = func_obj.clone().try_into_value(vm)?; + let args: PyTupleRef = args_obj.clone().try_into_value(vm)?; + let kwargs: Option = kwargs_obj.map(|obj| obj.try_into_value(vm)).transpose()?; + + vm.sys_module.get_attr("audit", vm)?.call( + ( + "_thread.start_new_thread", + func_obj, + args_obj, + kwargs + .as_ref() + .map_or_else(|| vm.ctx.none(), |k| k.clone().into()), + ), + vm, + )?; + + if vm + .state + .finalizing + .load(core::sync::atomic::Ordering::Acquire) + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.python_finalization_error.to_owned(), + "can't create new thread at interpreter shutdown" + .to_owned() + .into(), + )); + } + let args = FuncArgs::new( args.to_vec(), kwargs @@ -512,7 +569,7 @@ pub(crate) mod _thread { .make_spawn_func(move |vm| run_thread(func, args, vm)), ) .map(|handle| thread_to_id(&handle)) - .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}"))) + .map_err(|_err| vm.new_runtime_error("can't start new thread")) } fn run_thread(func: ArgCallable, args: FuncArgs, vm: &VirtualMachine) { @@ -630,14 +687,17 @@ pub(crate) mod _thread { }; match handle_to_join { - Some((_, done_event)) => { - // Wait for this thread to finish (infinite timeout) - // Only check done flag to avoid lock ordering issues - // (done_event lock vs inner lock) - let (lock, cvar) = &*done_event; - let mut done = lock.lock(); - while !*done { - vm.allow_threads(|| cvar.wait(&mut done)); + Some((inner, done_event)) => { + if let Err(exc) = ThreadHandle::join_internal(&inner, &done_event, None, vm) { + vm.run_unraisable( + exc, + Some( + "Exception ignored while joining a thread in _thread._shutdown()" + .to_owned(), + ), + vm.ctx.none(), + ); + return; } } None => break, // No more threads to wait on @@ -655,6 +715,24 @@ pub(crate) mod _thread { handles.push((Arc::downgrade(inner), Arc::downgrade(done_event))); } + fn remove_from_shutdown_handles( + vm: &VirtualMachine, + inner: &Arc>, + done_event: &Arc<(parking_lot::Mutex, parking_lot::Condvar)>, + ) { + let mut handles = vm.state.shutdown_handles.lock(); + handles.retain(|(inner_weak, done_event_weak): &ShutdownEntry| { + let Some(registered_inner) = inner_weak.upgrade() else { + return false; + }; + let Some(registered_done_event) = done_event_weak.upgrade() else { + return false; + }; + !(Arc::ptr_eq(®istered_inner, inner) + && Arc::ptr_eq(®istered_done_event, done_event)) + }); + } + #[pyfunction] fn _make_thread_handle(ident: u64, vm: &VirtualMachine) -> PyRef { let handle = ThreadHandle::new(vm); @@ -1108,7 +1186,7 @@ pub(crate) mod _thread { done_event: Arc<(parking_lot::Mutex, parking_lot::Condvar)>, } - #[pyclass] + #[pyclass(with(Representable))] impl ThreadHandle { fn new(vm: &VirtualMachine) -> Self { let inner = Arc::new(parking_lot::Mutex::new(ThreadHandleInner { @@ -1130,55 +1208,55 @@ pub(crate) mod _thread { Self { inner, done_event } } - #[pygetset] - fn ident(&self) -> u64 { - self.inner.lock().ident - } - - #[pymethod] - fn is_done(&self) -> bool { - self.inner.lock().state == ThreadHandleState::Done - } - - #[pymethod] - fn _set_done(&self) { - self.inner.lock().state = ThreadHandleState::Done; - // Signal waiting threads that this thread is done - let (lock, cvar) = &*self.done_event; - *lock.lock() = true; - cvar.notify_all(); - } - - #[pymethod] - fn join( - &self, - timeout: OptionalArg>>, + fn join_internal( + inner: &Arc>, + done_event: &Arc<(parking_lot::Mutex, parking_lot::Condvar)>, + timeout_duration: Option, vm: &VirtualMachine, ) -> PyResult<()> { - // Convert timeout to Duration (None or negative = infinite wait) - let timeout_duration = match timeout.flatten() { - Some(Either::A(t)) if t >= 0.0 => Some(Duration::from_secs_f64(t)), - Some(Either::B(t)) if t >= 0 => Some(Duration::from_secs(t as u64)), - _ => None, - }; + Self::check_started(inner, vm)?; - // Check for self-join first - { - let inner = self.inner.lock(); - let current_ident = get_ident(); - if inner.ident == current_ident && inner.state == ThreadHandleState::Running { - return Err(vm.new_runtime_error("cannot join current thread")); - } - } + let deadline = + timeout_duration.and_then(|timeout| std::time::Instant::now().checked_add(timeout)); // Wait for thread completion using Condvar (supports timeout) // Loop to handle spurious wakeups - let (lock, cvar) = &*self.done_event; + let (lock, cvar) = &**done_event; let mut done = lock.lock(); + // ThreadHandle_join semantics: self-join/finalizing checks + // apply only while target thread has not reported it is exiting yet. + if !*done { + let inner_guard = inner.lock(); + let current_ident = get_ident(); + if inner_guard.ident == current_ident + && inner_guard.state == ThreadHandleState::Running + { + return Err(vm.new_runtime_error("Cannot join current thread")); + } + if vm + .state + .finalizing + .load(core::sync::atomic::Ordering::Acquire) + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.python_finalization_error.to_owned(), + "cannot join thread at interpreter shutdown" + .to_owned() + .into(), + )); + } + } + while !*done { if let Some(timeout) = timeout_duration { - let result = vm.allow_threads(|| cvar.wait_for(&mut done, timeout)); + let remaining = deadline.map_or(timeout, |deadline| { + deadline.saturating_duration_since(std::time::Instant::now()) + }); + if remaining.is_zero() { + return Ok(()); + } + let result = vm.allow_threads(|| cvar.wait_for(&mut done, remaining)); if result.timed_out() && !*done { // Timeout occurred and done is still false return Ok(()); @@ -1192,18 +1270,18 @@ pub(crate) mod _thread { // Thread is done, now perform cleanup let join_handle = { - let mut inner = self.inner.lock(); + let mut inner_guard = inner.lock(); // If already joined, return immediately (idempotent) - if inner.joined { + if inner_guard.joined { return Ok(()); } // If another thread is already joining, wait for them to finish - if inner.joining { - drop(inner); + if inner_guard.joining { + drop(inner_guard); // Wait on done_event - let (lock, cvar) = &*self.done_event; + let (lock, cvar) = &**done_event; let mut done = lock.lock(); while !*done { vm.allow_threads(|| cvar.wait(&mut done)); @@ -1212,10 +1290,10 @@ pub(crate) mod _thread { } // Mark that we're joining - inner.joining = true; + inner_guard.joining = true; // Take the join handle if available - inner.join_handle.take() + inner_guard.join_handle.take() }; // Perform the actual join outside the lock @@ -1226,14 +1304,158 @@ pub(crate) mod _thread { // Mark as joined and clear joining flag { - let mut inner = self.inner.lock(); - inner.joined = true; - inner.joining = false; + let mut inner_guard = inner.lock(); + inner_guard.joined = true; + inner_guard.joining = false; } Ok(()) } + fn check_started( + inner: &Arc>, + vm: &VirtualMachine, + ) -> PyResult<()> { + let state = inner.lock().state; + if matches!( + state, + ThreadHandleState::NotStarted | ThreadHandleState::Starting + ) { + return Err(vm.new_runtime_error("thread not started")); + } + Ok(()) + } + + fn set_done_internal( + inner: &Arc>, + done_event: &Arc<(parking_lot::Mutex, parking_lot::Condvar)>, + vm: &VirtualMachine, + ) -> PyResult<()> { + Self::check_started(inner, vm)?; + { + let mut inner_guard = inner.lock(); + inner_guard.state = ThreadHandleState::Done; + // _set_done() detach path. Dropping the JoinHandle + // detaches the underlying Rust thread. + inner_guard.join_handle = None; + inner_guard.joining = false; + inner_guard.joined = true; + } + remove_from_shutdown_handles(vm, inner, done_event); + + let (lock, cvar) = &**done_event; + *lock.lock() = true; + cvar.notify_all(); + Ok(()) + } + + fn parse_join_timeout( + timeout_obj: Option, + vm: &VirtualMachine, + ) -> PyResult> { + const JOIN_TIMEOUT_MAX_SECONDS: i64 = TIMEOUT_MAX_IN_MICROSECONDS / 1_000_000; + let Some(timeout_obj) = timeout_obj else { + return Ok(None); + }; + + if let Some(t) = timeout_obj.try_index_opt(vm) { + let t: i64 = t?.try_to_primitive(vm).map_err(|_| { + vm.new_overflow_error("timestamp too large to convert to C PyTime_t") + })?; + if !(-JOIN_TIMEOUT_MAX_SECONDS..=JOIN_TIMEOUT_MAX_SECONDS).contains(&t) { + return Err( + vm.new_overflow_error("timestamp too large to convert to C PyTime_t") + ); + } + if t < 0 { + return Ok(None); + } + return Ok(Some(Duration::from_secs(t as u64))); + } + + if let Some(t) = timeout_obj.try_float_opt(vm) { + let t = t?.to_f64(); + if t.is_nan() { + return Err(vm.new_value_error("Invalid value NaN (not a number)")); + } + if !t.is_finite() || !(-TIMEOUT_MAX..=TIMEOUT_MAX).contains(&t) { + return Err(vm.new_overflow_error("timestamp out of range for platform time_t")); + } + if t < 0.0 { + return Ok(None); + } + return Ok(Some(Duration::from_secs_f64(t))); + } + + Err(vm.new_type_error(format!( + "'{}' object cannot be interpreted as an integer or float", + timeout_obj.class().name() + ))) + } + + #[pygetset] + fn ident(&self) -> u64 { + self.inner.lock().ident + } + + #[pymethod] + fn is_done(&self, f_args: FuncArgs, vm: &VirtualMachine) -> PyResult { + if !f_args.kwargs.is_empty() { + return Err(vm.new_type_error("_ThreadHandle.is_done() takes no keyword arguments")); + } + let given = f_args.args.len(); + if given != 0 { + return Err(vm.new_type_error(format!( + "_ThreadHandle.is_done() takes no arguments ({given} given)" + ))); + } + + // If completion was observed, perform one-time join cleanup + // before returning True. + let done = { + let (lock, _) = &*self.done_event; + *lock.lock() + }; + if !done { + return Ok(false); + } + Self::join_internal(&self.inner, &self.done_event, Some(Duration::ZERO), vm)?; + Ok(true) + } + + #[pymethod] + fn _set_done(&self, f_args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + if !f_args.kwargs.is_empty() { + return Err( + vm.new_type_error("_ThreadHandle._set_done() takes no keyword arguments") + ); + } + let given = f_args.args.len(); + if given != 0 { + return Err(vm.new_type_error(format!( + "_ThreadHandle._set_done() takes no arguments ({given} given)" + ))); + } + + Self::set_done_internal(&self.inner, &self.done_event, vm) + } + + #[pymethod] + fn join(&self, mut f_args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + if !f_args.kwargs.is_empty() { + return Err(vm.new_type_error("_ThreadHandle.join() takes no keyword arguments")); + } + let given = f_args.args.len(); + if given > 1 { + return Err( + vm.new_type_error(format!("join() takes at most 1 argument ({given} given)")) + ); + } + let timeout = f_args.take_positional().filter(|obj| !vm.is_none(obj)); + let timeout_duration = Self::parse_join_timeout(timeout, vm)?; + Self::join_internal(&self.inner, &self.done_event, timeout_duration, vm) + } + #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { ThreadHandle::new(vm) @@ -1242,38 +1464,174 @@ pub(crate) mod _thread { } } - #[derive(FromArgs)] - struct StartJoinableThreadArgs { - #[pyarg(positional)] - function: ArgCallable, - #[pyarg(any, optional)] - handle: OptionalArg>, - #[pyarg(any, default = true)] - daemon: bool, + impl Representable for ThreadHandle { + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + let ident = zelf.inner.lock().ident; + Ok(format!( + "<{} object: ident={ident}>", + zelf.class().slot_name() + )) + } } #[pyfunction] fn start_joinable_thread( - args: StartJoinableThreadArgs, + mut f_args: FuncArgs, vm: &VirtualMachine, ) -> PyResult> { - let handle = match args.handle { - OptionalArg::Present(h) => h, - OptionalArg::Missing => ThreadHandle::new(vm).into_ref(&vm.ctx), + let given = f_args.args.len() + f_args.kwargs.len(); + if given > 3 { + return Err(vm.new_type_error(format!( + "start_joinable_thread() takes at most 3 arguments ({given} given)" + ))); + } + + let function_pos = f_args.take_positional(); + let function_kw = f_args.take_keyword("function"); + if function_pos.is_some() && function_kw.is_some() { + return Err(vm.new_type_error( + "argument for start_joinable_thread() given by name ('function') and position (1)", + )); + } + let Some(function_obj) = function_pos.or(function_kw) else { + return Err(vm.new_type_error( + "start_joinable_thread() missing required argument 'function' (pos 1)", + )); + }; + + let handle_pos = f_args.take_positional(); + let handle_kw = f_args.take_keyword("handle"); + if handle_pos.is_some() && handle_kw.is_some() { + return Err(vm.new_type_error( + "argument for start_joinable_thread() given by name ('handle') and position (2)", + )); + } + let handle_obj = handle_pos.or(handle_kw); + + let daemon_pos = f_args.take_positional(); + let daemon_kw = f_args.take_keyword("daemon"); + if daemon_pos.is_some() && daemon_kw.is_some() { + return Err(vm.new_type_error( + "argument for start_joinable_thread() given by name ('daemon') and position (3)", + )); + } + let daemon = daemon_pos + .or(daemon_kw) + .map_or(Ok(true), |obj| obj.try_to_bool(vm))?; + + // Match CPython parser precedence: + // - required positional/keyword argument errors are raised before + // unknown keyword errors when `function` is missing. + if let Some(unexpected) = f_args.kwargs.keys().next() { + let suggestion = ["function", "handle", "daemon"] + .iter() + .filter_map(|candidate| { + let max_distance = (unexpected.len() + candidate.len() + 3) * MOVE_COST / 6; + let distance = levenshtein_distance( + unexpected.as_bytes(), + candidate.as_bytes(), + max_distance, + ); + (distance <= max_distance).then_some((distance, *candidate)) + }) + .min_by_key(|(distance, _)| *distance) + .map(|(_, candidate)| candidate); + let msg = if let Some(suggestion) = suggestion { + format!( + "start_joinable_thread() got an unexpected keyword argument '{unexpected}'. Did you mean '{suggestion}'?" + ) + } else { + format!("start_joinable_thread() got an unexpected keyword argument '{unexpected}'") + }; + return Err(vm.new_type_error(msg)); + } + + if function_obj.to_callable().is_none() { + return Err(vm.new_type_error("thread function must be callable")); + } + let function: ArgCallable = function_obj.clone().try_into_value(vm)?; + + let thread_handle_type = ThreadHandle::class(&vm.ctx); + let handle = if let Some(handle_obj) = handle_obj { + if vm.is_none(&handle_obj) { + None + } else if !handle_obj.class().is(thread_handle_type) { + return Err(vm.new_type_error("'handle' must be a _ThreadHandle")); + } else { + Some( + handle_obj + .downcast::() + .map_err(|_| vm.new_type_error("'handle' must be a _ThreadHandle"))?, + ) + } + } else { + None }; - // Mark as starting - handle.inner.lock().state = ThreadHandleState::Starting; + vm.sys_module.get_attr("audit", vm)?.call( + ( + "_thread.start_joinable_thread", + function_obj, + daemon, + handle + .as_ref() + .map_or_else(|| vm.ctx.none(), |h| h.clone().into()), + ), + vm, + )?; + + if vm + .state + .finalizing + .load(core::sync::atomic::Ordering::Acquire) + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.python_finalization_error.to_owned(), + "can't create new thread at interpreter shutdown" + .to_owned() + .into(), + )); + } + + let handle = match handle { + Some(h) => h, + None => ThreadHandle::new(vm).into_ref(&vm.ctx), + }; + + // Must only start once (ThreadHandle_start). + { + let mut inner = handle.inner.lock(); + if inner.state != ThreadHandleState::NotStarted { + return Err(vm.new_runtime_error("thread already started")); + } + inner.state = ThreadHandleState::Starting; + inner.ident = 0; + inner.join_handle = None; + inner.joining = false; + inner.joined = false; + } + // Starting a handle always resets the completion event. + { + let (done_lock, _) = &*handle.done_event; + *done_lock.lock() = false; + } // Add non-daemon threads to shutdown registry so _shutdown() will wait for them - if !args.daemon { + if !daemon { add_to_shutdown_handles(vm, &handle.inner, &handle.done_event); } - let func = args.function; + let func = function; let handle_clone = handle.clone(); let inner_clone = handle.inner.clone(); let done_event_clone = handle.done_event.clone(); + // Use std::sync (pthread-based) instead of parking_lot for these + // events so they remain fork-safe without the parking_lot_core patch. + let started_event = Arc::new((std::sync::Mutex::new(false), std::sync::Condvar::new())); + let started_event_clone = Arc::clone(&started_event); + let handle_ready_event = + Arc::new((std::sync::Mutex::new(false), std::sync::Condvar::new())); + let handle_ready_event_clone = Arc::clone(&handle_ready_event); let mut thread_builder = thread::Builder::new(); let stacksize = vm.state.stacksize.load(); @@ -1283,11 +1641,27 @@ pub(crate) mod _thread { let join_handle = thread_builder .spawn(vm.new_thread().make_spawn_func(move |vm| { - // Set ident and mark as running + // Publish ident for the parent starter thread. + { + inner_clone.lock().ident = get_ident(); + } { - let mut inner = inner_clone.lock(); - inner.ident = get_ident(); - inner.state = ThreadHandleState::Running; + let (started_lock, started_cvar) = &*started_event_clone; + *started_lock.lock().unwrap() = true; + started_cvar.notify_all(); + } + // Don't execute the target function until parent marks the + // handle as running. + { + let (ready_lock, ready_cvar) = &*handle_ready_event_clone; + let mut ready = ready_lock.lock().unwrap(); + while !*ready { + // Short timeout so we stay responsive to STW requests. + let (guard, _) = ready_cvar + .wait_timeout(ready, core::time::Duration::from_millis(1)) + .unwrap(); + ready = guard; + } } // Ensure cleanup happens even if the function panics @@ -1313,6 +1687,9 @@ pub(crate) mod _thread { vm_state.thread_count.fetch_sub(1); + // The runtime no longer needs to wait for this thread. + remove_from_shutdown_handles(vm, &inner_for_cleanup, &done_event_for_cleanup); + // Signal waiting threads that this thread is done // This must be LAST to ensure all cleanup is complete before join() returns { @@ -1338,10 +1715,52 @@ pub(crate) mod _thread { } } })) - .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}")))?; + .map_err(|_err| { + // force_done + remove_from_shutdown_handles on start failure. + { + let mut inner = handle.inner.lock(); + inner.state = ThreadHandleState::Done; + inner.join_handle = None; + inner.joining = false; + inner.joined = true; + } + { + let (done_lock, done_cvar) = &*handle.done_event; + *done_lock.lock() = true; + done_cvar.notify_all(); + } + if !daemon { + remove_from_shutdown_handles(vm, &handle.inner, &handle.done_event); + } + vm.new_runtime_error("can't start new thread") + })?; - // Store the join handle - handle.inner.lock().join_handle = Some(join_handle); + // Wait until the new thread has reported its ident. + { + let (started_lock, started_cvar) = &*started_event; + let mut started = started_lock.lock().unwrap(); + while !*started { + let (guard, _) = started_cvar + .wait_timeout(started, core::time::Duration::from_millis(1)) + .unwrap(); + started = guard; + } + } + + // Mark the handle running in the parent thread (like CPython's + // ThreadHandle_start sets THREAD_HANDLE_RUNNING after spawn succeeds). + { + let mut inner = handle.inner.lock(); + inner.join_handle = Some(join_handle); + inner.state = ThreadHandleState::Running; + } + + // Unblock the started thread once handle state is fully published. + { + let (ready_lock, ready_cvar) = &*handle_ready_event; + *ready_lock.lock().unwrap() = true; + ready_cvar.notify_all(); + } Ok(handle_clone) } diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index 502d892d895..6040b0b6f39 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -40,6 +40,8 @@ use crate::{ warn::WarningsState, }; use alloc::{borrow::Cow, collections::BTreeMap}; +#[cfg(all(unix, feature = "threading"))] +use core::sync::atomic::AtomicI64; use core::{ cell::{Cell, OnceCell, RefCell}, ptr::NonNull, @@ -131,11 +133,15 @@ struct ExceptionStack { pub struct StopTheWorldState { /// Fast-path flag checked in the bytecode loop (like `_PY_EVAL_PLEASE_STOP_BIT`) pub(crate) requested: AtomicBool, + /// Whether the world is currently stopped (`stw->world_stopped`). + world_stopped: AtomicBool, /// Ident of the thread that requested the stop (like `stw->requester`) requester: AtomicU64, /// Signaled by suspending threads when their state transitions to SUSPENDED notify_mutex: std::sync::Mutex<()>, notify_cv: std::sync::Condvar, + /// Number of non-requester threads still expected to park for current stop request. + thread_countdown: AtomicI64, /// Number of stop-the-world attempts. stats_stop_calls: AtomicU64, /// Most recent stop-the-world wait duration in ns. @@ -156,8 +162,6 @@ pub struct StopTheWorldState { stats_attach_wait_yields: AtomicU64, /// Number of yield loops while suspend waited on SUSPENDED->DETACHED. stats_suspend_wait_yields: AtomicU64, - /// Number of yield loops while detach waited on SUSPENDED->DETACHED. - stats_detach_wait_yields: AtomicU64, } #[cfg(all(unix, feature = "threading"))] @@ -173,7 +177,7 @@ pub struct StopTheWorldStats { pub suspend_notifications: u64, pub attach_wait_yields: u64, pub suspend_wait_yields: u64, - pub detach_wait_yields: u64, + pub world_stopped: bool, } #[cfg(all(unix, feature = "threading"))] @@ -188,9 +192,11 @@ impl StopTheWorldState { pub const fn new() -> Self { Self { requested: AtomicBool::new(false), + world_stopped: AtomicBool::new(false), requester: AtomicU64::new(0), notify_mutex: std::sync::Mutex::new(()), notify_cv: std::sync::Condvar::new(), + thread_countdown: AtomicI64::new(0), stats_stop_calls: AtomicU64::new(0), stats_last_wait_ns: AtomicU64::new(0), stats_total_wait_ns: AtomicU64::new(0), @@ -201,7 +207,6 @@ impl StopTheWorldState { stats_suspend_notifications: AtomicU64::new(0), stats_attach_wait_yields: AtomicU64::new(0), stats_suspend_wait_yields: AtomicU64::new(0), - stats_detach_wait_yields: AtomicU64::new(0), } } @@ -209,18 +214,49 @@ impl StopTheWorldState { pub(crate) fn notify_suspended(&self) { self.stats_suspend_notifications .fetch_add(1, Ordering::Relaxed); - // Just signal the condvar; the requester holds the mutex. + // Synchronize with requester wait loop to avoid lost wakeups. + let _guard = self.notify_mutex.lock().unwrap(); + self.decrement_thread_countdown(1); self.notify_cv.notify_one(); } + #[inline] + fn init_thread_countdown(&self, vm: &VirtualMachine) -> i64 { + let requester = self.requester.load(Ordering::Relaxed); + let registry = vm.state.thread_frames.lock(); + // Keep requested/count initialization serialized with thread-slot + // registration (which also takes this lock), matching the + // HEAD_LOCK-guarded stop-the-world bookkeeping. + self.requested.store(true, Ordering::Release); + let count = registry + .keys() + .filter(|&&thread_id| thread_id != requester) + .count(); + let count = (count.min(i64::MAX as usize)) as i64; + self.thread_countdown.store(count, Ordering::Release); + count + } + + #[inline] + fn decrement_thread_countdown(&self, n: u64) { + if n == 0 { + return; + } + let n = (n.min(i64::MAX as u64)) as i64; + let prev = self.thread_countdown.fetch_sub(n, Ordering::AcqRel); + if prev <= n { + // Clamp at 0 for safety in case of duplicate notifications. + self.thread_countdown.store(0, Ordering::Release); + } + } + /// Try to CAS detached threads directly to SUSPENDED and check whether - /// all non-requester threads are now SUSPENDED. - /// Like CPython's `park_detached_threads`. + /// stop countdown reached zero after parking detached threads + /// (`park_detached_threads`), matching CPython behavior class. fn park_detached_threads(&self, vm: &VirtualMachine) -> bool { use thread::{THREAD_ATTACHED, THREAD_DETACHED, THREAD_SUSPENDED}; let requester = self.requester.load(Ordering::Relaxed); let registry = vm.state.thread_frames.lock(); - let mut all_suspended = true; let mut attached_seen = 0u64; let mut forced_parks = 0u64; for (&id, slot) in registry.iter() { @@ -230,17 +266,40 @@ impl StopTheWorldState { let state = slot.state.load(Ordering::Relaxed); if state == THREAD_DETACHED { // CAS DETACHED → SUSPENDED (park without thread cooperation) - let _ = slot.state.compare_exchange( + match slot.state.compare_exchange( THREAD_DETACHED, THREAD_SUSPENDED, Ordering::AcqRel, Ordering::Relaxed, - ); - all_suspended = false; // re-check on next poll - forced_parks = forced_parks.saturating_add(1); + ) { + Ok(_) => { + slot.stop_requested.store(false, Ordering::Release); + forced_parks = forced_parks.saturating_add(1); + } + Err(THREAD_ATTACHED) => { + // Set per-thread stop bit (_PY_EVAL_PLEASE_STOP_BIT). + slot.stop_requested.store(true, Ordering::Release); + // Raced with a thread re-attaching; it will self-suspend. + attached_seen = attached_seen.saturating_add(1); + } + Err(THREAD_DETACHED) => { + // Extremely unlikely race; next poll will handle it. + } + Err(THREAD_SUSPENDED) => { + slot.stop_requested.store(false, Ordering::Release); + // Another path parked it first. + } + Err(other) => { + debug_assert!( + false, + "unexpected thread state in park_detached_threads: {other}" + ); + } + } } else if state == THREAD_ATTACHED { + // Set per-thread stop bit (_PY_EVAL_PLEASE_STOP_BIT). + slot.stop_requested.store(true, Ordering::Release); // Thread is in bytecode — it will see `requested` and self-suspend - all_suspended = false; attached_seen = attached_seen.saturating_add(1); } // THREAD_SUSPENDED → already parked @@ -250,13 +309,14 @@ impl StopTheWorldState { .fetch_add(attached_seen, Ordering::Relaxed); } if forced_parks != 0 { + self.decrement_thread_countdown(forced_parks); self.stats_forced_parks .fetch_add(forced_parks, Ordering::Relaxed); } - all_suspended + forced_parks != 0 && self.thread_countdown.load(Ordering::Acquire) == 0 } - /// Stop all non-requester threads. Like CPython's `stop_the_world`. + /// Stop all non-requester threads (`stop_the_world`). /// /// 1. Sets `requested`, marking the requester thread. /// 2. CAS detached threads to SUSPENDED. @@ -266,9 +326,18 @@ impl StopTheWorldState { let start = std::time::Instant::now(); let requester_ident = crate::stdlib::thread::get_ident(); self.requester.store(requester_ident, Ordering::Relaxed); - self.requested.store(true, Ordering::Release); self.stats_stop_calls.fetch_add(1, Ordering::Relaxed); + let initial_countdown = self.init_thread_countdown(vm); stw_trace(format_args!("stop begin requester={requester_ident}")); + if initial_countdown == 0 { + self.world_stopped.store(true, Ordering::Release); + #[cfg(debug_assertions)] + self.debug_assert_all_non_requester_suspended(vm); + stw_trace(format_args!( + "stop end requester={requester_ident} wait_ns=0 polls=0" + )); + return; + } let mut polls = 0u64; loop { @@ -276,8 +345,15 @@ impl StopTheWorldState { break; } polls = polls.saturating_add(1); - // Wait up to 1 ms for a thread to notify us it suspended + // Wait up to 1 ms for a thread to notify us it suspended. + // Re-check under the wait mutex first to avoid a lost-wake race: + // a thread may have suspended and notified right before we enter wait. let guard = self.notify_mutex.lock().unwrap(); + if self.thread_countdown.load(Ordering::Acquire) == 0 || self.park_detached_threads(vm) + { + drop(guard); + break; + } let _ = self .notify_cv .wait_timeout(guard, core::time::Duration::from_millis(1)); @@ -301,6 +377,7 @@ impl StopTheWorldState { Err(observed) => prev_max = observed, } } + self.world_stopped.store(true, Ordering::Release); #[cfg(debug_assertions)] self.debug_assert_all_non_requester_suspended(vm); stw_trace(format_args!( @@ -308,26 +385,36 @@ impl StopTheWorldState { )); } - /// Resume all suspended threads. Like CPython's `start_the_world`. + /// Resume all suspended threads (`start_the_world`). pub fn start_the_world(&self, vm: &VirtualMachine) { use thread::{THREAD_DETACHED, THREAD_SUSPENDED}; let requester = self.requester.load(Ordering::Relaxed); stw_trace(format_args!("start begin requester={requester}")); + let registry = vm.state.thread_frames.lock(); // Clear the request flag BEFORE waking threads. Otherwise a thread // returning from allow_threads → attach_thread could observe // `requested == true`, re-suspend itself, and stay parked forever. + // Keep this write under the registry lock to serialize with new + // thread-slot initialization. self.requested.store(false, Ordering::Release); - let registry = vm.state.thread_frames.lock(); + self.world_stopped.store(false, Ordering::Release); for (&id, slot) in registry.iter() { if id == requester { continue; } - if slot.state.load(Ordering::Relaxed) == THREAD_SUSPENDED { + slot.stop_requested.store(false, Ordering::Release); + let state = slot.state.load(Ordering::Relaxed); + debug_assert!( + state == THREAD_SUSPENDED, + "non-requester thread not suspended at start-the-world: id={id} state={state}" + ); + if state == THREAD_SUSPENDED { slot.state.store(THREAD_DETACHED, Ordering::Release); slot.thread.unpark(); } } drop(registry); + self.thread_countdown.store(0, Ordering::Release); self.requester.store(0, Ordering::Relaxed); #[cfg(debug_assertions)] self.debug_assert_all_non_requester_detached(vm); @@ -337,10 +424,24 @@ impl StopTheWorldState { /// Reset after fork in the child (only one thread alive). pub fn reset_after_fork(&self) { self.requested.store(false, Ordering::Relaxed); + self.world_stopped.store(false, Ordering::Relaxed); self.requester.store(0, Ordering::Relaxed); + self.thread_countdown.store(0, Ordering::Relaxed); stw_trace(format_args!("reset-after-fork")); } + #[inline] + pub(crate) fn requester_ident(&self) -> u64 { + self.requester.load(Ordering::Relaxed) + } + + #[inline] + pub(crate) fn notify_thread_gone(&self) { + let _guard = self.notify_mutex.lock().unwrap(); + self.decrement_thread_countdown(1); + self.notify_cv.notify_one(); + } + pub fn stats_snapshot(&self) -> StopTheWorldStats { StopTheWorldStats { stop_calls: self.stats_stop_calls.load(Ordering::Relaxed), @@ -353,7 +454,7 @@ impl StopTheWorldState { suspend_notifications: self.stats_suspend_notifications.load(Ordering::Relaxed), attach_wait_yields: self.stats_attach_wait_yields.load(Ordering::Relaxed), suspend_wait_yields: self.stats_suspend_wait_yields.load(Ordering::Relaxed), - detach_wait_yields: self.stats_detach_wait_yields.load(Ordering::Relaxed), + world_stopped: self.world_stopped.load(Ordering::Relaxed), } } @@ -368,7 +469,6 @@ impl StopTheWorldState { self.stats_suspend_notifications.store(0, Ordering::Relaxed); self.stats_attach_wait_yields.store(0, Ordering::Relaxed); self.stats_suspend_wait_yields.store(0, Ordering::Relaxed); - self.stats_detach_wait_yields.store(0, Ordering::Relaxed); } #[inline] @@ -387,17 +487,9 @@ impl StopTheWorldState { } } - #[inline] - pub(crate) fn add_detach_wait_yields(&self, n: u64) { - if n != 0 { - self.stats_detach_wait_yields - .fetch_add(n, Ordering::Relaxed); - } - } - #[cfg(debug_assertions)] fn debug_assert_all_non_requester_suspended(&self, vm: &VirtualMachine) { - use thread::THREAD_ATTACHED; + use thread::THREAD_SUSPENDED; let requester = self.requester.load(Ordering::Relaxed); let registry = vm.state.thread_frames.lock(); for (&id, slot) in registry.iter() { @@ -406,8 +498,8 @@ impl StopTheWorldState { } let state = slot.state.load(Ordering::Relaxed); debug_assert!( - state != THREAD_ATTACHED, - "non-requester thread still attached during stop-the-world: id={id} state={state}" + state == THREAD_SUSPENDED, + "non-requester thread not suspended during stop-the-world: id={id} state={state}" ); } } @@ -1837,6 +1929,26 @@ impl VirtualMachine { self.get_method(obj, method_name) } + #[inline] + pub(crate) fn eval_breaker_tripped(&self) -> bool { + #[cfg(feature = "threading")] + if self.state.finalizing.load(Ordering::Relaxed) && !self.is_main_thread() { + return true; + } + + #[cfg(all(unix, feature = "threading"))] + if thread::stop_requested_for_current_thread() { + return true; + } + + #[cfg(not(target_arch = "wasm32"))] + if crate::signal::is_triggered() { + return true; + } + + false + } + #[inline] /// Checks for triggered signals and calls the appropriate handlers. A no-op on /// platforms where signals are not supported. diff --git a/crates/vm/src/vm/thread.rs b/crates/vm/src/vm/thread.rs index 10297a964fa..80529699738 100644 --- a/crates/vm/src/vm/thread.rs +++ b/crates/vm/src/vm/thread.rs @@ -36,6 +36,9 @@ pub struct ThreadSlot { /// Thread state for stop-the-world: DETACHED / ATTACHED / SUSPENDED #[cfg(unix)] pub state: core::sync::atomic::AtomicI32, + /// Per-thread stop request bit (eval breaker equivalent). + #[cfg(unix)] + pub stop_requested: core::sync::atomic::AtomicBool, /// Handle for waking this thread from park in stop-the-world paths. #[cfg(unix)] pub thread: std::thread::Thread, @@ -93,7 +96,7 @@ pub fn enter_vm(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { // Outermost exit: transition ATTACHED → DETACHED #[cfg(all(unix, feature = "threading"))] if vms.borrow().len() == 1 { - detach_thread(vm); + detach_thread(); } vms.borrow_mut().pop(); } @@ -123,6 +126,8 @@ fn init_thread_slot_if_needed(vm: &VirtualMachine) { }, ), #[cfg(unix)] + stop_requested: core::sync::atomic::AtomicBool::new(false), + #[cfg(unix)] thread: std::thread::current(), }); registry.insert(thread_id, new_slot.clone()); @@ -139,7 +144,7 @@ fn wait_while_suspended(slot: &ThreadSlot) -> u64 { let mut wait_yields = 0u64; while slot.state.load(Ordering::Acquire) == THREAD_SUSPENDED { wait_yields = wait_yields.saturating_add(1); - std::thread::park_timeout(core::time::Duration::from_micros(50)); + std::thread::park(); } wait_yields } @@ -150,29 +155,6 @@ fn attach_thread(vm: &VirtualMachine) { if let Some(s) = slot.borrow().as_ref() { super::stw_trace(format_args!("attach begin")); loop { - if vm.state.stop_the_world.requested.load(Ordering::Acquire) { - match s.state.compare_exchange( - THREAD_DETACHED, - THREAD_SUSPENDED, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => { - super::stw_trace(format_args!("attach requested DETACHED->SUSPENDED")); - vm.state.stop_the_world.notify_suspended(); - let wait_yields = wait_while_suspended(s); - vm.state.stop_the_world.add_attach_wait_yields(wait_yields); - super::stw_trace(format_args!("attach requested resumed-detached")); - continue; - } - Err(THREAD_SUSPENDED) => { - let wait_yields = wait_while_suspended(s); - vm.state.stop_the_world.add_attach_wait_yields(wait_yields); - continue; - } - Err(_) => {} - } - } match s.state.compare_exchange( THREAD_DETACHED, THREAD_ATTACHED, @@ -202,7 +184,7 @@ fn attach_thread(vm: &VirtualMachine) { /// Transition ATTACHED → DETACHED (like `_PyThreadState_Detach`). #[cfg(all(unix, feature = "threading"))] -fn detach_thread(vm: &VirtualMachine) { +fn detach_thread() { CURRENT_THREAD_SLOT.with(|slot| { if let Some(s) = slot.borrow().as_ref() { match s.state.compare_exchange( @@ -222,25 +204,6 @@ fn detach_thread(vm: &VirtualMachine) { } } super::stw_trace(format_args!("detach ATTACHED->DETACHED")); - - if vm.state.stop_the_world.requested.load(Ordering::Acquire) { - match s.state.compare_exchange( - THREAD_DETACHED, - THREAD_SUSPENDED, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => { - super::stw_trace(format_args!("detach requested DETACHED->SUSPENDED")); - vm.state.stop_the_world.notify_suspended(); - } - Err(THREAD_SUSPENDED) => {} - Err(_) => return, - } - let wait_yields = wait_while_suspended(s); - vm.state.stop_the_world.add_detach_wait_yields(wait_yields); - super::stw_trace(format_args!("detach requested resumed-detached")); - } } }); } @@ -249,10 +212,10 @@ fn detach_thread(vm: &VirtualMachine) { /// running `f`, then re-attach afterwards. This allows `stop_the_world` /// to park this thread during blocking operations. /// -/// Equivalent to CPython's `Py_BEGIN_ALLOW_THREADS` / `Py_END_ALLOW_THREADS`. +/// `Py_BEGIN_ALLOW_THREADS` / `Py_END_ALLOW_THREADS` equivalent. #[cfg(all(unix, feature = "threading"))] pub fn allow_threads(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { - // Preserve CPython-like save/restore semantics: + // Preserve save/restore semantics: // only detach if this call observed ATTACHED at entry, and always restore // on unwind. let should_transition = CURRENT_THREAD_SLOT.with(|slot| { @@ -264,7 +227,7 @@ pub fn allow_threads(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { return f(); } - detach_thread(vm); + detach_thread(); let reattach_guard = scopeguard::guard(vm, attach_thread); let result = f(); drop(reattach_guard); @@ -282,9 +245,24 @@ pub fn allow_threads(_vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { /// (like `_PyThreadState_Suspend` + `_PyThreadState_Attach`). #[cfg(all(unix, feature = "threading"))] pub fn suspend_if_needed(stw: &super::StopTheWorldState) { - if !stw.requested.load(Ordering::Relaxed) { + let should_suspend = CURRENT_THREAD_SLOT.with(|slot| { + slot.borrow() + .as_ref() + .is_some_and(|s| s.stop_requested.load(Ordering::Relaxed)) + }); + if !should_suspend { + return; + } + + if !stw.requested.load(Ordering::Acquire) { + CURRENT_THREAD_SLOT.with(|slot| { + if let Some(s) = slot.borrow().as_ref() { + s.stop_requested.store(false, Ordering::Release); + } + }); return; } + do_suspend(stw); } @@ -300,7 +278,10 @@ fn do_suspend(stw: &super::StopTheWorldState) { Ordering::AcqRel, Ordering::Acquire, ) { - Ok(_) => {} + Ok(_) => { + // Consumed this thread's stop request bit. + s.stop_requested.store(false, Ordering::Release); + } Err(THREAD_DETACHED) => { // Leaving VM; caller will re-check on next entry. super::stw_trace(format_args!("suspend skip DETACHED")); @@ -308,6 +289,7 @@ fn do_suspend(stw: &super::StopTheWorldState) { } Err(THREAD_SUSPENDED) => { // Already parked by another path. + s.stop_requested.store(false, Ordering::Release); super::stw_trace(format_args!("suspend skip already-suspended")); return; } @@ -322,6 +304,7 @@ fn do_suspend(stw: &super::StopTheWorldState) { // no one will set us back to DETACHED — we must self-recover. if !stw.requested.load(Ordering::Acquire) { s.state.store(THREAD_ATTACHED, Ordering::Release); + s.stop_requested.store(false, Ordering::Release); super::stw_trace(format_args!("suspend abort requested-cleared")); return; } @@ -334,8 +317,7 @@ fn do_suspend(stw: &super::StopTheWorldState) { let wait_yields = wait_while_suspended(s); stw.add_suspend_wait_yields(wait_yields); - // Re-attach (DETACHED → ATTACHED), mirroring CPython's - // tstate_wait_attach CAS loop. + // Re-attach (DETACHED → ATTACHED), tstate_wait_attach CAS loop. loop { match s.state.compare_exchange( THREAD_DETACHED, @@ -355,11 +337,22 @@ fn do_suspend(stw: &super::StopTheWorldState) { } } } + s.stop_requested.store(false, Ordering::Release); super::stw_trace(format_args!("suspend resume -> ATTACHED")); } }); } +#[cfg(all(unix, feature = "threading"))] +#[inline] +pub fn stop_requested_for_current_thread() -> bool { + CURRENT_THREAD_SLOT.with(|slot| { + slot.borrow() + .as_ref() + .is_some_and(|s| s.stop_requested.load(Ordering::Relaxed)) + }) +} + /// Push a frame pointer onto the current thread's shared frame stack. /// The pointed-to frame must remain alive until the matching pop. #[cfg(feature = "threading")] @@ -432,7 +425,41 @@ pub fn get_all_current_exceptions(vm: &VirtualMachine) -> Vec<(u64, Option registry.remove(&thread_id), + _ => None, + } + } else { + None + }; + #[cfg(all(unix, feature = "threading"))] + if let Some(slot) = &removed + && vm.state.stop_the_world.requested.load(Ordering::Acquire) + && thread_id != vm.state.stop_the_world.requester_ident() + && slot.state.load(Ordering::Relaxed) != THREAD_SUSPENDED + { + // A non-requester thread disappeared while stop-the-world is pending. + // Unblock requester countdown progress. + vm.state.stop_the_world.notify_thread_gone(); + } CURRENT_THREAD_SLOT.with(|s| { *s.borrow_mut() = None; }); @@ -454,6 +481,8 @@ pub fn reinit_frame_slot_after_fork(vm: &VirtualMachine) { #[cfg(unix)] state: core::sync::atomic::AtomicI32::new(THREAD_ATTACHED), #[cfg(unix)] + stop_requested: core::sync::atomic::AtomicBool::new(false), + #[cfg(unix)] thread: std::thread::current(), });