Skip to content

Commit 29e9db0

Browse files
committed
py: Fix compiler to handle lambdas used as default arguments.
Addresses issue adafruit#1709.
1 parent bb7f5b5 commit 29e9db0

3 files changed

Lines changed: 28 additions & 0 deletions

File tree

py/compile.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,13 @@ STATIC void compile_funcdef_lambdef_param(compiler_t *comp, mp_parse_node_t pn)
662662
}
663663

664664
STATIC void compile_funcdef_lambdef(compiler_t *comp, scope_t *scope, mp_parse_node_t pn_params, pn_kind_t pn_list_kind) {
665+
// When we call compile_funcdef_lambdef_param below it can compile an arbitrary
666+
// expression for default arguments, which may contain a lambda. The lambda will
667+
// call here in a nested way, so we must save and restore the relevant state.
668+
bool orig_have_star = comp->have_star;
669+
uint16_t orig_num_dict_params = comp->num_dict_params;
670+
uint16_t orig_num_default_params = comp->num_default_params;
671+
665672
// compile default parameters
666673
comp->have_star = false;
667674
comp->num_dict_params = 0;
@@ -681,6 +688,11 @@ STATIC void compile_funcdef_lambdef(compiler_t *comp, scope_t *scope, mp_parse_n
681688

682689
// make the function
683690
close_over_variables_etc(comp, scope, comp->num_default_params, comp->num_dict_params);
691+
692+
// restore state
693+
comp->have_star = orig_have_star;
694+
comp->num_dict_params = orig_num_dict_params;
695+
comp->num_default_params = orig_num_default_params;
684696
}
685697

686698
// leaves function object on stack

tests/basics/fun_defargs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# testing default args to a function
2+
13
def fun1(val=5):
24
print(val)
35

@@ -18,3 +20,10 @@ def fun2(p1, p2=100, p3="foo"):
1820
fun2(1, 2, 3, 4)
1921
except TypeError:
2022
print("TypeError")
23+
24+
# lambda as default arg (exposes nested behaviour in compiler)
25+
def f(x=lambda:1):
26+
return x()
27+
print(f())
28+
print(f(f))
29+
print(f(lambda:2))

tests/basics/fun_kwonly.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,10 @@ def f(a, *b, c):
5757
f(1, c=2)
5858
f(1, 2, c=3)
5959
f(a=1, c=3)
60+
61+
# lambda as kw-only arg (exposes nested behaviour in compiler)
62+
def f(*, x=lambda:1):
63+
return x()
64+
print(f())
65+
print(f(x=f))
66+
print(f(x=lambda:2))

0 commit comments

Comments
 (0)