Skip to content

Commit 9075002

Browse files
committed
Implement default function arguments (for Python functions).
TODO: Decide if we really need separate bytecode for creating functions with default arguments - we would need same for closures, then there're keywords arguments too. Having all combinations is a small exponential explosion, likely we need just 2 cases - simplest (no defaults, no kw), and full - defaults & kw.
1 parent 532f2c3 commit 9075002

10 files changed

Lines changed: 73 additions & 14 deletions

File tree

py/bc0.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@
9494
#define MP_BC_CALL_METHOD_VAR (0x97) // uint
9595
#define MP_BC_CALL_METHOD_KW (0x98) // uint
9696
#define MP_BC_CALL_METHOD_VAR_KW (0x99) // uint
97+
#define MP_BC_MAKE_FUNCTION_DEFARGS (0x9a) // uint
9798

9899
#define MP_BC_IMPORT_NAME (0xe0) // qstr
99100
#define MP_BC_IMPORT_FROM (0xe1) // qstr
100101
#define MP_BC_IMPORT_STAR (0xe2)
102+

py/compile.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3212,7 +3212,7 @@ mp_obj_t mp_compile(mp_parse_node_t pn, qstr source_file, bool is_repl) {
32123212
return mp_const_true;
32133213
#else
32143214
// return function that executes the outer module
3215-
return rt_make_function_from_id(unique_code_id);
3215+
return rt_make_function_from_id(unique_code_id, MP_OBJ_NULL);
32163216
#endif
32173217
}
32183218
}

py/emitbc.c

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,9 +664,15 @@ static void emit_bc_unpack_ex(emit_t *emit, int n_left, int n_right) {
664664
}
665665

666666
static void emit_bc_make_function(emit_t *emit, scope_t *scope, int n_dict_params, int n_default_params) {
667-
assert(n_default_params == 0 && n_dict_params == 0);
668-
emit_pre(emit, 1);
669-
emit_write_byte_code_byte_uint(emit, MP_BC_MAKE_FUNCTION, scope->unique_code_id);
667+
assert(n_dict_params == 0);
668+
if (n_default_params != 0) {
669+
emit_bc_build_tuple(emit, n_default_params);
670+
emit_pre(emit, 0);
671+
emit_write_byte_code_byte_uint(emit, MP_BC_MAKE_FUNCTION_DEFARGS, scope->unique_code_id);
672+
} else {
673+
emit_pre(emit, 1);
674+
emit_write_byte_code_byte_uint(emit, MP_BC_MAKE_FUNCTION, scope->unique_code_id);
675+
}
670676
}
671677

672678
static void emit_bc_make_closure(emit_t *emit, scope_t *scope, int n_dict_params, int n_default_params) {

py/obj.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ mp_obj_t mp_obj_new_exception_msg_2_args(qstr id, const char *fmt, const char *a
221221
mp_obj_t mp_obj_new_exception_msg_varg(qstr id, const char *fmt, ...); // counts args by number of % symbols in fmt, excluding %%; can only handle void* sizes (ie no float/double!)
222222
mp_obj_t mp_obj_new_range(int start, int stop, int step);
223223
mp_obj_t mp_obj_new_range_iterator(int cur, int stop, int step);
224-
mp_obj_t mp_obj_new_fun_bc(int n_args, uint n_state, const byte *code);
224+
mp_obj_t mp_obj_new_fun_bc(int n_args, mp_obj_t def_args, uint n_state, const byte *code);
225225
mp_obj_t mp_obj_new_fun_asm(uint n_args, void *fun);
226226
mp_obj_t mp_obj_new_gen_wrap(mp_obj_t fun);
227227
mp_obj_t mp_obj_new_gen_instance(const byte *bytecode, uint n_state, int n_args, const mp_obj_t *args);

py/objfun.c

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "mpconfig.h"
99
#include "qstr.h"
1010
#include "obj.h"
11+
#include "objtuple.h"
1112
#include "map.h"
1213
#include "runtime.h"
1314
#include "bc.h"
@@ -136,21 +137,32 @@ mp_obj_t rt_make_function_var_between(int n_args_min, int n_args_max, mp_fun_var
136137
typedef struct _mp_obj_fun_bc_t {
137138
mp_obj_base_t base;
138139
mp_map_t *globals; // the context within which this function was defined
139-
int n_args; // number of arguments this function takes
140+
short n_args; // number of arguments this function takes
141+
short n_def_args; // number of default arguments
140142
uint n_state; // total state size for the executing function (incl args, locals, stack)
141143
const byte *bytecode; // bytecode for the function
144+
mp_obj_t def_args[]; // values of default args, if any
142145
} mp_obj_fun_bc_t;
143146

144147
mp_obj_t fun_bc_call(mp_obj_t self_in, uint n_args, uint n_kw, const mp_obj_t *args) {
145148
mp_obj_fun_bc_t *self = self_in;
146149

147-
if (n_args != self->n_args) {
150+
if (n_args < self->n_args - self->n_def_args || n_args > self->n_args) {
148151
nlr_jump(mp_obj_new_exception_msg_2_args(MP_QSTR_TypeError, "function takes %d positional arguments but %d were given", (const char*)(machine_int_t)self->n_args, (const char*)(machine_int_t)n_args));
149152
}
150153
if (n_kw != 0) {
151154
nlr_jump(mp_obj_new_exception_msg(MP_QSTR_TypeError, "function does not take keyword arguments"));
152155
}
153156

157+
mp_obj_t full_args[n_args];
158+
if (n_args < self->n_args) {
159+
memcpy(full_args, args, n_args * sizeof(*args));
160+
int use_def_args = self->n_args - n_args;
161+
memcpy(full_args + n_args, self->def_args + self->n_def_args - use_def_args, use_def_args * sizeof(*args));
162+
args = full_args;
163+
n_args = self->n_args;
164+
}
165+
154166
// optimisation: allow the compiler to optimise this tail call for
155167
// the common case when the globals don't need to be changed
156168
mp_map_t *old_globals = rt_globals_get();
@@ -170,13 +182,22 @@ const mp_obj_type_t fun_bc_type = {
170182
.call = fun_bc_call,
171183
};
172184

173-
mp_obj_t mp_obj_new_fun_bc(int n_args, uint n_state, const byte *code) {
174-
mp_obj_fun_bc_t *o = m_new_obj(mp_obj_fun_bc_t);
185+
mp_obj_t mp_obj_new_fun_bc(int n_args, mp_obj_t def_args_in, uint n_state, const byte *code) {
186+
int n_def_args = 0;
187+
mp_obj_tuple_t *def_args = def_args_in;
188+
if (def_args != MP_OBJ_NULL) {
189+
n_def_args = def_args->len;
190+
}
191+
mp_obj_fun_bc_t *o = m_new_obj_var(mp_obj_fun_bc_t, mp_obj_t, n_def_args);
175192
o->base.type = &fun_bc_type;
176193
o->globals = rt_globals_get();
177194
o->n_args = n_args;
195+
o->n_def_args = n_def_args;
178196
o->n_state = n_state;
179197
o->bytecode = code;
198+
if (def_args != MP_OBJ_NULL) {
199+
memcpy(o->def_args, def_args->items, n_def_args * sizeof(*o->def_args));
200+
}
180201
return o;
181202
}
182203

py/runtime.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
674674
return mp_const_none;
675675
}
676676

677-
mp_obj_t rt_make_function_from_id(int unique_code_id) {
677+
mp_obj_t rt_make_function_from_id(int unique_code_id, mp_obj_t def_args) {
678678
DEBUG_OP_printf("make_function_from_id %d\n", unique_code_id);
679679
if (unique_code_id < 1 || unique_code_id >= next_unique_code_id) {
680680
// illegal code id
@@ -686,7 +686,7 @@ mp_obj_t rt_make_function_from_id(int unique_code_id) {
686686
mp_obj_t fun;
687687
switch (c->kind) {
688688
case MP_CODE_BYTE:
689-
fun = mp_obj_new_fun_bc(c->n_args, c->n_state, c->u_byte.code);
689+
fun = mp_obj_new_fun_bc(c->n_args, def_args, c->n_state, c->u_byte.code);
690690
break;
691691
case MP_CODE_NATIVE:
692692
fun = rt_make_function_n(c->n_args, c->u_native.fun);
@@ -710,7 +710,7 @@ mp_obj_t rt_make_function_from_id(int unique_code_id) {
710710
mp_obj_t rt_make_closure_from_id(int unique_code_id, mp_obj_t closure_tuple) {
711711
DEBUG_OP_printf("make_closure_from_id %d\n", unique_code_id);
712712
// make function object
713-
mp_obj_t ffun = rt_make_function_from_id(unique_code_id);
713+
mp_obj_t ffun = rt_make_function_from_id(unique_code_id, MP_OBJ_NULL);
714714
// wrap function in closure object
715715
return mp_obj_new_closure(ffun, closure_tuple);
716716
}

py/runtime.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void rt_store_name(qstr qstr, mp_obj_t obj);
1212
void rt_store_global(qstr qstr, mp_obj_t obj);
1313
mp_obj_t rt_unary_op(int op, mp_obj_t arg);
1414
mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs);
15-
mp_obj_t rt_make_function_from_id(int unique_code_id);
15+
mp_obj_t rt_make_function_from_id(int unique_code_id, mp_obj_t def_args);
1616
mp_obj_t rt_make_function_n(int n_args, void *fun); // fun must have the correct signature for n_args fixed arguments
1717
mp_obj_t rt_make_function_var(int n_args_min, mp_fun_var_t fun);
1818
mp_obj_t rt_make_function_var_between(int n_args_min, int n_args_max, mp_fun_var_t fun); // min and max are inclusive

py/showbc.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,11 @@ void mp_byte_code_print(const byte *ip, int len) {
339339
printf("MAKE_FUNCTION " UINT_FMT, unum);
340340
break;
341341

342+
case MP_BC_MAKE_FUNCTION_DEFARGS:
343+
DECODE_UINT;
344+
printf("MAKE_FUNCTION_DEFARGS " UINT_FMT, unum);
345+
break;
346+
342347
case MP_BC_MAKE_CLOSURE:
343348
DECODE_UINT;
344349
printf("MAKE_CLOSURE " UINT_FMT, unum);

py/vm.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,12 @@ bool mp_execute_byte_code_2(const byte *code_info, const byte **ip_in_out, mp_ob
476476

477477
case MP_BC_MAKE_FUNCTION:
478478
DECODE_UINT;
479-
PUSH(rt_make_function_from_id(unum));
479+
PUSH(rt_make_function_from_id(unum, MP_OBJ_NULL));
480+
break;
481+
482+
case MP_BC_MAKE_FUNCTION_DEFARGS:
483+
DECODE_UINT;
484+
SET_TOP(rt_make_function_from_id(unum, TOP()));
480485
break;
481486

482487
case MP_BC_MAKE_CLOSURE:

tests/basics/fun-defargs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
def fun1(val=5):
2+
print(5)
3+
4+
fun1()
5+
fun1(10)
6+
7+
def fun2(p1, p2=100, p3="foo"):
8+
print(p1, p2, p3)
9+
10+
fun2(1)
11+
fun2(1, None)
12+
fun2(0, "bar", 200)
13+
try:
14+
fun2()
15+
except TypeError:
16+
print("TypeError")
17+
try:
18+
fun2(1, 2, 3, 4)
19+
except TypeError:
20+
print("TypeError")

0 commit comments

Comments
 (0)