Skip to content

Commit 067ca98

Browse files
committed
Implement Windows SemLock in _multiprocessing module
Add SemLock class using Windows semaphore APIs (CreateSemaphoreW, WaitForSingleObjectEx, ReleaseSemaphore) so test_multiprocessing suites are no longer skipped with "lacks a functioning sem_open". Also add sem_unlink as no-op and flags dict for Windows.
1 parent b48f72d commit 067ca98

File tree

1 file changed

+347
-1
lines changed

1 file changed

+347
-1
lines changed

crates/stdlib/src/multiprocessing.rs

Lines changed: 347 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,354 @@ pub(crate) use _multiprocessing::module_def;
33
#[cfg(windows)]
44
#[pymodule]
55
mod _multiprocessing {
6-
use crate::vm::{PyResult, VirtualMachine, function::ArgBytesLike};
6+
use crate::vm::{
7+
Context, FromArgs, Py, PyPayload, PyRef, PyResult, VirtualMachine,
8+
builtins::{PyDict, PyType, PyTypeRef},
9+
function::{ArgBytesLike, FuncArgs, KwArgs},
10+
types::Constructor,
11+
};
12+
use core::sync::atomic::{AtomicI32, AtomicU32, Ordering};
13+
use windows_sys::Win32::Foundation::{
14+
CloseHandle, HANDLE, INVALID_HANDLE_VALUE, WAIT_EVENT, WAIT_OBJECT_0,
15+
};
716
use windows_sys::Win32::Networking::WinSock::{self, SOCKET};
17+
use windows_sys::Win32::System::Threading::{
18+
CreateSemaphoreW, GetCurrentThreadId, ReleaseSemaphore, WaitForSingleObjectEx,
19+
};
20+
21+
const INFINITE: u32 = 0xFFFFFFFF;
22+
const WAIT_TIMEOUT: WAIT_EVENT = 258; // 0x102
23+
const WAIT_FAILED: WAIT_EVENT = 0xFFFFFFFF;
24+
const ERROR_TOO_MANY_POSTS: u32 = 298;
25+
26+
// These match the values in Lib/multiprocessing/synchronize.py
27+
const RECURSIVE_MUTEX: i32 = 0;
28+
const SEMAPHORE: i32 = 1;
29+
30+
macro_rules! ismine {
31+
($self:expr) => {
32+
$self.count.load(Ordering::Acquire) > 0
33+
&& $self.last_tid.load(Ordering::Acquire) == unsafe { GetCurrentThreadId() }
34+
};
35+
}
36+
37+
#[derive(FromArgs)]
38+
struct SemLockNewArgs {
39+
#[pyarg(positional)]
40+
kind: i32,
41+
#[pyarg(positional)]
42+
value: i32,
43+
#[pyarg(positional)]
44+
maxvalue: i32,
45+
#[pyarg(positional)]
46+
name: String,
47+
#[pyarg(positional)]
48+
unlink: bool,
49+
}
50+
51+
#[pyattr]
52+
#[pyclass(name = "SemLock", module = "_multiprocessing")]
53+
#[derive(Debug, PyPayload)]
54+
struct SemLock {
55+
handle: SemHandle,
56+
kind: i32,
57+
maxvalue: i32,
58+
name: Option<String>,
59+
last_tid: AtomicU32,
60+
count: AtomicI32,
61+
}
62+
63+
#[derive(Debug)]
64+
struct SemHandle {
65+
raw: HANDLE,
66+
}
67+
68+
unsafe impl Send for SemHandle {}
69+
unsafe impl Sync for SemHandle {}
70+
71+
impl SemHandle {
72+
fn create(value: i32, maxvalue: i32, vm: &VirtualMachine) -> PyResult<Self> {
73+
let handle =
74+
unsafe { CreateSemaphoreW(core::ptr::null(), value, maxvalue, core::ptr::null()) };
75+
if handle == 0 as HANDLE {
76+
return Err(vm.new_last_os_error());
77+
}
78+
// Check ERROR_ALREADY_EXISTS
79+
let last_err = unsafe { windows_sys::Win32::Foundation::GetLastError() };
80+
if last_err != 0 {
81+
unsafe { CloseHandle(handle) };
82+
return Err(vm.new_last_os_error());
83+
}
84+
Ok(SemHandle { raw: handle })
85+
}
86+
87+
#[inline]
88+
fn as_raw(&self) -> HANDLE {
89+
self.raw
90+
}
91+
}
92+
93+
impl Drop for SemHandle {
94+
fn drop(&mut self) {
95+
if self.raw != 0 as HANDLE && self.raw != INVALID_HANDLE_VALUE {
96+
unsafe {
97+
CloseHandle(self.raw);
98+
}
99+
}
100+
}
101+
}
102+
103+
/// _GetSemaphoreValue - get value of semaphore by briefly acquiring and releasing
104+
fn get_semaphore_value(handle: HANDLE) -> Result<i32, ()> {
105+
match unsafe { WaitForSingleObjectEx(handle, 0, 0) } {
106+
WAIT_OBJECT_0 => {
107+
let mut previous: i32 = 0;
108+
if unsafe { ReleaseSemaphore(handle, 1, &mut previous) } == 0 {
109+
return Err(());
110+
}
111+
Ok(previous + 1)
112+
}
113+
WAIT_TIMEOUT => Ok(0),
114+
_ => Err(()),
115+
}
116+
}
117+
118+
#[pyclass(with(Constructor), flags(BASETYPE))]
119+
impl SemLock {
120+
#[pygetset]
121+
fn handle(&self) -> isize {
122+
self.handle.as_raw() as isize
123+
}
124+
125+
#[pygetset]
126+
fn kind(&self) -> i32 {
127+
self.kind
128+
}
129+
130+
#[pygetset]
131+
fn maxvalue(&self) -> i32 {
132+
self.maxvalue
133+
}
134+
135+
#[pygetset]
136+
fn name(&self) -> Option<String> {
137+
self.name.clone()
138+
}
139+
140+
#[pymethod]
141+
fn acquire(&self, args: FuncArgs, vm: &VirtualMachine) -> PyResult<bool> {
142+
let blocking: bool = args
143+
.kwargs
144+
.get("block")
145+
.or_else(|| args.args.first())
146+
.map(|o| o.clone().try_to_bool(vm))
147+
.transpose()?
148+
.unwrap_or(true);
149+
150+
let timeout_obj = args
151+
.kwargs
152+
.get("timeout")
153+
.or_else(|| args.args.get(1))
154+
.cloned();
155+
156+
// Calculate timeout in milliseconds
157+
let full_msecs: u32 = if !blocking {
158+
0
159+
} else if timeout_obj.as_ref().is_none_or(|o| vm.is_none(o)) {
160+
INFINITE
161+
} else {
162+
let timeout: f64 = timeout_obj.unwrap().try_float(vm)?.to_f64();
163+
let timeout = timeout * 1000.0; // convert to ms
164+
if timeout < 0.0 {
165+
0
166+
} else if timeout >= 0.5 * INFINITE as f64 {
167+
return Err(vm.new_overflow_error("timeout is too large".to_owned()));
168+
} else {
169+
(timeout + 0.5) as u32
170+
}
171+
};
172+
173+
// Check whether we already own the lock
174+
if self.kind == RECURSIVE_MUTEX && ismine!(self) {
175+
self.count.fetch_add(1, Ordering::Release);
176+
return Ok(true);
177+
}
178+
179+
// Check whether we can acquire without blocking
180+
if unsafe { WaitForSingleObjectEx(self.handle.as_raw(), 0, 0) } == WAIT_OBJECT_0 {
181+
self.last_tid
182+
.store(unsafe { GetCurrentThreadId() }, Ordering::Release);
183+
self.count.fetch_add(1, Ordering::Release);
184+
return Ok(true);
185+
}
186+
187+
// Do the wait
188+
let res = unsafe { WaitForSingleObjectEx(self.handle.as_raw(), full_msecs, 0) };
189+
190+
match res {
191+
WAIT_TIMEOUT => Ok(false),
192+
WAIT_OBJECT_0 => {
193+
self.last_tid
194+
.store(unsafe { GetCurrentThreadId() }, Ordering::Release);
195+
self.count.fetch_add(1, Ordering::Release);
196+
Ok(true)
197+
}
198+
WAIT_FAILED => Err(vm.new_last_os_error()),
199+
_ => Err(vm.new_runtime_error(format!(
200+
"WaitForSingleObject() gave unrecognized value {res}"
201+
))),
202+
}
203+
}
204+
205+
#[pymethod]
206+
fn release(&self, vm: &VirtualMachine) -> PyResult<()> {
207+
if self.kind == RECURSIVE_MUTEX {
208+
if !ismine!(self) {
209+
return Err(vm.new_exception_msg(
210+
vm.ctx.exceptions.assertion_error.to_owned(),
211+
"attempt to release recursive lock not owned by thread".to_owned(),
212+
));
213+
}
214+
if self.count.load(Ordering::Acquire) > 1 {
215+
self.count.fetch_sub(1, Ordering::Release);
216+
return Ok(());
217+
}
218+
}
219+
220+
if unsafe { ReleaseSemaphore(self.handle.as_raw(), 1, core::ptr::null_mut()) } == 0 {
221+
let err = unsafe { windows_sys::Win32::Foundation::GetLastError() };
222+
if err == ERROR_TOO_MANY_POSTS {
223+
return Err(
224+
vm.new_value_error("semaphore or lock released too many times".to_owned())
225+
);
226+
}
227+
return Err(vm.new_last_os_error());
228+
}
229+
230+
self.count.fetch_sub(1, Ordering::Release);
231+
Ok(())
232+
}
233+
234+
#[pymethod(name = "__enter__")]
235+
fn enter(&self, vm: &VirtualMachine) -> PyResult<bool> {
236+
self.acquire(
237+
FuncArgs::new::<Vec<_>, KwArgs>(
238+
vec![vm.ctx.new_bool(true).into()],
239+
KwArgs::default(),
240+
),
241+
vm,
242+
)
243+
}
244+
245+
#[pymethod]
246+
fn __exit__(&self, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
247+
self.release(vm)
248+
}
249+
250+
#[pyclassmethod(name = "_rebuild")]
251+
fn rebuild(
252+
cls: PyTypeRef,
253+
handle: isize,
254+
kind: i32,
255+
maxvalue: i32,
256+
name: Option<String>,
257+
vm: &VirtualMachine,
258+
) -> PyResult {
259+
// On Windows, _rebuild receives the handle directly (no sem_open)
260+
let zelf = SemLock {
261+
handle: SemHandle {
262+
raw: handle as HANDLE,
263+
},
264+
kind,
265+
maxvalue,
266+
name,
267+
last_tid: AtomicU32::new(0),
268+
count: AtomicI32::new(0),
269+
};
270+
zelf.into_ref_with_type(vm, cls).map(Into::into)
271+
}
272+
273+
#[pymethod]
274+
fn _after_fork(&self) {
275+
self.count.store(0, Ordering::Release);
276+
self.last_tid.store(0, Ordering::Release);
277+
}
278+
279+
#[pymethod]
280+
fn __reduce__(&self, vm: &VirtualMachine) -> PyResult {
281+
Err(vm.new_type_error("cannot pickle 'SemLock' object".to_owned()))
282+
}
283+
284+
#[pymethod]
285+
fn _count(&self) -> i32 {
286+
self.count.load(Ordering::Acquire)
287+
}
288+
289+
#[pymethod]
290+
fn _is_mine(&self) -> bool {
291+
ismine!(self)
292+
}
293+
294+
#[pymethod]
295+
fn _get_value(&self, vm: &VirtualMachine) -> PyResult<i32> {
296+
get_semaphore_value(self.handle.as_raw()).map_err(|_| vm.new_last_os_error())
297+
}
298+
299+
#[pymethod]
300+
fn _is_zero(&self, vm: &VirtualMachine) -> PyResult<bool> {
301+
let val =
302+
get_semaphore_value(self.handle.as_raw()).map_err(|_| vm.new_last_os_error())?;
303+
Ok(val == 0)
304+
}
305+
306+
#[extend_class]
307+
fn extend_class(ctx: &Context, class: &Py<PyType>) {
308+
class.set_attr(
309+
ctx.intern_str("RECURSIVE_MUTEX"),
310+
ctx.new_int(RECURSIVE_MUTEX).into(),
311+
);
312+
class.set_attr(ctx.intern_str("SEMAPHORE"), ctx.new_int(SEMAPHORE).into());
313+
class.set_attr(
314+
ctx.intern_str("SEM_VALUE_MAX"),
315+
ctx.new_int(i32::MAX).into(),
316+
);
317+
}
318+
}
319+
320+
impl Constructor for SemLock {
321+
type Args = SemLockNewArgs;
322+
323+
fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
324+
if args.kind != RECURSIVE_MUTEX && args.kind != SEMAPHORE {
325+
return Err(vm.new_value_error("unrecognized kind".to_owned()));
326+
}
327+
if args.value < 0 || args.value > args.maxvalue {
328+
return Err(vm.new_value_error("invalid value".to_owned()));
329+
}
330+
331+
let handle = SemHandle::create(args.value, args.maxvalue, vm)?;
332+
let name = if args.unlink { None } else { Some(args.name) };
333+
334+
Ok(SemLock {
335+
handle,
336+
kind: args.kind,
337+
maxvalue: args.maxvalue,
338+
name,
339+
last_tid: AtomicU32::new(0),
340+
count: AtomicI32::new(0),
341+
})
342+
}
343+
}
344+
345+
// On Windows, sem_unlink is a no-op
346+
#[pyfunction]
347+
fn sem_unlink(_name: String) {}
348+
349+
#[pyattr]
350+
fn flags(vm: &VirtualMachine) -> PyRef<PyDict> {
351+
// On Windows, no HAVE_SEM_OPEN / HAVE_SEM_TIMEDWAIT / HAVE_BROKEN_SEM_GETVALUE
352+
vm.ctx.new_dict()
353+
}
8354

9355
#[pyfunction]
10356
fn closesocket(socket: usize, vm: &VirtualMachine) -> PyResult<()> {

0 commit comments

Comments
 (0)