Skip to content

Commit 495d781

Browse files
committed
py: implement UNPACK_EX byte code (for: a, *b, c = d)
1 parent e753d91 commit 495d781

7 files changed

Lines changed: 151 additions & 0 deletions

File tree

py/compile.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,11 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) {
19331933
// optimisation for a, b = c, d; to match CPython's optimisation
19341934
mp_parse_node_struct_t* pns10 = (mp_parse_node_struct_t*)pns1->nodes[0];
19351935
mp_parse_node_struct_t* pns0 = (mp_parse_node_struct_t*)pns->nodes[0];
1936+
if (MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[0], PN_star_expr)
1937+
|| MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[1], PN_star_expr)) {
1938+
// can't optimise when it's a star expression on the lhs
1939+
goto no_optimisation;
1940+
}
19361941
compile_node(comp, pns10->nodes[0]); // rhs
19371942
compile_node(comp, pns10->nodes[1]); // rhs
19381943
EMIT(rot_two);
@@ -1945,6 +1950,12 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) {
19451950
// optimisation for a, b, c = d, e, f; to match CPython's optimisation
19461951
mp_parse_node_struct_t* pns10 = (mp_parse_node_struct_t*)pns1->nodes[0];
19471952
mp_parse_node_struct_t* pns0 = (mp_parse_node_struct_t*)pns->nodes[0];
1953+
if (MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[0], PN_star_expr)
1954+
|| MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[1], PN_star_expr)
1955+
|| MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[2], PN_star_expr)) {
1956+
// can't optimise when it's a star expression on the lhs
1957+
goto no_optimisation;
1958+
}
19481959
compile_node(comp, pns10->nodes[0]); // rhs
19491960
compile_node(comp, pns10->nodes[1]); // rhs
19501961
compile_node(comp, pns10->nodes[2]); // rhs
@@ -1954,6 +1965,7 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) {
19541965
c_assign(comp, pns0->nodes[1], ASSIGN_STORE); // lhs store
19551966
c_assign(comp, pns0->nodes[2], ASSIGN_STORE); // lhs store
19561967
} else {
1968+
no_optimisation:
19571969
compile_node(comp, pns1->nodes[0]); // rhs
19581970
c_assign(comp, pns->nodes[0], ASSIGN_STORE); // lhs store
19591971
}

py/obj.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ mp_obj_t mp_obj_tuple_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const m
440440
// list
441441
mp_obj_t mp_obj_list_append(mp_obj_t self_in, mp_obj_t arg);
442442
void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items);
443+
void mp_obj_list_set_len(mp_obj_t self_in, uint len);
443444
void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value);
444445
mp_obj_t mp_obj_list_sort(uint n_args, const mp_obj_t *args, mp_map_t *kwargs);
445446

py/objlist.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,13 @@ void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items) {
378378
*items = self->items;
379379
}
380380

381+
void mp_obj_list_set_len(mp_obj_t self_in, uint len) {
382+
// trust that the caller knows what it's doing
383+
// TODO realloc if len got much smaller than alloc
384+
mp_obj_list_t *self = self_in;
385+
self->len = len;
386+
}
387+
381388
void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
382389
mp_obj_list_t *self = self_in;
383390
uint i = mp_get_index(self->base.type, self->len, index, false);

py/runtime.c

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,70 @@ void mp_unpack_sequence(mp_obj_t seq_in, uint num, mp_obj_t *items) {
672672
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "too many values to unpack (expected %d)", num));
673673
}
674674

675+
// unpacked items are stored in reverse order into the array pointed to by items
676+
void mp_unpack_ex(mp_obj_t seq_in, uint num_in, mp_obj_t *items) {
677+
uint num_left = num_in & 0xff;
678+
uint num_right = (num_in >> 8) & 0xff;
679+
DEBUG_OP_printf("unpack ex %d %d\n", num_left, num_right);
680+
uint seq_len;
681+
if (MP_OBJ_IS_TYPE(seq_in, &mp_type_tuple) || MP_OBJ_IS_TYPE(seq_in, &mp_type_list)) {
682+
mp_obj_t *seq_items;
683+
if (MP_OBJ_IS_TYPE(seq_in, &mp_type_tuple)) {
684+
mp_obj_tuple_get(seq_in, &seq_len, &seq_items);
685+
} else {
686+
if (num_left == 0 && num_right == 0) {
687+
// *a, = b # sets a to b if b is a list
688+
items[0] = seq_in;
689+
return;
690+
}
691+
mp_obj_list_get(seq_in, &seq_len, &seq_items);
692+
}
693+
if (seq_len < num_left + num_right) {
694+
goto too_short;
695+
}
696+
for (uint i = 0; i < num_right; i++) {
697+
items[i] = seq_items[seq_len - 1 - i];
698+
}
699+
items[num_right] = mp_obj_new_list(seq_len - num_left - num_right, seq_items + num_left);
700+
for (uint i = 0; i < num_left; i++) {
701+
items[num_right + 1 + i] = seq_items[num_left - 1 - i];
702+
}
703+
} else {
704+
// Generic iterable; this gets a bit messy: we unpack known left length to the
705+
// items destination array, then the rest to a dynamically created list. Once the
706+
// iterable is exhausted, we take from this list for the right part of the items.
707+
// TODO Improve to waste less memory in the dynamically created list.
708+
mp_obj_t iterable = mp_getiter(seq_in);
709+
mp_obj_t item;
710+
for (seq_len = 0; seq_len < num_left; seq_len++) {
711+
item = mp_iternext(iterable);
712+
if (item == MP_OBJ_NULL) {
713+
goto too_short;
714+
}
715+
items[num_left + num_right + 1 - 1 - seq_len] = item;
716+
}
717+
mp_obj_t rest = mp_obj_new_list(0, NULL);
718+
while ((item = mp_iternext(iterable)) != MP_OBJ_NULL) {
719+
mp_obj_list_append(rest, item);
720+
}
721+
uint rest_len;
722+
mp_obj_t *rest_items;
723+
mp_obj_list_get(rest, &rest_len, &rest_items);
724+
if (rest_len < num_right) {
725+
goto too_short;
726+
}
727+
items[num_right] = rest;
728+
for (uint i = 0; i < num_right; i++) {
729+
items[num_right - 1 - i] = rest_items[rest_len - num_right + i];
730+
}
731+
mp_obj_list_set_len(rest, rest_len - num_right);
732+
}
733+
return;
734+
735+
too_short:
736+
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "need more than %d values to unpack", seq_len));
737+
}
738+
675739
mp_obj_t mp_load_attr(mp_obj_t base, qstr attr) {
676740
DEBUG_OP_printf("load attr %p.%s\n", base, qstr_str(attr));
677741
// use load_method

py/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ mp_obj_t mp_call_method_n_kw(uint n_args, uint n_kw, const mp_obj_t *args);
4444
mp_obj_t mp_call_method_n_kw_var(bool have_self, uint n_args_n_kw, const mp_obj_t *args);
4545

4646
void mp_unpack_sequence(mp_obj_t seq, uint num, mp_obj_t *items);
47+
void mp_unpack_ex(mp_obj_t seq, uint num, mp_obj_t *items);
4748
mp_obj_t mp_store_map(mp_obj_t map, mp_obj_t key, mp_obj_t value);
4849
mp_obj_t mp_load_attr(mp_obj_t base, qstr attr);
4950
void mp_load_method(mp_obj_t base, qstr attr, mp_obj_t *dest);

py/vm.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,12 @@ mp_vm_return_kind_t mp_execute_byte_code_2(const byte *code_info, const byte **i
653653
sp += unum - 1;
654654
break;
655655

656+
case MP_BC_UNPACK_EX:
657+
DECODE_UINT;
658+
mp_unpack_ex(sp[0], unum, sp);
659+
sp += (unum & 0xff) + ((unum >> 8) & 0xff);
660+
break;
661+
656662
case MP_BC_MAKE_FUNCTION:
657663
DECODE_UINT;
658664
PUSH(mp_make_function_from_id(unum, MP_OBJ_NULL, MP_OBJ_NULL));

tests/basics/unpack1.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# unpack sequences
2+
3+
a, = 1, ; print(a)
4+
a, b = 2, 3 ; print(a, b)
5+
6+
a, b = range(2); print(a, b)
7+
8+
# with star
9+
10+
*a, = () ; print(a)
11+
*a, = 4, ; print(a)
12+
*a, = 5, 6 ; print(a)
13+
14+
*a, b = 7, ; print(a, b)
15+
*a, b = 8, 9 ; print(a, b)
16+
*a, b = 10, 11, 12 ; print(a, b)
17+
18+
a, *b = 13, ; print(a, b)
19+
a, *b = 14, 15 ; print(a, b)
20+
a, *b = 16, 17, 18 ; print(a, b)
21+
22+
a, *b, c = 19, 20 ; print(a, b)
23+
a, *b, c = 21, 22, 23 ; print(a, b)
24+
a, *b, c = 24, 25, 26, 27 ; print(a, b)
25+
26+
a = [28, 29]
27+
*b, = a
28+
print(a, b, a == b)
29+
30+
try:
31+
a, *b, c = (30,)
32+
except ValueError:
33+
print("ValueError")
34+
35+
# with star and generic iterator
36+
37+
*a, = range(5) ; print(a)
38+
*a, b = range(5) ; print(a, b)
39+
*a, b, c = range(5) ; print(a, b, c)
40+
a, *b = range(5) ; print(a, b)
41+
a, *b, c = range(5) ; print(a, b, c)
42+
a, *b, c, d = range(5) ; print(a, b, c, d)
43+
a, b, *c = range(5) ; print(a, b, c)
44+
a, b, *c, d = range(5) ; print(a, b, c, d)
45+
a, b, *c, d, e = range(5) ; print(a, b, c, d, e)
46+
47+
*a, = [x * 2 for x in [1, 2, 3, 4]] ; print(a)
48+
*a, b = [x * 2 for x in [1, 2, 3, 4]] ; print(a, b)
49+
a, *b = [x * 2 for x in [1, 2, 3, 4]] ; print(a, b)
50+
a, *b, c = [x * 2 for x in [1, 2, 3, 4]]; print(a, b, c)
51+
52+
try:
53+
a, *b, c = range(0)
54+
except ValueError:
55+
print("ValueError")
56+
57+
try:
58+
a, *b, c = range(1)
59+
except ValueError:
60+
print("ValueError")

0 commit comments

Comments
 (0)