Skip to content

Commit f39d3b9

Browse files
committed
py: Implement support for generalized generator protocol.
Iterators and ducktype objects can now be arguments of yield from.
1 parent a30cf9f commit f39d3b9

7 files changed

Lines changed: 126 additions & 9 deletions

File tree

py/bc.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
typedef enum {
2-
MP_VM_RETURN_NORMAL,
3-
MP_VM_RETURN_YIELD,
4-
MP_VM_RETURN_EXCEPTION,
5-
} mp_vm_return_kind_t;
6-
71
// Exception stack entry
82
typedef struct _mp_exc_stack {
93
const byte *handler;

py/objgenerator.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,14 @@ STATIC mp_obj_t gen_resume_and_raise(mp_obj_t self_in, mp_obj_t send_value, mp_o
135135
return ret;
136136

137137
case MP_VM_RETURN_EXCEPTION:
138-
nlr_jump(ret);
138+
// TODO: Optimization of returning MP_OBJ_NULL is really part
139+
// of mp_iternext() protocol, but this function is called by other methods
140+
// too, which may not handled MP_OBJ_NULL.
141+
if (mp_obj_is_subclass_fast(mp_obj_get_type(ret), &mp_type_StopIteration)) {
142+
return MP_OBJ_NULL;
143+
} else {
144+
nlr_jump(ret);
145+
}
139146

140147
default:
141148
assert(0);

py/runtime.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "builtintables.h"
1818
#include "bc.h"
1919
#include "intdivmod.h"
20+
#include "objgenerator.h"
2021

2122
#if 0 // print debugging info
2223
#define DEBUG_PRINT (1)
@@ -903,6 +904,62 @@ mp_obj_t mp_iternext(mp_obj_t o_in) {
903904
}
904905
}
905906

907+
// TODO: Unclear what to do with StopIterarion exception here.
908+
mp_vm_return_kind_t mp_resume(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value, mp_obj_t *ret_val) {
909+
mp_obj_type_t *type = mp_obj_get_type(self_in);
910+
911+
if (type == &mp_type_gen_instance) {
912+
return mp_obj_gen_resume(self_in, send_value, throw_value, ret_val);
913+
}
914+
915+
if (type->iternext != NULL && send_value == mp_const_none) {
916+
mp_obj_t ret = type->iternext(self_in);
917+
if (ret != MP_OBJ_NULL) {
918+
*ret_val = ret;
919+
return MP_VM_RETURN_YIELD;
920+
} else {
921+
// Emulate raise StopIteration()
922+
// Special case, handled in vm.c
923+
*ret_val = MP_OBJ_NULL;
924+
return MP_VM_RETURN_NORMAL;
925+
}
926+
}
927+
928+
mp_obj_t dest[3]; // Reserve slot for send() arg
929+
930+
if (send_value == mp_const_none) {
931+
mp_load_method_maybe(self_in, MP_QSTR___next__, dest);
932+
if (dest[0] != MP_OBJ_NULL) {
933+
*ret_val = mp_call_method_n_kw(0, 0, dest);
934+
return MP_VM_RETURN_YIELD;
935+
}
936+
}
937+
938+
if (send_value != MP_OBJ_NULL) {
939+
mp_load_method(self_in, MP_QSTR_send, dest);
940+
dest[2] = send_value;
941+
*ret_val = mp_call_method_n_kw(1, 0, dest);
942+
return MP_VM_RETURN_YIELD;
943+
}
944+
945+
if (throw_value != MP_OBJ_NULL) {
946+
if (mp_obj_is_subclass_fast(mp_obj_get_type(throw_value), &mp_type_GeneratorExit)) {
947+
mp_load_method_maybe(self_in, MP_QSTR_close, dest);
948+
if (dest[0] != MP_OBJ_NULL) {
949+
*ret_val = mp_call_method_n_kw(0, 0, dest);
950+
// We assume one can't "yield" from close()
951+
return MP_VM_RETURN_NORMAL;
952+
}
953+
}
954+
mp_load_method(self_in, MP_QSTR_throw, dest);
955+
*ret_val = mp_call_method_n_kw(1, 0, &throw_value);
956+
return MP_VM_RETURN_YIELD;
957+
}
958+
959+
assert(0);
960+
return MP_VM_RETURN_NORMAL; // Should be unreachable
961+
}
962+
906963
mp_obj_t mp_make_raise_obj(mp_obj_t o) {
907964
DEBUG_printf("raise %p\n", o);
908965
if (mp_obj_is_exception_type(o)) {

py/runtime.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
typedef enum {
2+
MP_VM_RETURN_NORMAL,
3+
MP_VM_RETURN_YIELD,
4+
MP_VM_RETURN_EXCEPTION,
5+
} mp_vm_return_kind_t;
6+
17
void mp_init(void);
28
void mp_deinit(void);
39

@@ -55,6 +61,7 @@ void mp_store_subscr(mp_obj_t base, mp_obj_t index, mp_obj_t val);
5561
mp_obj_t mp_getiter(mp_obj_t o);
5662
mp_obj_t mp_iternext_allow_raise(mp_obj_t o); // may return MP_OBJ_NULL instead of raising StopIteration()
5763
mp_obj_t mp_iternext(mp_obj_t o); // will always return MP_OBJ_NULL instead of raising StopIteration(...)
64+
mp_vm_return_kind_t mp_resume(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value, mp_obj_t *ret_val);
5865

5966
mp_obj_t mp_make_raise_obj(mp_obj_t o);
6067

py/vm.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,9 +808,9 @@ mp_vm_return_kind_t mp_execute_byte_code_2(const byte *code_info, const byte **i
808808
if (inject_exc != MP_OBJ_NULL) {
809809
t_exc = inject_exc;
810810
inject_exc = MP_OBJ_NULL;
811-
ret_kind = mp_obj_gen_resume(TOP(), mp_const_none, t_exc, &obj2);
811+
ret_kind = mp_resume(TOP(), mp_const_none, t_exc, &obj2);
812812
} else {
813-
ret_kind = mp_obj_gen_resume(TOP(), obj1, MP_OBJ_NULL, &obj2);
813+
ret_kind = mp_resume(TOP(), obj1, MP_OBJ_NULL, &obj2);
814814
}
815815

816816
if (ret_kind == MP_VM_RETURN_YIELD) {
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
class MyGen:
2+
3+
def __init__(self):
4+
self.v = 0
5+
6+
def __iter__(self):
7+
return self
8+
9+
def __next__(self):
10+
self.v += 1
11+
if self.v > 5:
12+
raise StopIteration
13+
return self.v
14+
15+
def gen():
16+
yield from MyGen()
17+
18+
def gen2():
19+
yield from gen()
20+
21+
print(list(gen()))
22+
print(list(gen2()))
23+
24+
25+
class Incrementer:
26+
27+
def __iter__(self):
28+
return self
29+
30+
def __next__(self):
31+
return self.send(None)
32+
33+
def send(self, val):
34+
if val is None:
35+
return "Incrementer initialized"
36+
return val + 1
37+
38+
def gen3():
39+
yield from Incrementer()
40+
41+
g = gen3()
42+
print(next(g))
43+
print(g.send(5))
44+
print(g.send(100))
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def gen():
2+
yield from (1, 2, 3)
3+
4+
def gen2():
5+
yield from gen()
6+
7+
print(list(gen()))
8+
print(list(gen2()))

0 commit comments

Comments
 (0)