@@ -325,12 +325,59 @@ mod decl {
325325 pub ( super ) mod poll {
326326 use super :: * ;
327327 use crate :: vm:: {
328- builtins:: PyFloat , common:: lock:: PyMutex , convert:: ToPyObject , function:: OptionalArg ,
329- stdlib:: io:: Fildes , AsObject , PyPayload ,
328+ builtins:: PyFloat ,
329+ common:: lock:: PyMutex ,
330+ convert:: { IntoPyException , ToPyObject } ,
331+ function:: OptionalArg ,
332+ stdlib:: io:: Fildes ,
333+ AsObject , PyPayload ,
330334 } ;
331335 use libc:: pollfd;
332- use num_traits:: ToPrimitive ;
333- use std:: time;
336+ use num_traits:: { Signed , ToPrimitive } ;
337+ use std:: time:: { Duration , Instant } ;
338+
339+ #[ derive( Default ) ]
340+ pub ( super ) struct TimeoutArg < const MILLIS : bool > ( pub Option < Duration > ) ;
341+
342+ impl < const MILLIS : bool > TryFromObject for TimeoutArg < MILLIS > {
343+ fn try_from_object ( vm : & VirtualMachine , obj : PyObjectRef ) -> PyResult < Self > {
344+ let timeout = if vm. is_none ( & obj) {
345+ None
346+ } else if let Some ( float) = obj. payload :: < PyFloat > ( ) {
347+ let float = float. to_f64 ( ) ;
348+ if float. is_nan ( ) {
349+ return Err (
350+ vm. new_value_error ( "Invalid value NaN (not a number)" . to_owned ( ) )
351+ ) ;
352+ }
353+ if float. is_sign_negative ( ) {
354+ None
355+ } else {
356+ let secs = if MILLIS { float * 1000.0 } else { float } ;
357+ Some ( Duration :: from_secs_f64 ( secs) )
358+ }
359+ } else if let Some ( int) = obj. try_index_opt ( vm) . transpose ( ) ? {
360+ if int. as_bigint ( ) . is_negative ( ) {
361+ None
362+ } else {
363+ let n = int. as_bigint ( ) . to_u64 ( ) . ok_or_else ( || {
364+ vm. new_overflow_error ( "value out of range" . to_owned ( ) )
365+ } ) ?;
366+ Some ( if MILLIS {
367+ Duration :: from_millis ( n)
368+ } else {
369+ Duration :: from_secs ( n)
370+ } )
371+ }
372+ } else {
373+ return Err ( vm. new_type_error ( format ! (
374+ "expected an int or float for duration, got {}" ,
375+ obj. class( )
376+ ) ) ) ;
377+ } ;
378+ Ok ( Self ( timeout) )
379+ }
380+ }
334381
335382 #[ pyclass( module = "select" , name = "poll" ) ]
336383 #[ derive( Default , Debug , PyPayload ) ]
@@ -399,50 +446,31 @@ mod decl {
399446 #[ pymethod]
400447 fn poll (
401448 & self ,
402- timeout : OptionalOption ,
449+ timeout : OptionalArg < TimeoutArg < true > > ,
403450 vm : & VirtualMachine ,
404451 ) -> PyResult < Vec < PyObjectRef > > {
405452 let mut fds = self . fds . lock ( ) ;
406- let timeout_ms = match timeout. flatten ( ) {
407- Some ( ms) => {
408- let ms = if let Some ( float) = ms. payload :: < PyFloat > ( ) {
409- float. to_f64 ( ) . to_i32 ( )
410- } else if let Some ( int) = ms. try_index_opt ( vm) {
411- int?. as_bigint ( ) . to_i32 ( )
412- } else {
413- return Err ( vm. new_type_error ( format ! (
414- "expected an int or float for duration, got {}" ,
415- ms. class( )
416- ) ) ) ;
417- } ;
418- ms. ok_or_else ( || vm. new_value_error ( "value out of range" . to_owned ( ) ) ) ?
419- }
420- None => -1 ,
453+ let TimeoutArg ( timeout) = timeout. unwrap_or_default ( ) ;
454+ let timeout_ms = match timeout {
455+ Some ( d) => i32:: try_from ( d. as_millis ( ) )
456+ . map_err ( |_| vm. new_overflow_error ( "value out of range" . to_owned ( ) ) ) ?,
457+ None => -1i32 ,
421458 } ;
422- let timeout_ms = if timeout_ms < 0 { -1 } else { timeout_ms } ;
423- let deadline = ( timeout_ms >= 0 )
424- . then ( || time:: Instant :: now ( ) + time:: Duration :: from_millis ( timeout_ms as u64 ) ) ;
459+ let deadline = timeout. map ( |d| Instant :: now ( ) + d) ;
425460 let mut poll_timeout = timeout_ms;
426461 loop {
427462 let res = unsafe { libc:: poll ( fds. as_mut_ptr ( ) , fds. len ( ) as _ , poll_timeout) } ;
428- let res = if res < 0 {
429- Err ( io:: Error :: last_os_error ( ) )
430- } else {
431- Ok ( ( ) )
432- } ;
433- match res {
434- Ok ( ( ) ) => break ,
435- Err ( e) if e. kind ( ) == io:: ErrorKind :: Interrupted => {
436- vm. check_signals ( ) ?;
437- if let Some ( d) = deadline {
438- match d. checked_duration_since ( time:: Instant :: now ( ) ) {
439- Some ( remaining) => poll_timeout = remaining. as_millis ( ) as i32 ,
440- // we've timed out
441- None => break ,
442- }
443- }
463+ match nix:: Error :: result ( res) {
464+ Ok ( _) => break ,
465+ Err ( nix:: Error :: EINTR ) => vm. check_signals ( ) ?,
466+ Err ( e) => return Err ( e. into_pyexception ( vm) ) ,
467+ }
468+ if let Some ( d) = deadline {
469+ if let Some ( remaining) = d. checked_duration_since ( Instant :: now ( ) ) {
470+ poll_timeout = remaining. as_millis ( ) as i32 ;
471+ } else {
472+ break ;
444473 }
445- Err ( e) => return Err ( e. to_pyexception ( vm) ) ,
446474 }
447475 }
448476 Ok ( fds
@@ -453,4 +481,216 @@ mod decl {
453481 }
454482 }
455483 }
484+
485+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
486+ #[ pyattr( name = "epoll" , once) ]
487+ fn epoll ( vm : & VirtualMachine ) -> PyTypeRef {
488+ use crate :: vm:: class:: PyClassImpl ;
489+ epoll:: PyEpoll :: make_class ( & vm. ctx )
490+ }
491+
492+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
493+ #[ pyattr]
494+ use libc:: {
495+ EPOLLERR , EPOLLEXCLUSIVE , EPOLLHUP , EPOLLIN , EPOLLMSG , EPOLLONESHOT , EPOLLOUT , EPOLLPRI ,
496+ EPOLLRDBAND , EPOLLRDHUP , EPOLLRDNORM , EPOLLWAKEUP , EPOLLWRBAND , EPOLLWRNORM , EPOLL_CLOEXEC ,
497+ } ;
498+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
499+ #[ pyattr]
500+ const EPOLLET : u32 = libc:: EPOLLET as u32 ;
501+
502+ #[ cfg( any( target_os = "linux" , target_os = "android" , target_os = "redox" ) ) ]
503+ pub ( super ) mod epoll {
504+ use super :: * ;
505+ use crate :: vm:: {
506+ builtins:: PyTypeRef ,
507+ common:: lock:: { PyRwLock , PyRwLockReadGuard } ,
508+ convert:: { IntoPyException , ToPyObject } ,
509+ function:: OptionalArg ,
510+ stdlib:: io:: Fildes ,
511+ types:: Constructor ,
512+ PyPayload ,
513+ } ;
514+ use rustix:: event:: epoll:: { self , EventData , EventFlags } ;
515+ use std:: ops:: Deref ;
516+ use std:: os:: fd:: { AsRawFd , IntoRawFd , OwnedFd } ;
517+ use std:: time:: { Duration , Instant } ;
518+
519+ #[ pyclass( module = "select" , name = "epoll" ) ]
520+ #[ derive( Debug , rustpython_vm:: PyPayload ) ]
521+ pub struct PyEpoll {
522+ epoll_fd : PyRwLock < Option < OwnedFd > > ,
523+ }
524+
525+ #[ derive( FromArgs ) ]
526+ pub struct EpollNewArgs {
527+ #[ pyarg( any, default = "-1" ) ]
528+ sizehint : i32 ,
529+ #[ pyarg( any, default = "0" ) ]
530+ flags : i32 ,
531+ }
532+
533+ impl Constructor for PyEpoll {
534+ type Args = EpollNewArgs ;
535+ fn py_new ( cls : PyTypeRef , args : EpollNewArgs , vm : & VirtualMachine ) -> PyResult {
536+ if let ..=-2 | 0 = args. sizehint {
537+ return Err ( vm. new_value_error ( "negative sizehint" . to_owned ( ) ) ) ;
538+ }
539+ if !matches ! ( args. flags, 0 | libc:: EPOLL_CLOEXEC ) {
540+ return Err ( vm. new_os_error ( "invalid flags" . to_owned ( ) ) ) ;
541+ }
542+ Self :: new ( )
543+ . map_err ( |e| e. into_pyexception ( vm) ) ?
544+ . into_ref_with_type ( vm, cls)
545+ . map ( Into :: into)
546+ }
547+ }
548+
549+ #[ derive( FromArgs ) ]
550+ struct EpollPollArgs {
551+ #[ pyarg( any, default ) ]
552+ timeout : poll:: TimeoutArg < false > ,
553+ #[ pyarg( any, default = "-1" ) ]
554+ maxevents : i32 ,
555+ }
556+
557+ #[ pyclass( with( Constructor ) ) ]
558+ impl PyEpoll {
559+ fn new ( ) -> std:: io:: Result < Self > {
560+ let epoll_fd = epoll:: create ( epoll:: CreateFlags :: CLOEXEC ) ?;
561+ let epoll_fd = Some ( epoll_fd) . into ( ) ;
562+ Ok ( PyEpoll { epoll_fd } )
563+ }
564+
565+ #[ pymethod]
566+ fn close ( & self ) -> std:: io:: Result < ( ) > {
567+ let fd = self . epoll_fd . write ( ) . take ( ) ;
568+ if let Some ( fd) = fd {
569+ nix:: unistd:: close ( fd. into_raw_fd ( ) ) ?;
570+ }
571+ Ok ( ( ) )
572+ }
573+
574+ #[ pygetset]
575+ fn closed ( & self ) -> bool {
576+ self . epoll_fd . read ( ) . is_none ( )
577+ }
578+
579+ fn get_epoll (
580+ & self ,
581+ vm : & VirtualMachine ,
582+ ) -> PyResult < impl Deref < Target = OwnedFd > + ' _ > {
583+ PyRwLockReadGuard :: try_map ( self . epoll_fd . read ( ) , |x| x. as_ref ( ) ) . map_err ( |_| {
584+ vm. new_value_error ( "I/O operation on closed epoll object" . to_owned ( ) )
585+ } )
586+ }
587+
588+ #[ pymethod]
589+ fn fileno ( & self , vm : & VirtualMachine ) -> PyResult < i32 > {
590+ self . get_epoll ( vm) . map ( |epoll_fd| epoll_fd. as_raw_fd ( ) )
591+ }
592+
593+ #[ pyclassmethod]
594+ fn fromfd ( cls : PyTypeRef , fd : OwnedFd , vm : & VirtualMachine ) -> PyResult < PyRef < Self > > {
595+ let epoll_fd = Some ( fd) . into ( ) ;
596+ Self { epoll_fd } . into_ref_with_type ( vm, cls)
597+ }
598+
599+ #[ pymethod]
600+ fn register (
601+ & self ,
602+ fd : Fildes ,
603+ eventmask : OptionalArg < u32 > ,
604+ vm : & VirtualMachine ,
605+ ) -> PyResult < ( ) > {
606+ let events = match eventmask {
607+ OptionalArg :: Present ( mask) => EventFlags :: from_bits_retain ( mask) ,
608+ OptionalArg :: Missing => EventFlags :: IN | EventFlags :: PRI | EventFlags :: OUT ,
609+ } ;
610+ let epoll_fd = & * self . get_epoll ( vm) ?;
611+ let data = EventData :: new_u64 ( fd. as_raw_fd ( ) as u64 ) ;
612+ epoll:: add ( epoll_fd, fd, data, events) . map_err ( |e| e. into_pyexception ( vm) )
613+ }
614+
615+ #[ pymethod]
616+ fn modify ( & self , fd : Fildes , eventmask : u32 , vm : & VirtualMachine ) -> PyResult < ( ) > {
617+ let events = EventFlags :: from_bits_retain ( eventmask) ;
618+ let epoll_fd = & * self . get_epoll ( vm) ?;
619+ let data = EventData :: new_u64 ( fd. as_raw_fd ( ) as u64 ) ;
620+ epoll:: modify ( epoll_fd, fd, data, events) . map_err ( |e| e. into_pyexception ( vm) )
621+ }
622+
623+ #[ pymethod]
624+ fn unregister ( & self , fd : Fildes , vm : & VirtualMachine ) -> PyResult < ( ) > {
625+ let epoll_fd = & * self . get_epoll ( vm) ?;
626+ epoll:: delete ( epoll_fd, fd) . map_err ( |e| e. into_pyexception ( vm) )
627+ }
628+
629+ #[ pymethod]
630+ fn poll ( & self , args : EpollPollArgs , vm : & VirtualMachine ) -> PyResult < PyListRef > {
631+ let poll:: TimeoutArg ( timeout) = args. timeout ;
632+ let maxevents = args. maxevents ;
633+
634+ let make_poll_timeout = |d : Duration | i32:: try_from ( d. as_millis ( ) ) ;
635+ let mut poll_timeout = match timeout {
636+ Some ( d) => make_poll_timeout ( d)
637+ . map_err ( |_| vm. new_overflow_error ( "timeout is too large" . to_owned ( ) ) ) ?,
638+ None => -1 ,
639+ } ;
640+
641+ let deadline = timeout. map ( |d| Instant :: now ( ) + d) ;
642+ let maxevents = match maxevents {
643+ ..-1 => {
644+ return Err ( vm. new_value_error ( format ! (
645+ "maxevents must be greater than 0, got {maxevents}"
646+ ) ) )
647+ }
648+ -1 => libc:: FD_SETSIZE - 1 ,
649+ _ => maxevents as usize ,
650+ } ;
651+
652+ let mut events = epoll:: EventVec :: with_capacity ( maxevents) ;
653+
654+ let epoll = & * self . get_epoll ( vm) ?;
655+
656+ loop {
657+ match epoll:: wait ( epoll, & mut events, poll_timeout) {
658+ Ok ( ( ) ) => break ,
659+ Err ( rustix:: io:: Errno :: INTR ) => vm. check_signals ( ) ?,
660+ Err ( e) => return Err ( e. into_pyexception ( vm) ) ,
661+ }
662+ if let Some ( deadline) = deadline {
663+ if let Some ( new_timeout) = deadline. checked_duration_since ( Instant :: now ( ) ) {
664+ poll_timeout = make_poll_timeout ( new_timeout) . unwrap ( ) ;
665+ } else {
666+ break ;
667+ }
668+ }
669+ }
670+
671+ let ret = events
672+ . iter ( )
673+ . map ( |ev| ( ev. data . u64 ( ) as i32 , { ev. flags } . bits ( ) ) . to_pyobject ( vm) )
674+ . collect ( ) ;
675+
676+ Ok ( vm. ctx . new_list ( ret) )
677+ }
678+
679+ #[ pymethod( magic) ]
680+ fn enter ( zelf : PyRef < Self > , vm : & VirtualMachine ) -> PyResult < PyRef < Self > > {
681+ zelf. get_epoll ( vm) ?;
682+ Ok ( zelf)
683+ }
684+
685+ #[ pymethod( magic) ]
686+ fn exit (
687+ & self ,
688+ _exc_type : OptionalArg ,
689+ _exc_value : OptionalArg ,
690+ _exc_tb : OptionalArg ,
691+ ) -> std:: io:: Result < ( ) > {
692+ self . close ( )
693+ }
694+ }
695+ }
456696}
0 commit comments