diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index 5011975f496..0e4e17e2872 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -30,6 +30,7 @@ fromlist heaptype HIGHRES IMMUTABLETYPE +ismine Itertool keeped kwonlyarg @@ -40,6 +41,7 @@ lsprof maxdepth mult multibytecodec +newsemlockobject nkwargs noraise numer diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index c22ce769c48..351887819e6 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -38,7 +38,7 @@ from test.support import socket_helper from test.support import threading_helper from test.support import warnings_helper - +from test.support import subTests # Skip tests if _multiprocessing wasn't built. _multiprocessing = import_helper.import_module('_multiprocessing') @@ -1109,7 +1109,7 @@ def test_put(self): @classmethod def _test_get(cls, queue, child_can_start, parent_can_continue): child_can_start.wait() - #queue.put(1) + queue.put(1) queue.put(2) queue.put(3) queue.put(4) @@ -1133,15 +1133,16 @@ def test_get(self): child_can_start.set() parent_can_continue.wait() - time.sleep(DELTA) + for _ in support.sleeping_retry(support.SHORT_TIMEOUT): + if not queue_empty(queue): + break self.assertEqual(queue_empty(queue), False) - # Hangs unexpectedly, remove for now - #self.assertEqual(queue.get(), 1) + self.assertEqual(queue.get_nowait(), 1) self.assertEqual(queue.get(True, None), 2) self.assertEqual(queue.get(True), 3) self.assertEqual(queue.get(timeout=1), 4) - self.assertEqual(queue.get_nowait(), 5) + self.assertEqual(queue.get(), 5) self.assertEqual(queue_empty(queue), True) @@ -2970,6 +2971,8 @@ def test_map_no_failfast(self): # check that we indeed waited for all jobs self.assertGreater(time.monotonic() - t_start, 0.9) + # TODO: RUSTPYTHON - reference counting differences + @unittest.skip("TODO: RUSTPYTHON") def test_release_task_refs(self): # Issue #29861: task arguments and results should not be kept # alive after we are done with them. @@ -3882,6 +3885,8 @@ def _remote(cls, conn): conn.close() + # TODO: RUSTPYTHON - hangs + @unittest.skip("TODO: RUSTPYTHON") def test_pickling(self): families = self.connection.families @@ -4051,6 +4056,8 @@ def test_heap(self): self.assertEqual(len(heap._allocated_blocks), 0, heap._allocated_blocks) self.assertEqual(len(heap._len_to_seq), 0) + # TODO: RUSTPYTHON - gc.enable() not implemented + @unittest.expectedFailure def test_free_from_gc(self): # Check that freeing of blocks by the garbage collector doesn't deadlock # (issue #12352). @@ -4103,6 +4110,8 @@ def _double(cls, x, y, z, foo, arr, string): for i in range(len(arr)): arr[i] *= 2 + # TODO: RUSTPYTHON - ctypes Structure shared memory not working + @unittest.expectedFailure def test_sharedctypes(self, lock=False): x = Value('i', 7, lock=lock) y = Value(c_double, 1.0/3.0, lock=lock) @@ -4126,6 +4135,8 @@ def test_sharedctypes(self, lock=False): self.assertAlmostEqual(arr[i], i*2) self.assertEqual(string.value, latin('hellohello')) + # TODO: RUSTPYTHON - calls test_sharedctypes which fails + @unittest.expectedFailure def test_synchronize(self): self.test_sharedctypes(lock=True) @@ -4140,6 +4151,19 @@ def test_copy(self): self.assertEqual(bar.z, 2 ** 33) +def resource_tracker_format_subtests(func): + """Run given test using both resource tracker communication formats""" + def _inner(self, *args, **kwargs): + tracker = resource_tracker._resource_tracker + for use_simple_format in False, True: + with ( + self.subTest(use_simple_format=use_simple_format), + unittest.mock.patch.object( + tracker, '_use_simple_format', use_simple_format) + ): + func(self, *args, **kwargs) + return _inner + @unittest.skipUnless(HAS_SHMEM, "requires multiprocessing.shared_memory") @hashlib_helper.requires_hashdigest('sha256') class _TestSharedMemory(BaseTestCase): @@ -4417,6 +4441,7 @@ def test_shared_memory_SharedMemoryServer_ignores_sigint(self): smm.shutdown() @unittest.skipIf(os.name != "posix", "resource_tracker is posix only") + @resource_tracker_format_subtests def test_shared_memory_SharedMemoryManager_reuses_resource_tracker(self): # bpo-36867: test that a SharedMemoryManager uses the # same resource_tracker process as its parent. @@ -4667,6 +4692,7 @@ def test_shared_memory_cleaned_after_process_termination(self): "shared_memory objects to clean up at shutdown", err) @unittest.skipIf(os.name != "posix", "resource_tracker is posix only") + @resource_tracker_format_subtests def test_shared_memory_untracking(self): # gh-82300: When a separate Python process accesses shared memory # with track=False, it must not cause the memory to be deleted @@ -4694,6 +4720,7 @@ def test_shared_memory_untracking(self): mem.close() @unittest.skipIf(os.name != "posix", "resource_tracker is posix only") + @resource_tracker_format_subtests def test_shared_memory_tracking(self): # gh-82300: When a separate Python process accesses shared memory # with track=True, it must cause the memory to be deleted when @@ -4787,6 +4814,8 @@ def test_finalize(self): result = [obj for obj in iter(conn.recv, 'STOP')] self.assertEqual(result, ['a', 'b', 'd10', 'd03', 'd02', 'd01', 'e']) + # TODO: RUSTPYTHON - gc.get_threshold() and gc.set_threshold() not implemented + @unittest.expectedFailure @support.requires_resource('cpu') def test_thread_safety(self): # bpo-24484: _run_finalizers() should be thread-safe @@ -5414,6 +5443,8 @@ def run_in_child(cls, start_method): flags = (tuple(sys.flags), grandchild_flags) print(json.dumps(flags)) + # TODO: RUSTPYTHON - SyntaxError in subprocess after fork + @unittest.expectedFailure def test_flags(self): import json # start child process using unusual flags @@ -6457,28 +6488,13 @@ def test_std_streams_flushed_after_preload(self): if multiprocessing.get_start_method() != "forkserver": self.skipTest("forkserver specific test") - # Create a test module in the temporary directory on the child's path - # TODO: This can all be simplified once gh-126631 is fixed and we can - # use __main__ instead of a module. - dirname = os.path.join(self._temp_dir, 'preloaded_module') - init_name = os.path.join(dirname, '__init__.py') - os.mkdir(dirname) - with open(init_name, "w") as f: - cmd = '''if 1: - import sys - print('stderr', end='', file=sys.stderr) - print('stdout', end='', file=sys.stdout) - ''' - f.write(cmd) - name = os.path.join(os.path.dirname(__file__), 'mp_preload_flush.py') - env = {'PYTHONPATH': self._temp_dir} - _, out, err = test.support.script_helper.assert_python_ok(name, **env) + _, out, err = test.support.script_helper.assert_python_ok(name) # Check stderr first, as it is more likely to be useful to see in the # event of a failure. - self.assertEqual(err.decode().rstrip(), 'stderr') - self.assertEqual(out.decode().rstrip(), 'stdout') + self.assertEqual(err.decode().rstrip(), '__main____mp_main__') + self.assertEqual(out.decode().rstrip(), '__main____mp_main__') class MiscTestCase(unittest.TestCase): @@ -6804,3 +6820,52 @@ class SemLock(_multiprocessing.SemLock): name = f'test_semlock_subclass-{os.getpid()}' s = SemLock(1, 0, 10, name, False) _multiprocessing.sem_unlink(name) + + +@unittest.skipUnless(HAS_SHMEM, "requires multiprocessing.shared_memory") +class TestSharedMemoryNames(unittest.TestCase): + @subTests('use_simple_format', (True, False)) + def test_that_shared_memory_name_with_colons_has_no_resource_tracker_errors( + self, use_simple_format): + # Test script that creates and cleans up shared memory with colon in name + test_script = textwrap.dedent(""" + import sys + from multiprocessing import shared_memory + from multiprocessing import resource_tracker + import time + + resource_tracker._resource_tracker._use_simple_format = %s + + # Test various patterns of colons in names + test_names = [ + "a:b", + "a:b:c", + "test:name:with:many:colons", + ":starts:with:colon", + "ends:with:colon:", + "::double::colons::", + "name\\nwithnewline", + "name-with-trailing-newline\\n", + "\\nname-starts-with-newline", + "colons:and\\nnewlines:mix", + "multi\\nline\\nname", + ] + + for name in test_names: + try: + shm = shared_memory.SharedMemory(create=True, size=100, name=name) + shm.buf[:5] = b'hello' # Write something to the shared memory + shm.close() + shm.unlink() + + except Exception as e: + print(f"Error with name '{name}': {e}", file=sys.stderr) + sys.exit(1) + + print("SUCCESS") + """ % use_simple_format) + + rc, out, err = script_helper.assert_python_ok("-c", test_script) + self.assertIn(b"SUCCESS", out) + self.assertNotIn(b"traceback", err.lower(), err) + self.assertNotIn(b"resource_tracker.py", err, err) diff --git a/Lib/test/mp_fork_bomb.py b/Lib/test/mp_fork_bomb.py new file mode 100644 index 00000000000..017e010ba0e --- /dev/null +++ b/Lib/test/mp_fork_bomb.py @@ -0,0 +1,18 @@ +import multiprocessing, sys + +def foo(): + print("123") + +# Because "if __name__ == '__main__'" is missing this will not work +# correctly on Windows. However, we should get a RuntimeError rather +# than the Windows equivalent of a fork bomb. + +if len(sys.argv) > 1: + multiprocessing.set_start_method(sys.argv[1]) +else: + multiprocessing.set_start_method('spawn') + +p = multiprocessing.Process(target=foo) +p.start() +p.join() +sys.exit(p.exitcode) diff --git a/Lib/test/mp_preload.py b/Lib/test/mp_preload.py new file mode 100644 index 00000000000..5314e8f0b21 --- /dev/null +++ b/Lib/test/mp_preload.py @@ -0,0 +1,18 @@ +import multiprocessing + +multiprocessing.Lock() + + +def f(): + print("ok") + + +if __name__ == "__main__": + ctx = multiprocessing.get_context("forkserver") + modname = "test.mp_preload" + # Make sure it's importable + __import__(modname) + ctx.set_forkserver_preload([modname]) + proc = ctx.Process(target=f) + proc.start() + proc.join() diff --git a/Lib/test/mp_preload_flush.py b/Lib/test/mp_preload_flush.py new file mode 100644 index 00000000000..c195a9ef6b2 --- /dev/null +++ b/Lib/test/mp_preload_flush.py @@ -0,0 +1,11 @@ +import multiprocessing +import sys + +print(__name__, end='', file=sys.stderr) +print(__name__, end='', file=sys.stdout) +if __name__ == '__main__': + multiprocessing.set_start_method('forkserver') + for _ in range(2): + p = multiprocessing.Process() + p.start() + p.join() diff --git a/Lib/test/test_importlib/test_threaded_import.py b/Lib/test/test_importlib/test_threaded_import.py index 148b2e4370b..3ceb86cbea3 100644 --- a/Lib/test/test_importlib/test_threaded_import.py +++ b/Lib/test/test_importlib/test_threaded_import.py @@ -256,8 +256,7 @@ def test_concurrent_futures_circular_import(self): 'partial', 'cfimport.py') script_helper.assert_python_ok(fn) - @unittest.skipUnless(hasattr(_multiprocessing, "SemLock"), "TODO: RUSTPYTHON, pool_in_threads.py needs _multiprocessing.SemLock") - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") + @unittest.skip("TODO: RUSTPYTHON - fails on Linux due to multiprocessing issues") def test_multiprocessing_pool_circular_import(self): # Regression test for bpo-41567 fn = os.path.join(os.path.dirname(__file__), diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 9004e9ed744..529cb2dc2f1 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -4058,7 +4058,8 @@ def _mpinit_issue121723(qspec, message_to_log): # log a message (this creates a record put in the queue) logging.getLogger().info(message_to_log) - @unittest.expectedFailure # TODO: RUSTPYTHON; ImportError: cannot import name 'SemLock' + # TODO: RUSTPYTHON - SemLock not implemented on Windows + @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") @skip_if_tsan_fork @support.requires_subprocess() def test_multiprocessing_queues(self): @@ -4118,7 +4119,8 @@ def test_90195(self): # Logger should be enabled, since explicitly mentioned self.assertFalse(logger.disabled) - @unittest.expectedFailure # TODO: RUSTPYTHON; ImportError: cannot import name 'SemLock' + # TODO: RUSTPYTHON - SemLock not implemented on Windows + @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def test_111615(self): # See gh-111615 import_helper.import_module('_multiprocessing') # see gh-113692 diff --git a/crates/stdlib/src/multiprocessing.rs b/crates/stdlib/src/multiprocessing.rs index 9ff2d3dc318..21b7bfa85c7 100644 --- a/crates/stdlib/src/multiprocessing.rs +++ b/crates/stdlib/src/multiprocessing.rs @@ -41,6 +41,696 @@ mod _multiprocessing { } } -#[cfg(not(windows))] +// Unix platforms (Linux, macOS, etc.) +// macOS has broken sem_timedwait/sem_getvalue - we use polled fallback +#[cfg(unix)] +#[pymodule] +mod _multiprocessing { + use crate::vm::{ + Context, FromArgs, Py, PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyBaseExceptionRef, PyDict, PyType, PyTypeRef}, + function::{FuncArgs, KwArgs}, + types::Constructor, + }; + use libc::sem_t; + use nix::errno::Errno; + use std::{ + ffi::CString, + sync::atomic::{AtomicI32, AtomicU64, Ordering}, + }; + + /// Error type for sem_timedwait operations + #[cfg(target_vendor = "apple")] + enum SemWaitError { + Timeout, + SignalException(PyBaseExceptionRef), + OsError(Errno), + } + + /// macOS fallback for sem_timedwait using select + sem_trywait polling + /// Matches sem_timedwait_save in semaphore.c + #[cfg(target_vendor = "apple")] + fn sem_timedwait_polled( + sem: *mut sem_t, + deadline: &libc::timespec, + vm: &VirtualMachine, + ) -> Result<(), SemWaitError> { + let mut delay: u64 = 0; + + loop { + // poll: try to acquire + if unsafe { libc::sem_trywait(sem) } == 0 { + return Ok(()); + } + let err = Errno::last(); + if err != Errno::EAGAIN { + return Err(SemWaitError::OsError(err)); + } + + // get current time + let mut now = libc::timeval { + tv_sec: 0, + tv_usec: 0, + }; + if unsafe { libc::gettimeofday(&mut now, std::ptr::null_mut()) } < 0 { + return Err(SemWaitError::OsError(Errno::last())); + } + + // check for timeout + let deadline_usec = deadline.tv_sec * 1_000_000 + deadline.tv_nsec / 1000; + #[allow(clippy::unnecessary_cast)] + let now_usec = now.tv_sec as i64 * 1_000_000 + now.tv_usec as i64; + + if now_usec >= deadline_usec { + return Err(SemWaitError::Timeout); + } + + // calculate how much time is left + let difference = (deadline_usec - now_usec) as u64; + + // check delay not too long -- maximum is 20 msecs + delay += 1000; + if delay > 20000 { + delay = 20000; + } + if delay > difference { + delay = difference; + } + + // sleep using select + let mut tv_delay = libc::timeval { + tv_sec: (delay / 1_000_000) as _, + tv_usec: (delay % 1_000_000) as _, + }; + unsafe { + libc::select( + 0, + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut tv_delay, + ) + }; + + // check for signals - preserve the exception (e.g., KeyboardInterrupt) + if let Err(exc) = vm.check_signals() { + return Err(SemWaitError::SignalException(exc)); + } + } + } + + // These match the values in Lib/multiprocessing/synchronize.py + const RECURSIVE_MUTEX: i32 = 0; + const SEMAPHORE: i32 = 1; + + // #define ISMINE(o) (o->count > 0 && PyThread_get_thread_ident() == o->last_tid) + macro_rules! ismine { + ($self:expr) => { + $self.count.load(Ordering::Acquire) > 0 + && $self.last_tid.load(Ordering::Acquire) == current_thread_id() + }; + } + + #[derive(FromArgs)] + struct SemLockNewArgs { + #[pyarg(positional)] + kind: i32, + #[pyarg(positional)] + value: i32, + #[pyarg(positional)] + maxvalue: i32, + #[pyarg(positional)] + name: String, + #[pyarg(positional)] + unlink: bool, + } + + #[pyattr] + #[pyclass(name = "SemLock", module = "_multiprocessing")] + #[derive(Debug, PyPayload)] + struct SemLock { + handle: SemHandle, + kind: i32, + maxvalue: i32, + name: Option, + last_tid: AtomicU64, // unsigned long + count: AtomicI32, // int + } + + #[derive(Debug)] + struct SemHandle { + raw: *mut sem_t, + } + + unsafe impl Send for SemHandle {} + unsafe impl Sync for SemHandle {} + + impl SemHandle { + fn create( + name: &str, + value: u32, + unlink: bool, + vm: &VirtualMachine, + ) -> PyResult<(Self, Option)> { + let cname = semaphore_name(vm, name)?; + // SEM_CREATE(name, val, max) sem_open(name, O_CREAT | O_EXCL, 0600, val) + let raw = unsafe { + libc::sem_open(cname.as_ptr(), libc::O_CREAT | libc::O_EXCL, 0o600, value) + }; + if raw == libc::SEM_FAILED { + let err = Errno::last(); + return Err(os_error(vm, err)); + } + if unlink { + // SEM_UNLINK(name) sem_unlink(name) + unsafe { + libc::sem_unlink(cname.as_ptr()); + } + Ok((SemHandle { raw }, None)) + } else { + Ok((SemHandle { raw }, Some(name.to_owned()))) + } + } + + fn open_existing(name: &str, vm: &VirtualMachine) -> PyResult { + let cname = semaphore_name(vm, name)?; + let raw = unsafe { libc::sem_open(cname.as_ptr(), 0) }; + if raw == libc::SEM_FAILED { + let err = Errno::last(); + return Err(os_error(vm, err)); + } + Ok(SemHandle { raw }) + } + + #[inline] + fn as_ptr(&self) -> *mut sem_t { + self.raw + } + } + + impl Drop for SemHandle { + fn drop(&mut self) { + // Guard against default/uninitialized state. + // Note: SEM_FAILED is (sem_t*)-1, not null, but valid handles are never null + // and SEM_FAILED is never stored (error is returned immediately on sem_open failure). + if !self.raw.is_null() { + // SEM_CLOSE(sem) sem_close(sem) + unsafe { + libc::sem_close(self.raw); + } + } + } + } + + #[pyclass(with(Constructor), flags(BASETYPE))] + impl SemLock { + #[pygetset] + fn handle(&self) -> isize { + self.handle.as_ptr() as isize + } + + #[pygetset] + fn kind(&self) -> i32 { + self.kind + } + + #[pygetset] + fn maxvalue(&self) -> i32 { + self.maxvalue + } + + #[pygetset] + fn name(&self) -> Option { + self.name.clone() + } + + /// Acquire the semaphore/lock. + // _multiprocessing_SemLock_acquire_impl + #[pymethod] + fn acquire(&self, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + // block=True, timeout=None + + let blocking: bool = args + .kwargs + .get("block") + .or_else(|| args.args.first()) + .map(|o| o.clone().try_to_bool(vm)) + .transpose()? + .unwrap_or(true); + + let timeout_obj = args + .kwargs + .get("timeout") + .or_else(|| args.args.get(1)) + .cloned(); + + if self.kind == RECURSIVE_MUTEX && ismine!(self) { + self.count.fetch_add(1, Ordering::Release); + return Ok(true); + } + + // timeout_obj != Py_None + let use_deadline = timeout_obj.as_ref().is_some_and(|o| !vm.is_none(o)); + + let deadline = if use_deadline { + let timeout_obj = timeout_obj.unwrap(); + // This accepts both int and float, converting to f64 + let timeout: f64 = timeout_obj.try_float(vm)?.to_f64(); + let timeout = if timeout < 0.0 { 0.0 } else { timeout }; + + let mut tv = libc::timeval { + tv_sec: 0, + tv_usec: 0, + }; + let res = unsafe { libc::gettimeofday(&mut tv, std::ptr::null_mut()) }; + if res < 0 { + return Err(vm.new_os_error("gettimeofday failed".to_string())); + } + + // deadline calculation: + // long sec = (long) timeout; + // long nsec = (long) (1e9 * (timeout - sec) + 0.5); + // deadline.tv_sec = now.tv_sec + sec; + // deadline.tv_nsec = now.tv_usec * 1000 + nsec; + // deadline.tv_sec += (deadline.tv_nsec / 1000000000); + // deadline.tv_nsec %= 1000000000; + let sec = timeout as libc::c_long; + let nsec = (1e9 * (timeout - sec as f64) + 0.5) as libc::c_long; + let mut deadline = libc::timespec { + tv_sec: tv.tv_sec + sec as libc::time_t, + tv_nsec: (tv.tv_usec as libc::c_long * 1000 + nsec) as _, + }; + deadline.tv_sec += (deadline.tv_nsec / 1_000_000_000) as libc::time_t; + deadline.tv_nsec %= 1_000_000_000; + Some(deadline) + } else { + None + }; + + // Check whether we can acquire without releasing the GIL and blocking + let mut res; + loop { + res = unsafe { libc::sem_trywait(self.handle.as_ptr()) }; + if res >= 0 { + break; + } + let err = Errno::last(); + if err == Errno::EINTR { + vm.check_signals()?; + continue; + } + break; + } + + // if (res < 0 && errno == EAGAIN && blocking) + if res < 0 && Errno::last() == Errno::EAGAIN && blocking { + // Couldn't acquire immediately, need to block + #[cfg(not(target_vendor = "apple"))] + { + loop { + // Py_BEGIN_ALLOW_THREADS / Py_END_ALLOW_THREADS + // RustPython doesn't have GIL, so we just do the wait + if let Some(ref dl) = deadline { + res = unsafe { libc::sem_timedwait(self.handle.as_ptr(), dl) }; + } else { + res = unsafe { libc::sem_wait(self.handle.as_ptr()) }; + } + + if res >= 0 { + break; + } + let err = Errno::last(); + if err == Errno::EINTR { + vm.check_signals()?; + continue; + } + break; + } + } + #[cfg(target_vendor = "apple")] + { + // macOS: use polled fallback since sem_timedwait is not available + if let Some(ref dl) = deadline { + match sem_timedwait_polled(self.handle.as_ptr(), dl, vm) { + Ok(()) => res = 0, + Err(SemWaitError::Timeout) => { + // Timeout occurred - return false directly + return Ok(false); + } + Err(SemWaitError::SignalException(exc)) => { + // Propagate the original exception (e.g., KeyboardInterrupt) + return Err(exc); + } + Err(SemWaitError::OsError(e)) => { + return Err(os_error(vm, e)); + } + } + } else { + // No timeout: use sem_wait (available on macOS) + loop { + res = unsafe { libc::sem_wait(self.handle.as_ptr()) }; + if res >= 0 { + break; + } + let err = Errno::last(); + if err == Errno::EINTR { + vm.check_signals()?; + continue; + } + break; + } + } + } + } + + // result handling: + if res < 0 { + let err = Errno::last(); + match err { + Errno::EAGAIN | Errno::ETIMEDOUT => return Ok(false), + Errno::EINTR => { + // EINTR should be handled by the check_signals() loop above + // If we reach here, check signals again and propagate any exception + return vm.check_signals().map(|_| false); + } + _ => return Err(os_error(vm, err)), + } + } + + self.count.fetch_add(1, Ordering::Release); + self.last_tid.store(current_thread_id(), Ordering::Release); + + Ok(true) + } + + /// Release the semaphore/lock. + // _multiprocessing_SemLock_release_impl + #[pymethod] + fn release(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.kind == RECURSIVE_MUTEX { + // if (!ISMINE(self)) + if !ismine!(self) { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.assertion_error.to_owned(), + "attempt to release recursive lock not owned by thread".to_owned(), + )); + } + // if (self->count > 1) { --self->count; Py_RETURN_NONE; } + if self.count.load(Ordering::Acquire) > 1 { + self.count.fetch_sub(1, Ordering::Release); + return Ok(()); + } + // assert(self->count == 1); + } else { + // SEMAPHORE case: check value before releasing + #[cfg(not(target_vendor = "apple"))] + { + // Linux: use sem_getvalue + let mut sval: libc::c_int = 0; + let res = unsafe { libc::sem_getvalue(self.handle.as_ptr(), &mut sval) }; + if res < 0 { + return Err(os_error(vm, Errno::last())); + } + if sval >= self.maxvalue { + return Err(vm.new_value_error( + "semaphore or lock released too many times".to_owned(), + )); + } + } + #[cfg(target_vendor = "apple")] + { + // macOS: HAVE_BROKEN_SEM_GETVALUE + // We will only check properly the maxvalue == 1 case + if self.maxvalue == 1 { + // make sure that already locked + if unsafe { libc::sem_trywait(self.handle.as_ptr()) } < 0 { + if Errno::last() != Errno::EAGAIN { + return Err(os_error(vm, Errno::last())); + } + // it is already locked as expected + } else { + // it was not locked so undo wait and raise + if unsafe { libc::sem_post(self.handle.as_ptr()) } < 0 { + return Err(os_error(vm, Errno::last())); + } + return Err(vm.new_value_error( + "semaphore or lock released too many times".to_owned(), + )); + } + } + } + } + + let res = unsafe { libc::sem_post(self.handle.as_ptr()) }; + if res < 0 { + return Err(os_error(vm, Errno::last())); + } + + self.count.fetch_sub(1, Ordering::Release); + Ok(()) + } + + /// Enter the semaphore/lock (context manager). + // _multiprocessing_SemLock___enter___impl + #[pymethod(name = "__enter__")] + fn enter(&self, vm: &VirtualMachine) -> PyResult { + // return _multiprocessing_SemLock_acquire_impl(self, 1, Py_None); + self.acquire( + FuncArgs::new::, KwArgs>( + vec![vm.ctx.new_bool(true).into()], + KwArgs::default(), + ), + vm, + ) + } + + /// Exit the semaphore/lock (context manager). + // _multiprocessing_SemLock___exit___impl + #[pymethod] + fn __exit__(&self, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + self.release(vm) + } + + /// Rebuild a SemLock from pickled state. + // _multiprocessing_SemLock__rebuild_impl + #[pyclassmethod(name = "_rebuild")] + fn rebuild( + cls: PyTypeRef, + _handle: isize, + kind: i32, + maxvalue: i32, + name: Option, + vm: &VirtualMachine, + ) -> PyResult { + let Some(ref name_str) = name else { + return Err(vm.new_value_error("cannot rebuild SemLock without name".to_owned())); + }; + let handle = SemHandle::open_existing(name_str, vm)?; + // return newsemlockobject(type, handle, kind, maxvalue, name_copy); + let zelf = SemLock { + handle, + kind, + maxvalue, + name, + last_tid: AtomicU64::new(0), + count: AtomicI32::new(0), + }; + zelf.into_ref_with_type(vm, cls).map(Into::into) + } + + /// Rezero the net acquisition count after fork(). + // _multiprocessing_SemLock__after_fork_impl + #[pymethod] + fn _after_fork(&self) { + self.count.store(0, Ordering::Release); + // Also reset last_tid for safety + self.last_tid.store(0, Ordering::Release); + } + + /// Num of `acquire()`s minus num of `release()`s for this process. + // _multiprocessing_SemLock__count_impl + #[pymethod] + fn _count(&self) -> i32 { + self.count.load(Ordering::Acquire) + } + + /// Whether the lock is owned by this thread. + // _multiprocessing_SemLock__is_mine_impl + #[pymethod] + fn _is_mine(&self) -> bool { + ismine!(self) + } + + /// Get the value of the semaphore. + // _multiprocessing_SemLock__get_value_impl + #[pymethod] + fn _get_value(&self, vm: &VirtualMachine) -> PyResult { + #[cfg(not(target_vendor = "apple"))] + { + // Linux: use sem_getvalue + let mut sval: libc::c_int = 0; + let res = unsafe { libc::sem_getvalue(self.handle.as_ptr(), &mut sval) }; + if res < 0 { + return Err(os_error(vm, Errno::last())); + } + // some posix implementations use negative numbers to indicate + // the number of waiting threads + Ok(if sval < 0 { 0 } else { sval }) + } + #[cfg(target_vendor = "apple")] + { + // macOS: HAVE_BROKEN_SEM_GETVALUE - raise NotImplementedError + Err(vm.new_not_implemented_error(String::new())) + } + } + + /// Return whether semaphore has value zero. + // _multiprocessing_SemLock__is_zero_impl + #[pymethod] + fn _is_zero(&self, vm: &VirtualMachine) -> PyResult { + #[cfg(not(target_vendor = "apple"))] + { + Ok(self._get_value(vm)? == 0) + } + #[cfg(target_vendor = "apple")] + { + // macOS: HAVE_BROKEN_SEM_GETVALUE + // Try to acquire - if EAGAIN, value is 0 + if unsafe { libc::sem_trywait(self.handle.as_ptr()) } < 0 { + if Errno::last() == Errno::EAGAIN { + return Ok(true); + } + return Err(os_error(vm, Errno::last())); + } + // Successfully acquired - undo and return false + if unsafe { libc::sem_post(self.handle.as_ptr()) } < 0 { + return Err(os_error(vm, Errno::last())); + } + Ok(false) + } + } + + #[extend_class] + fn extend_class(ctx: &Context, class: &Py) { + class.set_attr( + ctx.intern_str("RECURSIVE_MUTEX"), + ctx.new_int(RECURSIVE_MUTEX).into(), + ); + class.set_attr(ctx.intern_str("SEMAPHORE"), ctx.new_int(SEMAPHORE).into()); + // SEM_VALUE_MAX from system, or INT_MAX if negative + // We use a reasonable default + let sem_value_max: i32 = unsafe { + let val = libc::sysconf(libc::_SC_SEM_VALUE_MAX); + if val < 0 || val > i32::MAX as libc::c_long { + i32::MAX + } else { + val as i32 + } + }; + class.set_attr( + ctx.intern_str("SEM_VALUE_MAX"), + ctx.new_int(sem_value_max).into(), + ); + } + } + + impl Constructor for SemLock { + type Args = SemLockNewArgs; + + // Create a new SemLock. + // _multiprocessing_SemLock_impl + fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { + if args.kind != RECURSIVE_MUTEX && args.kind != SEMAPHORE { + return Err(vm.new_value_error("unrecognized kind".to_owned())); + } + // Value validation + if args.value < 0 || args.value > args.maxvalue { + return Err(vm.new_value_error("invalid value".to_owned())); + } + + let value = args.value as u32; + let (handle, name) = SemHandle::create(&args.name, value, args.unlink, vm)?; + + // return newsemlockobject(type, handle, kind, maxvalue, name_copy); + Ok(SemLock { + handle, + kind: args.kind, + maxvalue: args.maxvalue, + name, + last_tid: AtomicU64::new(0), + count: AtomicI32::new(0), + }) + } + } + + /// Function to unlink semaphore names. + // _PyMp_sem_unlink. + #[pyfunction] + fn sem_unlink(name: String, vm: &VirtualMachine) -> PyResult<()> { + let cname = semaphore_name(vm, &name)?; + let res = unsafe { libc::sem_unlink(cname.as_ptr()) }; + if res < 0 { + return Err(os_error(vm, Errno::last())); + } + Ok(()) + } + + /// Module-level flags dict. + #[pyattr] + fn flags(vm: &VirtualMachine) -> PyRef { + let flags = vm.ctx.new_dict(); + // HAVE_SEM_OPEN is always 1 on Unix (we wouldn't be here otherwise) + flags + .set_item("HAVE_SEM_OPEN", vm.ctx.new_int(1).into(), vm) + .unwrap(); + + #[cfg(not(target_vendor = "apple"))] + { + // Linux: HAVE_SEM_TIMEDWAIT is available + flags + .set_item("HAVE_SEM_TIMEDWAIT", vm.ctx.new_int(1).into(), vm) + .unwrap(); + } + + #[cfg(target_vendor = "apple")] + { + // macOS: sem_getvalue is broken + flags + .set_item("HAVE_BROKEN_SEM_GETVALUE", vm.ctx.new_int(1).into(), vm) + .unwrap(); + } + + flags + } + + fn semaphore_name(vm: &VirtualMachine, name: &str) -> PyResult { + // POSIX semaphore names must start with / + let mut full = String::with_capacity(name.len() + 1); + if !name.starts_with('/') { + full.push('/'); + } + full.push_str(name); + CString::new(full).map_err(|_| vm.new_value_error("embedded null character".to_owned())) + } + + fn os_error(vm: &VirtualMachine, err: Errno) -> PyBaseExceptionRef { + // _PyMp_SetError maps to PyErr_SetFromErrno + let exc_type = match err { + Errno::EEXIST => vm.ctx.exceptions.file_exists_error.to_owned(), + Errno::ENOENT => vm.ctx.exceptions.file_not_found_error.to_owned(), + _ => vm.ctx.exceptions.os_error.to_owned(), + }; + vm.new_os_subtype_error(exc_type, Some(err as i32), err.desc().to_owned()) + .upcast() + } + + /// Get current thread identifier. + /// PyThread_get_thread_ident on Unix (pthread_self). + fn current_thread_id() -> u64 { + unsafe { libc::pthread_self() as u64 } + } +} + +#[cfg(all(not(unix), not(windows)))] #[pymodule] mod _multiprocessing {} diff --git a/crates/stdlib/src/posixshmem.rs b/crates/stdlib/src/posixshmem.rs index 53bf372532d..a52866b7985 100644 --- a/crates/stdlib/src/posixshmem.rs +++ b/crates/stdlib/src/posixshmem.rs @@ -8,25 +8,27 @@ mod _posixshmem { use crate::{ common::os::errno_io_error, - vm::{ - PyResult, VirtualMachine, builtins::PyStrRef, convert::IntoPyException, - function::OptionalArg, - }, + vm::{FromArgs, PyResult, VirtualMachine, builtins::PyStrRef, convert::IntoPyException}, }; - #[pyfunction] - fn shm_open( + #[derive(FromArgs)] + struct ShmOpenArgs { + #[pyarg(any)] name: PyStrRef, + #[pyarg(any)] flags: libc::c_int, - mode: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - let name = CString::new(name.as_str()).map_err(|e| e.into_pyexception(vm))?; - let mode: libc::c_uint = mode.unwrap_or(0o600) as _; + #[pyarg(any, default = 0o600)] + mode: libc::mode_t, + } + + #[pyfunction] + fn shm_open(args: ShmOpenArgs, vm: &VirtualMachine) -> PyResult { + let name = CString::new(args.name.as_str()).map_err(|e| e.into_pyexception(vm))?; + let mode: libc::c_uint = args.mode as _; #[cfg(target_os = "freebsd")] let mode = mode.try_into().unwrap(); // SAFETY: `name` is a NUL-terminated string and `shm_open` does not write through it. - let fd = unsafe { libc::shm_open(name.as_ptr(), flags, mode) }; + let fd = unsafe { libc::shm_open(name.as_ptr(), args.flags, mode) }; if fd == -1 { Err(errno_io_error().into_pyexception(vm)) } else { diff --git a/crates/stdlib/src/pystruct.rs b/crates/stdlib/src/pystruct.rs index 34a4905ed9f..8801f0d705e 100644 --- a/crates/stdlib/src/pystruct.rs +++ b/crates/stdlib/src/pystruct.rs @@ -71,7 +71,7 @@ pub(crate) mod _struct { } else { ("unpack_from", "unpacking") }; - if offset >= buffer_len { + if offset + needed > buffer_len { let msg = format!( "{op} requires a buffer of at least {required} bytes for {op_action} {needed} \ bytes at offset {offset} (actual buffer size is {buffer_len})",