Skip to content

Commit 9f911d8

Browse files
committed
py/objcomplex: Add mp_obj_get_complex_maybe for use in complex bin-op.
This allows complex binary operations to fail gracefully with unsupported operation rather than raising an exception, so that special methods work correctly. Signed-off-by: Damien George <damien@micropython.org>
1 parent 41fa8b5 commit 9f911d8

File tree

6 files changed

+35
-2
lines changed

6 files changed

+35
-2
lines changed

py/obj.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ mp_float_t mp_obj_get_float(mp_obj_t arg) {
371371
}
372372

373373
#if MICROPY_PY_BUILTINS_COMPLEX
374-
void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
374+
bool mp_obj_get_complex_maybe(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
375375
if (arg == mp_const_false) {
376376
*real = 0;
377377
*imag = 0;
@@ -392,6 +392,13 @@ void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
392392
} else if (mp_obj_is_type(arg, &mp_type_complex)) {
393393
mp_obj_complex_get(arg, real, imag);
394394
} else {
395+
return false;
396+
}
397+
return true;
398+
}
399+
400+
void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
401+
if (!mp_obj_get_complex_maybe(arg, real, imag)) {
395402
#if MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE
396403
mp_raise_TypeError(MP_ERROR_TEXT("can't convert to complex"));
397404
#else

py/obj.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,7 @@ bool mp_obj_get_int_maybe(mp_const_obj_t arg, mp_int_t *value);
778778
mp_float_t mp_obj_get_float(mp_obj_t self_in);
779779
bool mp_obj_get_float_maybe(mp_obj_t arg, mp_float_t *value);
780780
void mp_obj_get_complex(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag);
781+
bool mp_obj_get_complex_maybe(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag);
781782
#endif
782783
void mp_obj_get_array(mp_obj_t o, size_t *len, mp_obj_t **items); // *items may point inside a GC block
783784
void mp_obj_get_array_fixed_n(mp_obj_t o, size_t len, mp_obj_t **items); // *items may point inside a GC block

py/objcomplex.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ void mp_obj_complex_get(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag) {
178178

179179
mp_obj_t mp_obj_complex_binary_op(mp_binary_op_t op, mp_float_t lhs_real, mp_float_t lhs_imag, mp_obj_t rhs_in) {
180180
mp_float_t rhs_real, rhs_imag;
181-
mp_obj_get_complex(rhs_in, &rhs_real, &rhs_imag); // can be any type, this function will convert to float (if possible)
181+
if (!mp_obj_get_complex_maybe(rhs_in, &rhs_real, &rhs_imag)) {
182+
return MP_OBJ_NULL; // op not supported
183+
}
184+
182185
switch (op) {
183186
case MP_BINARY_OP_ADD:
184187
case MP_BINARY_OP_INPLACE_ADD:

tests/float/cmath_fun.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,9 @@
5757
if abs(real) < 1e-6:
5858
real = 0.0
5959
print("complex(%.5g, %.5g)" % (real, ret.imag))
60+
61+
# test invalid type passed to cmath function
62+
try:
63+
log([])
64+
except TypeError:
65+
print("TypeError")
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# test complex interacting with special methods
2+
3+
4+
class A:
5+
def __add__(self, x):
6+
print("__add__")
7+
return 1
8+
9+
def __radd__(self, x):
10+
print("__radd__")
11+
return 2
12+
13+
14+
print(A() + 1j)
15+
print(1j + A())

tests/run-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def run_tests(pyb, tests, args, base_path="."):
355355
if not has_complex:
356356
skip_tests.add('float/complex1.py')
357357
skip_tests.add('float/complex1_intbig.py')
358+
skip_tests.add('float/complex_special_mehods.py')
358359
skip_tests.add('float/int_big_float.py')
359360
skip_tests.add('float/true_value.py')
360361
skip_tests.add('float/types.py')

0 commit comments

Comments
 (0)