33pub ( crate ) use _socket:: module_def;
44
55#[ cfg( feature = "ssl" ) ]
6- pub ( super ) use _socket:: { PySocket , SelectKind , sock_select , timeout_error_msg} ;
6+ pub ( super ) use _socket:: { PySocket , SockWaitKind , sock_wait , timeout_error_msg} ;
77
88#[ pymodule]
99mod _socket {
@@ -1060,20 +1060,20 @@ mod _socket {
10601060 fn sock_op < F , R > (
10611061 & self ,
10621062 vm : & VirtualMachine ,
1063- select : SelectKind ,
1063+ wait_kind : SockWaitKind ,
10641064 f : F ,
10651065 ) -> Result < R , IoOrPyException >
10661066 where
10671067 F : FnMut ( ) -> io:: Result < R > ,
10681068 {
10691069 let timeout = self . get_timeout ( ) . ok ( ) ;
1070- self . sock_op_timeout_err ( vm, select , timeout, f)
1070+ self . sock_op_timeout_err ( vm, wait_kind , timeout, f)
10711071 }
10721072
10731073 fn sock_op_timeout_err < F , R > (
10741074 & self ,
10751075 vm : & VirtualMachine ,
1076- select : SelectKind ,
1076+ wait_kind : SockWaitKind ,
10771077 timeout : Option < Duration > ,
10781078 mut f : F ,
10791079 ) -> Result < R , IoOrPyException >
@@ -1083,19 +1083,9 @@ mod _socket {
10831083 let deadline = timeout. map ( Deadline :: new) ;
10841084
10851085 loop {
1086- if deadline. is_some ( ) || matches ! ( select, SelectKind :: Connect ) {
1087- let interval = deadline. as_ref ( ) . map ( |d| d. time_until ( ) ) . transpose ( ) ?;
1086+ if deadline. is_some ( ) || matches ! ( wait_kind, SockWaitKind :: Connect ) {
10881087 let sock = self . sock ( ) ?;
1089- let res = vm. allow_threads ( || sock_select ( & sock, select, interval) ) ;
1090- match res {
1091- Ok ( true ) => return Err ( IoOrPyException :: Timeout ) ,
1092- Err ( e) if e. kind ( ) == io:: ErrorKind :: Interrupted => {
1093- vm. check_signals ( ) ?;
1094- continue ;
1095- }
1096- Err ( e) => return Err ( e. into ( ) ) ,
1097- Ok ( false ) => { } // no timeout, continue as normal
1098- }
1088+ sock_wait_deadline ( & sock, wait_kind, & deadline, vm) ?;
10991089 }
11001090
11011091 let err = loop {
@@ -1339,16 +1329,11 @@ mod _socket {
13391329 } ;
13401330
13411331 if wait_connect {
1342- // basically, connect() is async, and it registers an "error" on the socket when it's
1343- // done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up
1344- // from poll and the error is EISCONN then we know that the connect is done
1345- self . sock_op ( vm, SelectKind :: Connect , || {
1332+ self . sock_op ( vm, SockWaitKind :: Connect , || {
13461333 let sock = self . sock ( ) ?;
13471334 let err = sock. take_error ( ) ?;
13481335 match err {
1349- Some ( e) if e. posix_errno ( ) == libc:: EISCONN => Ok ( ( ) ) ,
13501336 Some ( e) => Err ( e) ,
1351- // TODO: is this accurate?
13521337 None => Ok ( ( ) ) ,
13531338 }
13541339 } )
@@ -1587,7 +1572,8 @@ mod _socket {
15871572 ) -> Result < ( RawSocket , PyObjectRef ) , IoOrPyException > {
15881573 // Use accept_raw() instead of accept() to avoid socket2's set_common_flags()
15891574 // which tries to set SO_NOSIGPIPE and fails with EINVAL on Unix domain sockets on macOS
1590- let ( sock, addr) = self . sock_op ( vm, SelectKind :: Read , || self . sock ( ) ?. accept_raw ( ) ) ?;
1575+ let ( sock, addr) =
1576+ self . sock_op ( vm, SockWaitKind :: Read , || self . sock ( ) ?. accept_raw ( ) ) ?;
15911577 let fd = into_sock_fileno ( sock) ;
15921578 Ok ( ( fd, get_addr_tuple ( & addr, vm) ) )
15931579 }
@@ -1602,7 +1588,7 @@ mod _socket {
16021588 let flags = flags. unwrap_or ( 0 ) ;
16031589 let mut buffer = Vec :: with_capacity ( bufsize) ;
16041590 let sock = self . sock ( ) ?;
1605- let n = self . sock_op ( vm, SelectKind :: Read , || {
1591+ let n = self . sock_op ( vm, SockWaitKind :: Read , || {
16061592 sock. recv_with_flags ( buffer. spare_capacity_mut ( ) , flags)
16071593 } ) ?;
16081594 unsafe { buffer. set_len ( n) } ;
@@ -1633,7 +1619,7 @@ mod _socket {
16331619 } ;
16341620
16351621 let buf = & mut buf[ ..read_len] ;
1636- self . sock_op ( vm, SelectKind :: Read , || {
1622+ self . sock_op ( vm, SockWaitKind :: Read , || {
16371623 sock. recv_with_flags ( unsafe { slice_as_uninit ( buf) } , flags)
16381624 } )
16391625 }
@@ -1650,7 +1636,7 @@ mod _socket {
16501636 . to_usize ( )
16511637 . ok_or_else ( || vm. new_value_error ( "negative buffersize in recvfrom" ) ) ?;
16521638 let mut buffer = Vec :: with_capacity ( bufsize) ;
1653- let ( n, addr) = self . sock_op ( vm, SelectKind :: Read , || {
1639+ let ( n, addr) = self . sock_op ( vm, SockWaitKind :: Read , || {
16541640 self . sock ( ) ?
16551641 . recv_from_with_flags ( buffer. spare_capacity_mut ( ) , flags)
16561642 } ) ?;
@@ -1681,7 +1667,7 @@ mod _socket {
16811667 } ;
16821668 let flags = flags. unwrap_or ( 0 ) ;
16831669 let sock = self . sock ( ) ?;
1684- let ( n, addr) = self . sock_op ( vm, SelectKind :: Read , || {
1670+ let ( n, addr) = self . sock_op ( vm, SockWaitKind :: Read , || {
16851671 sock. recv_from_with_flags ( unsafe { slice_as_uninit ( buf) } , flags)
16861672 } ) ?;
16871673 Ok ( ( n, get_addr_tuple ( & addr, vm) ) )
@@ -1697,7 +1683,7 @@ mod _socket {
16971683 let flags = flags. unwrap_or ( 0 ) ;
16981684 let buf = bytes. borrow_buf ( ) ;
16991685 let buf = & * buf;
1700- self . sock_op ( vm, SelectKind :: Write , || {
1686+ self . sock_op ( vm, SockWaitKind :: Write , || {
17011687 self . sock ( ) ?. send_with_flags ( buf, flags)
17021688 } )
17031689 }
@@ -1721,7 +1707,7 @@ mod _socket {
17211707 // now we have like 3 layers of interrupt loop :)
17221708 while buf_offset < buf. len ( ) {
17231709 let interval = deadline. as_ref ( ) . map ( |d| d. time_until ( ) ) . transpose ( ) ?;
1724- self . sock_op_timeout_err ( vm, SelectKind :: Write , interval, || {
1710+ self . sock_op_timeout_err ( vm, SockWaitKind :: Write , interval, || {
17251711 let subbuf = & buf[ buf_offset..] ;
17261712 buf_offset += self . sock ( ) ?. send_with_flags ( subbuf, flags) ?;
17271713 Ok ( ( ) )
@@ -1754,7 +1740,7 @@ mod _socket {
17541740 let addr = self . extract_address ( address, "sendto" , vm) ?;
17551741 let buf = bytes. borrow_buf ( ) ;
17561742 let buf = & * buf;
1757- self . sock_op ( vm, SelectKind :: Write , || {
1743+ self . sock_op ( vm, SockWaitKind :: Write , || {
17581744 self . sock ( ) ?. send_to_with_flags ( buf, & addr, flags)
17591745 } )
17601746 }
@@ -1812,7 +1798,7 @@ mod _socket {
18121798 }
18131799 }
18141800
1815- self . sock_op ( vm, SelectKind :: Write , || {
1801+ self . sock_op ( vm, SockWaitKind :: Write , || {
18161802 let sock = self . sock ( ) ?;
18171803 sock. sendmsg ( & msg, flags)
18181804 } )
@@ -1848,7 +1834,7 @@ mod _socket {
18481834 . collect :: < Vec < _ > > ( ) ;
18491835 let iv = iv. map ( |iv| iv. borrow_buf ( ) . to_vec ( ) ) ;
18501836
1851- self . sock_op ( vm, SelectKind :: Write , || {
1837+ self . sock_op ( vm, SockWaitKind :: Write , || {
18521838 let sock = self . sock ( ) ?;
18531839 let fd = unsafe { BorrowedFd :: borrow_raw ( sock_fileno ( & sock) ) } ;
18541840 host_socket:: sendmsg_afalg ( fd, & buffers, op, iv. as_deref ( ) , assoclen, flags)
@@ -1881,7 +1867,7 @@ mod _socket {
18811867 let flags = flags. unwrap_or ( 0 ) ;
18821868
18831869 let msg = self
1884- . sock_op ( vm, SelectKind :: Read , || {
1870+ . sock_op ( vm, SockWaitKind :: Read , || {
18851871 let sock = self . sock ( ) ?;
18861872 let fd = unsafe { std:: os:: fd:: BorrowedFd :: borrow_raw ( sock_fileno ( & sock) ) } ;
18871873 host_socket:: recvmsg ( fd, bufsize, ancbufsize, flags)
@@ -2436,61 +2422,135 @@ mod _socket {
24362422 }
24372423
24382424 #[ derive( Copy , Clone ) ]
2439- pub ( crate ) enum SelectKind {
2425+ pub ( crate ) enum SockWaitKind {
24402426 Read ,
24412427 Write ,
24422428 Connect ,
24432429 }
24442430
2445- /// returns true if timed out
2446- pub ( crate ) fn sock_select (
2431+ /// returns Ok(true) on timeout
2432+ pub ( crate ) fn sock_wait (
2433+ sock : & Socket ,
2434+ wait_kind : SockWaitKind ,
2435+ timeout : Option < Duration > ,
2436+ vm : & VirtualMachine ,
2437+ ) -> PyResult < bool > {
2438+ match sock_wait_deadline ( sock, wait_kind, & timeout. map ( Deadline :: new) , vm) {
2439+ Ok ( ( ) ) => Ok ( false ) ,
2440+ Err ( IoOrPyException :: Timeout ) => Ok ( true ) ,
2441+ Err ( e) => Err ( e. into_pyexception ( vm) ) ,
2442+ }
2443+ }
2444+
2445+ /// returns Err(IoOrPyException::Timeout) on timeout
2446+ fn sock_wait_deadline (
24472447 sock : & Socket ,
2448- kind : SelectKind ,
2449- interval : Option < Duration > ,
2450- ) -> io:: Result < bool > {
2448+ wait_kind : SockWaitKind ,
2449+ deadline : & Option < Deadline > ,
2450+ vm : & VirtualMachine ,
2451+ ) -> Result < ( ) , IoOrPyException > {
24512452 #[ cfg( unix) ]
24522453 {
2453- use std:: os:: fd:: AsFd ;
2454- let kind = match kind {
2455- SelectKind :: Read => host_socket:: PollKind :: Read ,
2456- SelectKind :: Write => host_socket:: PollKind :: Write ,
2457- SelectKind :: Connect => host_socket:: PollKind :: Connect ,
2458- } ;
2459- host_socket:: poll_socket ( sock. as_fd ( ) , kind, interval)
2454+ use rustpython_host_env:: select:: { PollFd , poll_fds} ;
2455+
2456+ let mut events = 0 ;
2457+ if matches ! ( wait_kind, SockWaitKind :: Read ) {
2458+ events |= libc:: POLLIN | libc:: POLLPRI ;
2459+ }
2460+ if matches ! ( wait_kind, SockWaitKind :: Write | SockWaitKind :: Connect ) {
2461+ events |= libc:: POLLOUT ;
2462+ }
2463+ let mut fds = [ PollFd {
2464+ fd : sock_fileno ( sock) ,
2465+ events,
2466+ revents : 0 ,
2467+ } ; 1 ] ;
2468+
2469+ loop {
2470+ let ( timeout, is_capped) = deadline
2471+ . as_ref ( )
2472+ . map ( |d| {
2473+ d. time_until ( ) . map ( |t| {
2474+ let timeout_ms = t. as_millis ( ) ;
2475+ let is_capped = timeout_ms > i32:: MAX as u128 ;
2476+ let timeout = if is_capped {
2477+ i32:: MAX
2478+ } else {
2479+ timeout_ms as i32
2480+ } ;
2481+ ( timeout, is_capped)
2482+ } )
2483+ } )
2484+ . transpose ( ) ?
2485+ . unwrap_or ( ( -1 , false ) ) ;
2486+
2487+ match vm. allow_threads ( || poll_fds ( & mut fds, timeout) ) {
2488+ Ok ( 0 ) => {
2489+ if is_capped {
2490+ continue ;
2491+ }
2492+ break Err ( IoOrPyException :: Timeout ) ;
2493+ }
2494+
2495+ Ok ( _) => {
2496+ if fds[ 0 ] . revents & libc:: POLLNVAL != 0 {
2497+ break Err ( io:: Error :: from_raw_os_error ( libc:: EBADF ) . into ( ) ) ;
2498+ }
2499+ break Ok ( ( ) ) ;
2500+ }
2501+
2502+ Err ( e) => {
2503+ if e. kind ( ) == io:: ErrorKind :: Interrupted {
2504+ vm. check_signals ( ) ?;
2505+ continue ;
2506+ }
2507+ break Err ( e. into ( ) ) ;
2508+ }
2509+ }
2510+ }
24602511 }
24612512 #[ cfg( windows) ]
24622513 {
2463- use rustpython_host_env:: select as host_select ;
2514+ use rustpython_host_env:: select:: { FdSet , select , timeval } ;
24642515
2465- let fd = sock_fileno ( sock) ;
2516+ let fd = sock_fileno ( sock) as usize ;
24662517
2467- let mut reads = host_select :: FdSet :: new ( ) ;
2468- let mut writes = host_select :: FdSet :: new ( ) ;
2469- let mut errs = host_select :: FdSet :: new ( ) ;
2518+ let mut reads = FdSet :: new ( ) ;
2519+ let mut writes = FdSet :: new ( ) ;
2520+ let mut errs = FdSet :: new ( ) ;
24702521
2471- let fd = fd as usize ;
2472- match kind {
2473- SelectKind :: Read => reads. insert ( fd) ,
2474- SelectKind :: Write => writes. insert ( fd) ,
2475- SelectKind :: Connect => {
2476- writes. insert ( fd) ;
2477- errs. insert ( fd) ;
2478- }
2522+ if matches ! ( wait_kind, SockWaitKind :: Read ) {
2523+ reads. insert ( fd) ;
2524+ errs. insert ( fd) ;
2525+ }
2526+ if matches ! ( wait_kind, SockWaitKind :: Write | SockWaitKind :: Connect ) {
2527+ writes. insert ( fd) ;
2528+ errs. insert ( fd) ;
24792529 }
24802530
2481- let mut interval = interval. map ( |dur| host_select:: timeval {
2482- tv_sec : dur. as_secs ( ) as _ ,
2483- tv_usec : dur. subsec_micros ( ) as _ ,
2484- } ) ;
2485-
2486- host_select:: select (
2487- fd as i32 + 1 ,
2488- & mut reads,
2489- & mut writes,
2490- & mut errs,
2491- interval. as_mut ( ) ,
2492- )
2493- . map ( |ret| ret == 0 )
2531+ let mut timeout = deadline
2532+ . as_ref ( )
2533+ . map ( |d| {
2534+ d. time_until ( ) . map ( |dur| timeval {
2535+ tv_sec : dur. as_secs ( ) as _ ,
2536+ tv_usec : dur. subsec_micros ( ) as _ ,
2537+ } )
2538+ } )
2539+ . transpose ( ) ?;
2540+
2541+ match vm. allow_threads ( || {
2542+ select (
2543+ 0 , // nfds is ignored on windows
2544+ & mut reads,
2545+ & mut writes,
2546+ & mut errs,
2547+ timeout. as_mut ( ) ,
2548+ )
2549+ } ) {
2550+ Ok ( 0 ) => Err ( IoOrPyException :: Timeout ) ,
2551+ Ok ( _) => Ok ( ( ) ) ,
2552+ Err ( e) => Err ( e. into ( ) ) ,
2553+ }
24942554 }
24952555 }
24962556
0 commit comments