@@ -3,8 +3,354 @@ pub(crate) use _multiprocessing::module_def;
33#[ cfg( windows) ]
44#[ pymodule]
55mod _multiprocessing {
6- use crate :: vm:: { PyResult , VirtualMachine , function:: ArgBytesLike } ;
6+ use crate :: vm:: {
7+ Context , FromArgs , Py , PyPayload , PyRef , PyResult , VirtualMachine ,
8+ builtins:: { PyDict , PyType , PyTypeRef } ,
9+ function:: { ArgBytesLike , FuncArgs , KwArgs } ,
10+ types:: Constructor ,
11+ } ;
12+ use core:: sync:: atomic:: { AtomicI32 , AtomicU32 , Ordering } ;
13+ use windows_sys:: Win32 :: Foundation :: {
14+ CloseHandle , HANDLE , INVALID_HANDLE_VALUE , WAIT_EVENT , WAIT_OBJECT_0 ,
15+ } ;
716 use windows_sys:: Win32 :: Networking :: WinSock :: { self , SOCKET } ;
17+ use windows_sys:: Win32 :: System :: Threading :: {
18+ CreateSemaphoreW , GetCurrentThreadId , ReleaseSemaphore , WaitForSingleObjectEx ,
19+ } ;
20+
21+ const INFINITE : u32 = 0xFFFFFFFF ;
22+ const WAIT_TIMEOUT : WAIT_EVENT = 258 ; // 0x102
23+ const WAIT_FAILED : WAIT_EVENT = 0xFFFFFFFF ;
24+ const ERROR_TOO_MANY_POSTS : u32 = 298 ;
25+
26+ // These match the values in Lib/multiprocessing/synchronize.py
27+ const RECURSIVE_MUTEX : i32 = 0 ;
28+ const SEMAPHORE : i32 = 1 ;
29+
30+ macro_rules! ismine {
31+ ( $self: expr) => {
32+ $self. count. load( Ordering :: Acquire ) > 0
33+ && $self. last_tid. load( Ordering :: Acquire ) == unsafe { GetCurrentThreadId ( ) }
34+ } ;
35+ }
36+
37+ #[ derive( FromArgs ) ]
38+ struct SemLockNewArgs {
39+ #[ pyarg( positional) ]
40+ kind : i32 ,
41+ #[ pyarg( positional) ]
42+ value : i32 ,
43+ #[ pyarg( positional) ]
44+ maxvalue : i32 ,
45+ #[ pyarg( positional) ]
46+ name : String ,
47+ #[ pyarg( positional) ]
48+ unlink : bool ,
49+ }
50+
51+ #[ pyattr]
52+ #[ pyclass( name = "SemLock" , module = "_multiprocessing" ) ]
53+ #[ derive( Debug , PyPayload ) ]
54+ struct SemLock {
55+ handle : SemHandle ,
56+ kind : i32 ,
57+ maxvalue : i32 ,
58+ name : Option < String > ,
59+ last_tid : AtomicU32 ,
60+ count : AtomicI32 ,
61+ }
62+
63+ #[ derive( Debug ) ]
64+ struct SemHandle {
65+ raw : HANDLE ,
66+ }
67+
68+ unsafe impl Send for SemHandle { }
69+ unsafe impl Sync for SemHandle { }
70+
71+ impl SemHandle {
72+ fn create ( value : i32 , maxvalue : i32 , vm : & VirtualMachine ) -> PyResult < Self > {
73+ let handle =
74+ unsafe { CreateSemaphoreW ( core:: ptr:: null ( ) , value, maxvalue, core:: ptr:: null ( ) ) } ;
75+ if handle == 0 as HANDLE {
76+ return Err ( vm. new_last_os_error ( ) ) ;
77+ }
78+ // Check ERROR_ALREADY_EXISTS
79+ let last_err = unsafe { windows_sys:: Win32 :: Foundation :: GetLastError ( ) } ;
80+ if last_err != 0 {
81+ unsafe { CloseHandle ( handle) } ;
82+ return Err ( vm. new_last_os_error ( ) ) ;
83+ }
84+ Ok ( SemHandle { raw : handle } )
85+ }
86+
87+ #[ inline]
88+ fn as_raw ( & self ) -> HANDLE {
89+ self . raw
90+ }
91+ }
92+
93+ impl Drop for SemHandle {
94+ fn drop ( & mut self ) {
95+ if self . raw != 0 as HANDLE && self . raw != INVALID_HANDLE_VALUE {
96+ unsafe {
97+ CloseHandle ( self . raw ) ;
98+ }
99+ }
100+ }
101+ }
102+
103+ /// _GetSemaphoreValue - get value of semaphore by briefly acquiring and releasing
104+ fn get_semaphore_value ( handle : HANDLE ) -> Result < i32 , ( ) > {
105+ match unsafe { WaitForSingleObjectEx ( handle, 0 , 0 ) } {
106+ WAIT_OBJECT_0 => {
107+ let mut previous: i32 = 0 ;
108+ if unsafe { ReleaseSemaphore ( handle, 1 , & mut previous) } == 0 {
109+ return Err ( ( ) ) ;
110+ }
111+ Ok ( previous + 1 )
112+ }
113+ WAIT_TIMEOUT => Ok ( 0 ) ,
114+ _ => Err ( ( ) ) ,
115+ }
116+ }
117+
118+ #[ pyclass( with( Constructor ) , flags( BASETYPE ) ) ]
119+ impl SemLock {
120+ #[ pygetset]
121+ fn handle ( & self ) -> isize {
122+ self . handle . as_raw ( ) as isize
123+ }
124+
125+ #[ pygetset]
126+ fn kind ( & self ) -> i32 {
127+ self . kind
128+ }
129+
130+ #[ pygetset]
131+ fn maxvalue ( & self ) -> i32 {
132+ self . maxvalue
133+ }
134+
135+ #[ pygetset]
136+ fn name ( & self ) -> Option < String > {
137+ self . name . clone ( )
138+ }
139+
140+ #[ pymethod]
141+ fn acquire ( & self , args : FuncArgs , vm : & VirtualMachine ) -> PyResult < bool > {
142+ let blocking: bool = args
143+ . kwargs
144+ . get ( "block" )
145+ . or_else ( || args. args . first ( ) )
146+ . map ( |o| o. clone ( ) . try_to_bool ( vm) )
147+ . transpose ( ) ?
148+ . unwrap_or ( true ) ;
149+
150+ let timeout_obj = args
151+ . kwargs
152+ . get ( "timeout" )
153+ . or_else ( || args. args . get ( 1 ) )
154+ . cloned ( ) ;
155+
156+ // Calculate timeout in milliseconds
157+ let full_msecs: u32 = if !blocking {
158+ 0
159+ } else if timeout_obj. as_ref ( ) . is_none_or ( |o| vm. is_none ( o) ) {
160+ INFINITE
161+ } else {
162+ let timeout: f64 = timeout_obj. unwrap ( ) . try_float ( vm) ?. to_f64 ( ) ;
163+ let timeout = timeout * 1000.0 ; // convert to ms
164+ if timeout < 0.0 {
165+ 0
166+ } else if timeout >= 0.5 * INFINITE as f64 {
167+ return Err ( vm. new_overflow_error ( "timeout is too large" . to_owned ( ) ) ) ;
168+ } else {
169+ ( timeout + 0.5 ) as u32
170+ }
171+ } ;
172+
173+ // Check whether we already own the lock
174+ if self . kind == RECURSIVE_MUTEX && ismine ! ( self ) {
175+ self . count . fetch_add ( 1 , Ordering :: Release ) ;
176+ return Ok ( true ) ;
177+ }
178+
179+ // Check whether we can acquire without blocking
180+ if unsafe { WaitForSingleObjectEx ( self . handle . as_raw ( ) , 0 , 0 ) } == WAIT_OBJECT_0 {
181+ self . last_tid
182+ . store ( unsafe { GetCurrentThreadId ( ) } , Ordering :: Release ) ;
183+ self . count . fetch_add ( 1 , Ordering :: Release ) ;
184+ return Ok ( true ) ;
185+ }
186+
187+ // Do the wait
188+ let res = unsafe { WaitForSingleObjectEx ( self . handle . as_raw ( ) , full_msecs, 0 ) } ;
189+
190+ match res {
191+ WAIT_TIMEOUT => Ok ( false ) ,
192+ WAIT_OBJECT_0 => {
193+ self . last_tid
194+ . store ( unsafe { GetCurrentThreadId ( ) } , Ordering :: Release ) ;
195+ self . count . fetch_add ( 1 , Ordering :: Release ) ;
196+ Ok ( true )
197+ }
198+ WAIT_FAILED => Err ( vm. new_last_os_error ( ) ) ,
199+ _ => Err ( vm. new_runtime_error ( format ! (
200+ "WaitForSingleObject() gave unrecognized value {res}"
201+ ) ) ) ,
202+ }
203+ }
204+
205+ #[ pymethod]
206+ fn release ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
207+ if self . kind == RECURSIVE_MUTEX {
208+ if !ismine ! ( self ) {
209+ return Err ( vm. new_exception_msg (
210+ vm. ctx . exceptions . assertion_error . to_owned ( ) ,
211+ "attempt to release recursive lock not owned by thread" . to_owned ( ) ,
212+ ) ) ;
213+ }
214+ if self . count . load ( Ordering :: Acquire ) > 1 {
215+ self . count . fetch_sub ( 1 , Ordering :: Release ) ;
216+ return Ok ( ( ) ) ;
217+ }
218+ }
219+
220+ if unsafe { ReleaseSemaphore ( self . handle . as_raw ( ) , 1 , core:: ptr:: null_mut ( ) ) } == 0 {
221+ let err = unsafe { windows_sys:: Win32 :: Foundation :: GetLastError ( ) } ;
222+ if err == ERROR_TOO_MANY_POSTS {
223+ return Err (
224+ vm. new_value_error ( "semaphore or lock released too many times" . to_owned ( ) )
225+ ) ;
226+ }
227+ return Err ( vm. new_last_os_error ( ) ) ;
228+ }
229+
230+ self . count . fetch_sub ( 1 , Ordering :: Release ) ;
231+ Ok ( ( ) )
232+ }
233+
234+ #[ pymethod( name = "__enter__" ) ]
235+ fn enter ( & self , vm : & VirtualMachine ) -> PyResult < bool > {
236+ self . acquire (
237+ FuncArgs :: new :: < Vec < _ > , KwArgs > (
238+ vec ! [ vm. ctx. new_bool( true ) . into( ) ] ,
239+ KwArgs :: default ( ) ,
240+ ) ,
241+ vm,
242+ )
243+ }
244+
245+ #[ pymethod]
246+ fn __exit__ ( & self , _args : FuncArgs , vm : & VirtualMachine ) -> PyResult < ( ) > {
247+ self . release ( vm)
248+ }
249+
250+ #[ pyclassmethod( name = "_rebuild" ) ]
251+ fn rebuild (
252+ cls : PyTypeRef ,
253+ handle : isize ,
254+ kind : i32 ,
255+ maxvalue : i32 ,
256+ name : Option < String > ,
257+ vm : & VirtualMachine ,
258+ ) -> PyResult {
259+ // On Windows, _rebuild receives the handle directly (no sem_open)
260+ let zelf = SemLock {
261+ handle : SemHandle {
262+ raw : handle as HANDLE ,
263+ } ,
264+ kind,
265+ maxvalue,
266+ name,
267+ last_tid : AtomicU32 :: new ( 0 ) ,
268+ count : AtomicI32 :: new ( 0 ) ,
269+ } ;
270+ zelf. into_ref_with_type ( vm, cls) . map ( Into :: into)
271+ }
272+
273+ #[ pymethod]
274+ fn _after_fork ( & self ) {
275+ self . count . store ( 0 , Ordering :: Release ) ;
276+ self . last_tid . store ( 0 , Ordering :: Release ) ;
277+ }
278+
279+ #[ pymethod]
280+ fn __reduce__ ( & self , vm : & VirtualMachine ) -> PyResult {
281+ Err ( vm. new_type_error ( "cannot pickle 'SemLock' object" . to_owned ( ) ) )
282+ }
283+
284+ #[ pymethod]
285+ fn _count ( & self ) -> i32 {
286+ self . count . load ( Ordering :: Acquire )
287+ }
288+
289+ #[ pymethod]
290+ fn _is_mine ( & self ) -> bool {
291+ ismine ! ( self )
292+ }
293+
294+ #[ pymethod]
295+ fn _get_value ( & self , vm : & VirtualMachine ) -> PyResult < i32 > {
296+ get_semaphore_value ( self . handle . as_raw ( ) ) . map_err ( |_| vm. new_last_os_error ( ) )
297+ }
298+
299+ #[ pymethod]
300+ fn _is_zero ( & self , vm : & VirtualMachine ) -> PyResult < bool > {
301+ let val =
302+ get_semaphore_value ( self . handle . as_raw ( ) ) . map_err ( |_| vm. new_last_os_error ( ) ) ?;
303+ Ok ( val == 0 )
304+ }
305+
306+ #[ extend_class]
307+ fn extend_class ( ctx : & Context , class : & Py < PyType > ) {
308+ class. set_attr (
309+ ctx. intern_str ( "RECURSIVE_MUTEX" ) ,
310+ ctx. new_int ( RECURSIVE_MUTEX ) . into ( ) ,
311+ ) ;
312+ class. set_attr ( ctx. intern_str ( "SEMAPHORE" ) , ctx. new_int ( SEMAPHORE ) . into ( ) ) ;
313+ class. set_attr (
314+ ctx. intern_str ( "SEM_VALUE_MAX" ) ,
315+ ctx. new_int ( i32:: MAX ) . into ( ) ,
316+ ) ;
317+ }
318+ }
319+
320+ impl Constructor for SemLock {
321+ type Args = SemLockNewArgs ;
322+
323+ fn py_new ( _cls : & Py < PyType > , args : Self :: Args , vm : & VirtualMachine ) -> PyResult < Self > {
324+ if args. kind != RECURSIVE_MUTEX && args. kind != SEMAPHORE {
325+ return Err ( vm. new_value_error ( "unrecognized kind" . to_owned ( ) ) ) ;
326+ }
327+ if args. value < 0 || args. value > args. maxvalue {
328+ return Err ( vm. new_value_error ( "invalid value" . to_owned ( ) ) ) ;
329+ }
330+
331+ let handle = SemHandle :: create ( args. value , args. maxvalue , vm) ?;
332+ let name = if args. unlink { None } else { Some ( args. name ) } ;
333+
334+ Ok ( SemLock {
335+ handle,
336+ kind : args. kind ,
337+ maxvalue : args. maxvalue ,
338+ name,
339+ last_tid : AtomicU32 :: new ( 0 ) ,
340+ count : AtomicI32 :: new ( 0 ) ,
341+ } )
342+ }
343+ }
344+
345+ // On Windows, sem_unlink is a no-op
346+ #[ pyfunction]
347+ fn sem_unlink ( _name : String ) { }
348+
349+ #[ pyattr]
350+ fn flags ( vm : & VirtualMachine ) -> PyRef < PyDict > {
351+ // On Windows, no HAVE_SEM_OPEN / HAVE_SEM_TIMEDWAIT / HAVE_BROKEN_SEM_GETVALUE
352+ vm. ctx . new_dict ( )
353+ }
8354
9355 #[ pyfunction]
10356 fn closesocket ( socket : usize , vm : & VirtualMachine ) -> PyResult < ( ) > {
0 commit comments