@@ -994,11 +994,40 @@ impl ExecutingFrame<'_> {
994994 }
995995 bytecode:: Instruction :: GetANext => {
996996 let aiter = self . top_value ( ) ;
997- let awaitable = vm. call_special_method ( aiter, identifier ! ( vm, __anext__) , ( ) ) ?;
998- let awaitable = if awaitable. payload_is :: < PyCoroutine > ( ) {
999- awaitable
997+ let awaitable = if aiter. class ( ) . is ( vm. ctx . types . async_generator ) {
998+ vm. call_special_method ( aiter, identifier ! ( vm, __anext__) , ( ) ) ?
1000999 } else {
1001- vm. call_special_method ( & awaitable, identifier ! ( vm, __await__) , ( ) ) ?
1000+ if !aiter. has_attr ( "__anext__" , vm) . unwrap_or ( false ) {
1001+ // TODO: __anext__ must be protocol
1002+ let msg = format ! (
1003+ "'async for' requires an iterator with __anext__ method, got {:.100}" ,
1004+ aiter. class( ) . name( )
1005+ ) ;
1006+ return Err ( vm. new_type_error ( msg) ) ;
1007+ }
1008+ let next_iter =
1009+ vm. call_special_method ( aiter, identifier ! ( vm, __anext__) , ( ) ) ?;
1010+
1011+ // _PyCoro_GetAwaitableIter in CPython
1012+ fn get_awaitable_iter ( next_iter : & PyObject , vm : & VirtualMachine ) -> PyResult {
1013+ let gen_is_coroutine = |_| {
1014+ // TODO: cpython gen_is_coroutine
1015+ true
1016+ } ;
1017+ if next_iter. class ( ) . is ( vm. ctx . types . coroutine_type )
1018+ || gen_is_coroutine ( next_iter)
1019+ {
1020+ return Ok ( next_iter. to_owned ( ) ) ;
1021+ }
1022+ // TODO: error handling
1023+ vm. call_special_method ( next_iter, identifier ! ( vm, __await__) , ( ) )
1024+ }
1025+ get_awaitable_iter ( & next_iter, vm) . map_err ( |_| {
1026+ vm. new_type_error ( format ! (
1027+ "'async for' received an invalid object from __anext__: {:.200}" ,
1028+ next_iter. class( ) . name( )
1029+ ) )
1030+ } ) ?
10021031 } ;
10031032 self . push_value ( awaitable) ;
10041033 Ok ( None )
0 commit comments