Skip to content

Commit 17298af

Browse files
mbueschdpgeorge
authored andcommitted
py/modmath: Add domain error checking to sqrt, log, log2, log10.
These functions will raise 'ValueError: math domain error' on invalid input.
1 parent f7c4f9a commit 17298af

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

py/modmath.c

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
*/
2626

2727
#include "py/builtin.h"
28+
#include "py/nlr.h"
2829

2930
#if MICROPY_PY_BUILTINS_FLOAT && MICROPY_PY_MATH
3031

@@ -35,7 +36,10 @@
3536
/// The `math` module provides some basic mathematical funtions for
3637
/// working with floating-point numbers.
3738

38-
//TODO: Change macros to check for overflow and raise OverflowError or RangeError
39+
STATIC NORETURN void math_error(void) {
40+
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "math domain error"));
41+
}
42+
3943
#define MATH_FUN_1(py_name, c_name) \
4044
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \
4145
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
@@ -52,14 +56,24 @@
5256
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { mp_int_t x = MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj)); return mp_obj_new_int(x); } \
5357
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
5458

59+
#define MATH_FUN_1_ERRCOND(py_name, c_name, error_condition) \
60+
STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \
61+
mp_float_t x = mp_obj_get_float(x_obj); \
62+
if (error_condition) { \
63+
math_error(); \
64+
} \
65+
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(x)); \
66+
} \
67+
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
68+
5569
#if MP_NEED_LOG2
5670
// 1.442695040888963407354163704 is 1/_M_LN2
5771
#define log2(x) (log(x) * 1.442695040888963407354163704)
5872
#endif
5973

6074
/// \function sqrt(x)
6175
/// Returns the square root of `x`.
62-
MATH_FUN_1(sqrt, sqrt)
76+
MATH_FUN_1_ERRCOND(sqrt, sqrt, (x < (mp_float_t)0.0))
6377
/// \function pow(x, y)
6478
/// Returns `x` to the power of `y`.
6579
MATH_FUN_2(pow, pow)
@@ -69,9 +83,9 @@ MATH_FUN_1(exp, exp)
6983
/// \function expm1(x)
7084
MATH_FUN_1(expm1, expm1)
7185
/// \function log2(x)
72-
MATH_FUN_1(log2, log2)
86+
MATH_FUN_1_ERRCOND(log2, log2, (x <= (mp_float_t)0.0))
7387
/// \function log10(x)
74-
MATH_FUN_1(log10, log10)
88+
MATH_FUN_1_ERRCOND(log10, log10, (x <= (mp_float_t)0.0))
7589
/// \function cosh(x)
7690
MATH_FUN_1(cosh, cosh)
7791
/// \function sinh(x)
@@ -139,11 +153,19 @@ MATH_FUN_1(lgamma, lgamma)
139153

140154
// log(x[, base])
141155
STATIC mp_obj_t mp_math_log(mp_uint_t n_args, const mp_obj_t *args) {
142-
mp_float_t l = MICROPY_FLOAT_C_FUN(log)(mp_obj_get_float(args[0]));
156+
mp_float_t x = mp_obj_get_float(args[0]);
157+
if (x <= (mp_float_t)0.0) {
158+
math_error();
159+
}
160+
mp_float_t l = MICROPY_FLOAT_C_FUN(log)(x);
143161
if (n_args == 1) {
144162
return mp_obj_new_float(l);
145163
} else {
146-
return mp_obj_new_float(l / MICROPY_FLOAT_C_FUN(log)(mp_obj_get_float(args[1])));
164+
mp_float_t base = mp_obj_get_float(args[1]);
165+
if (base <= (mp_float_t)0.0) {
166+
math_error();
167+
}
168+
return mp_obj_new_float(l / MICROPY_FLOAT_C_FUN(log)(base));
147169
}
148170
}
149171
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_math_log_obj, 1, 2, mp_math_log);

tests/float/math_fun.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99

1010
test_values = [-100., -1.23456, -1, -0.5, 0.0, 0.5, 1.23456, 100.]
1111
test_values_small = [-10., -1.23456, -1, -0.5, 0.0, 0.5, 1.23456, 10.] # so we don't overflow 32-bit precision
12-
p_test_values = [0.1, 0.5, 1.23456]
1312
unit_range_test_values = [-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75, 1.]
1413

15-
functions = [('sqrt', sqrt, p_test_values),
14+
functions = [('sqrt', sqrt, test_values),
1615
('exp', exp, test_values_small),
1716
('expm1', expm1, test_values_small),
18-
('log', log, p_test_values),
19-
('log2', log2, p_test_values),
20-
('log10', log10, p_test_values),
17+
('log', log, test_values),
18+
('log2', log2, test_values),
19+
('log10', log10, test_values),
2120
('cosh', cosh, test_values_small),
2221
('sinh', sinh, test_values_small),
2322
('tanh', tanh, test_values_small),
@@ -41,7 +40,10 @@
4140
for function_name, function, test_vals in functions:
4241
print(function_name)
4342
for value in test_vals:
44-
print("{:.5g}".format(function(value)))
43+
try:
44+
print("{:.5g}".format(function(value)))
45+
except ValueError as e:
46+
print(str(e))
4547

4648
tuple_functions = [('frexp', frexp, test_values),
4749
('modf', modf, test_values),
@@ -59,10 +61,13 @@
5961
('atan2', atan2, ((1., 0.), (0., 1.), (2., 0.5), (-3., 5.), (-3., -4.),)),
6062
('fmod', fmod, ((1., 1.), (0., 1.), (2., 0.5), (-3., 5.), (-3., -4.),)),
6163
('ldexp', ldexp, ((1., 0), (0., 1), (2., 2), (3., -2), (-3., -4),)),
62-
('log', log, ((2., 2.), (3., 2.), (4., 5.))),
64+
('log', log, ((2., 2.), (3., 2.), (4., 5.), (0., 1.), (1., 0.), (-1., 1.), (1., -1.))),
6365
]
6466

6567
for function_name, function, test_vals in binary_functions:
6668
print(function_name)
6769
for value1, value2 in test_vals:
68-
print("{:.5g}".format(function(value1, value2)))
70+
try:
71+
print("{:.5g}".format(function(value1, value2)))
72+
except ValueError as e:
73+
print(str(e))

0 commit comments

Comments
 (0)