|
4 | 4 | use crate::{ |
5 | 5 | AsObject, Py, PyObject, PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine, |
6 | 6 | builtins::{ |
7 | | - PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyTuple, |
8 | | - PyTupleRef, PyType, PyTypeRef, PyUtf8Str, pystr::AsPyStr, |
| 7 | + PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyTuple, PyTupleRef, |
| 8 | + PyType, PyTypeRef, PyUtf8Str, pystr::AsPyStr, |
9 | 9 | }, |
10 | 10 | common::{hash::PyHash, str::to_ascii}, |
11 | 11 | convert::{ToPyObject, ToPyResult}, |
@@ -87,11 +87,37 @@ impl PyObject { |
87 | 87 |
|
88 | 88 | // PyObject *PyObject_GetAIter(PyObject *o) |
89 | 89 | pub fn get_aiter(&self, vm: &VirtualMachine) -> PyResult { |
90 | | - if self.downcastable::<PyAsyncGen>() { |
91 | | - vm.call_special_method(self, identifier!(vm, __aiter__), ()) |
92 | | - } else { |
93 | | - Err(vm.new_type_error("wrong argument type")) |
| 90 | + use crate::builtins::PyCoroutine; |
| 91 | + |
| 92 | + // Check if object has __aiter__ method |
| 93 | + let aiter_method = self.class().get_attr(identifier!(vm, __aiter__)); |
| 94 | + let Some(_aiter_method) = aiter_method else { |
| 95 | + return Err(vm.new_type_error(format!( |
| 96 | + "'{}' object is not an async iterable", |
| 97 | + self.class().name() |
| 98 | + ))); |
| 99 | + }; |
| 100 | + |
| 101 | + // Call __aiter__ |
| 102 | + let iterator = vm.call_special_method(self, identifier!(vm, __aiter__), ())?; |
| 103 | + |
| 104 | + // Check that __aiter__ did not return a coroutine |
| 105 | + if iterator.downcast_ref::<PyCoroutine>().is_some() { |
| 106 | + return Err(vm.new_type_error( |
| 107 | + "'async_iterator' object cannot be interpreted as an async iterable; \ |
| 108 | + perhaps you forgot to call aiter()?", |
| 109 | + )); |
| 110 | + } |
| 111 | + |
| 112 | + // Check that the result is an async iterator (has __anext__) |
| 113 | + if !iterator.class().has_attr(identifier!(vm, __anext__)) { |
| 114 | + return Err(vm.new_type_error(format!( |
| 115 | + "'{}' object is not an async iterator", |
| 116 | + iterator.class().name() |
| 117 | + ))); |
94 | 118 | } |
| 119 | + |
| 120 | + Ok(iterator) |
95 | 121 | } |
96 | 122 |
|
97 | 123 | pub fn has_attr<'a>(&self, attr_name: impl AsPyStr<'a>, vm: &VirtualMachine) -> PyResult<bool> { |
|
0 commit comments