Skip to content

Commit 084ef37

Browse files
committed
py: Fix math.{ceil,floor,trunc} to return int.
1 parent e3e0500 commit 084ef37

2 files changed

Lines changed: 41 additions & 37 deletions

File tree

py/modmath.c

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_float(y_obj))); } \
1818
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name);
1919

20-
#define MATH_FUN_BOOL1(py_name, c_name) \
20+
#define MATH_FUN_1_TO_BOOL(py_name, c_name) \
2121
mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return MP_BOOL(c_name(mp_obj_get_float(x_obj))); } \
2222
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
2323

24+
#define MATH_FUN_1_TO_INT(py_name, c_name) \
25+
mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_int((machine_int_t)MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \
26+
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name);
27+
2428
STATIC const mp_obj_float_t mp_math_e_obj = {{&mp_type_float}, M_E};
2529
STATIC const mp_obj_float_t mp_math_pi_obj = {{&mp_type_float}, M_PI};
2630

@@ -44,15 +48,15 @@ MATH_FUN_1(acos, acos)
4448
MATH_FUN_1(asin, asin)
4549
MATH_FUN_1(atan, atan)
4650
MATH_FUN_2(atan2, atan2)
47-
MATH_FUN_1(ceil, ceil)
51+
MATH_FUN_1_TO_INT(ceil, ceil)
4852
MATH_FUN_2(copysign, copysign)
4953
MATH_FUN_1(fabs, fabs)
50-
MATH_FUN_1(floor, floor) //TODO: delegate to x.__floor__() if x is not a float
54+
MATH_FUN_1_TO_INT(floor, floor) //TODO: delegate to x.__floor__() if x is not a float
5155
MATH_FUN_2(fmod, fmod)
52-
MATH_FUN_BOOL1(isfinite, isfinite)
53-
MATH_FUN_BOOL1(isinf, isinf)
54-
MATH_FUN_BOOL1(isnan, isnan)
55-
MATH_FUN_1(trunc, trunc)
56+
MATH_FUN_1_TO_BOOL(isfinite, isfinite)
57+
MATH_FUN_1_TO_BOOL(isinf, isinf)
58+
MATH_FUN_1_TO_BOOL(isnan, isnan)
59+
MATH_FUN_1_TO_INT(trunc, trunc)
5660
MATH_FUN_2(ldexp, ldexp)
5761
MATH_FUN_1(erf, erf)
5862
MATH_FUN_1(erfc, erfc)

tests/basics/math-fun.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,41 @@
66
p_test_values = [0.1, 0.5, 1.23456]
77
unit_range_test_values = [-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75, 1.]
88

9-
functions = [(sqrt, p_test_values),
10-
(exp, test_values),
11-
(expm1, test_values),
12-
(log, p_test_values),
13-
(log2, p_test_values),
14-
(log10, p_test_values),
15-
(cosh, test_values),
16-
(sinh, test_values),
17-
(tanh, test_values),
18-
(acosh, [1.0, 5.0, 1.0]),
19-
(asinh, test_values),
20-
(atanh, [-0.99, -0.5, 0.0, 0.5, 0.99]),
21-
(cos, test_values),
22-
(sin, test_values),
23-
(tan, test_values),
24-
(acos, unit_range_test_values),
25-
(asin, unit_range_test_values),
26-
(atan, test_values),
27-
(ceil, test_values),
28-
(fabs, test_values),
29-
(floor, test_values),
30-
#(frexp, test_values),
31-
(trunc, test_values)
9+
functions = [('sqrt', sqrt, p_test_values),
10+
('exp', exp, test_values),
11+
('expm1', expm1, test_values),
12+
('log', log, p_test_values),
13+
('log2', log2, p_test_values),
14+
('log10', log10, p_test_values),
15+
('cosh', cosh, test_values),
16+
('sinh', sinh, test_values),
17+
('tanh', tanh, test_values),
18+
('acosh', acosh, [1.0, 5.0, 1.0]),
19+
('asinh', asinh, test_values),
20+
('atanh', atanh, [-0.99, -0.5, 0.0, 0.5, 0.99]),
21+
('cos', cos, test_values),
22+
('sin', sin, test_values),
23+
('tan', tan, test_values),
24+
('acos', acos, unit_range_test_values),
25+
('asin', asin, unit_range_test_values),
26+
('atan', atan, test_values),
27+
('ceil', ceil, test_values),
28+
('fabs', fabs, test_values),
29+
('floor', floor, test_values),
30+
#('frexp', frexp, test_values),
31+
('trunc', trunc, test_values)
3232
]
3333

34-
for function, test_vals in functions:
34+
for function_name, function, test_vals in functions:
35+
print(function_name)
3536
for value in test_vals:
36-
print("{:8.7f}".format(function(value)))
37+
print("{:8.7g}".format(function(value)))
3738

38-
binary_functions = [(copysign, [(23., 42.), (-23., 42.), (23., -42.),
39+
binary_functions = [('copysign', copysign, [(23., 42.), (-23., 42.), (23., -42.),
3940
(-23., -42.), (1., 0.0), (1., -0.0)])
4041
]
4142

42-
for function, test_vals in binary_functions:
43+
for function_name, function, test_vals in binary_functions:
44+
print(function_name)
4345
for value1, value2 in test_vals:
44-
print("{:8.7f}".format(function(value1, value2)))
45-
46-
46+
print("{:8.7g}".format(function(value1, value2)))

0 commit comments

Comments
 (0)