Skip to content

Commit df0117c

Browse files
nickovspfalcon
authored andcommitted
py: Added optimised support for 3-argument calls to builtin.pow()
Updated modbuiltin.c to add conditional support for 3-arg calls to pow() using MICROPY_PY_BUILTINS_POW3 config parameter. Added support in objint_mpz.c for for optimised implementation.
1 parent 2486c4f commit df0117c

8 files changed

Lines changed: 80 additions & 4 deletions

File tree

py/modbuiltins.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,14 @@ MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_ord_obj, mp_builtin_ord);
378378
STATIC mp_obj_t mp_builtin_pow(size_t n_args, const mp_obj_t *args) {
379379
switch (n_args) {
380380
case 2: return mp_binary_op(MP_BINARY_OP_POWER, args[0], args[1]);
381-
default: return mp_binary_op(MP_BINARY_OP_MODULO, mp_binary_op(MP_BINARY_OP_POWER, args[0], args[1]), args[2]); // TODO optimise...
381+
default:
382+
#if !MICROPY_PY_BUILTINS_POW3
383+
mp_raise_msg(&mp_type_NotImplementedError, "3-arg pow() not supported");
384+
#elif MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_MPZ
385+
return mp_binary_op(MP_BINARY_OP_MODULO, mp_binary_op(MP_BINARY_OP_POWER, args[0], args[1]), args[2]);
386+
#else
387+
return mp_obj_int_pow3(args[0], args[1], args[2]);
388+
#endif
382389
}
383390
}
384391
MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_builtin_pow_obj, 2, 3, mp_builtin_pow);

py/mpconfig.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@
490490
#define MICROPY_LONGINT_IMPL (MICROPY_LONGINT_IMPL_NONE)
491491
#endif
492492

493+
// Support for calls to pow() with 3 integer arguments
494+
#ifndef MICROPY_PY_BUILTINS_POW3
495+
#define MICROPY_PY_BUILTINS_POW3 (0)
496+
#endif
497+
493498
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_LONGLONG
494499
typedef long long mp_longint_impl_t;
495500
#endif

py/mpz.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,9 +1395,6 @@ void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
13951395
mpz_free(n);
13961396
}
13971397

1398-
#if 0
1399-
these functions are unused
1400-
14011398
/* computes dest = (lhs ** rhs) % mod
14021399
can have dest, lhs, rhs the same; mod can't be the same as dest
14031400
*/
@@ -1436,6 +1433,9 @@ void mpz_pow3_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs, const mpz_t
14361433
mpz_free(n);
14371434
}
14381435

1436+
#if 0
1437+
these functions are unused
1438+
14391439
/* computes gcd(z1, z2)
14401440
based on Knuth's modified gcd algorithm (I think?)
14411441
gcd(z1, z2) >= 0

py/mpz.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
123123
void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
124124
void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
125125
void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
126+
void mpz_pow3_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs, const mpz_t *mod);
126127
void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
127128
void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);
128129
void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs);

py/objint.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,6 @@ mp_obj_t mp_obj_int_abs(mp_obj_t self_in);
6666
mp_obj_t mp_obj_int_unary_op(mp_uint_t op, mp_obj_t o_in);
6767
mp_obj_t mp_obj_int_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in);
6868
mp_obj_t mp_obj_int_binary_op_extra_cases(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in);
69+
mp_obj_t mp_obj_int_pow3(mp_obj_t base, mp_obj_t exponent, mp_obj_t modulus);
6970

7071
#endif // __MICROPY_INCLUDED_PY_OBJINT_H__

py/objint_mpz.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,39 @@ mp_obj_t mp_obj_int_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
326326
}
327327
}
328328

329+
#if MICROPY_PY_BUILTINS_POW3
330+
STATIC mpz_t *mp_mpz_for_int(mp_obj_t arg, mpz_t *temp) {
331+
if (MP_OBJ_IS_SMALL_INT(arg)) {
332+
mpz_init_from_int(temp, MP_OBJ_SMALL_INT_VALUE(arg));
333+
return temp;
334+
} else {
335+
mp_obj_int_t *arp_p = MP_OBJ_TO_PTR(arg);
336+
return &(arp_p->mpz);
337+
}
338+
}
339+
340+
mp_obj_t mp_obj_int_pow3(mp_obj_t base, mp_obj_t exponent, mp_obj_t modulus) {
341+
if (!MP_OBJ_IS_INT(base) || !MP_OBJ_IS_INT(exponent) || !MP_OBJ_IS_INT(modulus)) {
342+
mp_raise_TypeError("pow() with 3 arguments requires integers");
343+
} else {
344+
mp_obj_t result = mp_obj_new_int_from_ull(0); // Use the _from_ull version as this forces an mpz int
345+
mp_obj_int_t *res_p = (mp_obj_int_t *) MP_OBJ_TO_PTR(result);
346+
347+
mpz_t l_temp, r_temp, m_temp;
348+
mpz_t *lhs = mp_mpz_for_int(base, &l_temp);
349+
mpz_t *rhs = mp_mpz_for_int(exponent, &r_temp);
350+
mpz_t *mod = mp_mpz_for_int(modulus, &m_temp);
351+
352+
mpz_pow3_inpl(&(res_p->mpz), lhs, rhs, mod);
353+
354+
if (lhs == &l_temp) { mpz_deinit(lhs); }
355+
if (rhs == &r_temp) { mpz_deinit(rhs); }
356+
if (mod == &m_temp) { mpz_deinit(mod); }
357+
return result;
358+
}
359+
}
360+
#endif
361+
329362
mp_obj_t mp_obj_new_int(mp_int_t value) {
330363
if (MP_SMALL_INT_FITS(value)) {
331364
return MP_OBJ_NEW_SMALL_INT(value);

tests/basics/builtin_pow.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,32 @@
88

99
# 3 arg version
1010
print(pow(3, 4, 7))
11+
print(pow(555557, 1000002, 1000003))
1112

13+
# 3 arg pow is defined to only work on integers
14+
try:
15+
print(pow("x", 5, 6))
16+
except TypeError:
17+
print("TypeError expected")
18+
19+
try:
20+
print(pow(4, "y", 6))
21+
except TypeError:
22+
print("TypeError expected")
23+
24+
try:
25+
print(pow(4, 5, "z"))
26+
except TypeError:
27+
print("TypeError expected")
28+
29+
# Tests for 3 arg pow with large values
30+
31+
# This value happens to be prime
32+
x = 0xd48a1e2a099b1395895527112937a391d02d4a208bce5d74b281cf35a57362502726f79a632f063a83c0eba66196712d963aa7279ab8a504110a668c0fc38a7983c51e6ee7a85cae87097686ccdc359ee4bbf2c583bce524e3f7836bded1c771a4efcb25c09460a862fc98e18f7303df46aaeb34da46b0c4d61d5cd78350f3edb60e6bc4befa712a849
33+
y = 0x3accf60bb1a5365e4250d1588eb0fe6cd81ad495e9063f90880229f2a625e98c59387238670936afb2cafc5b79448e4414d6cd5e9901aa845aa122db58ddd7b9f2b17414600a18c47494ed1f3d49d005a5
34+
35+
print(hex(pow(2, 200, x))) # Should not overflow, just 1 << 200
36+
print(hex(pow(2, x-1, x))) # Should be 1, since x is prime
37+
print(hex(pow(y, x-1, x))) # Should be 1, since x is prime
38+
print(hex(pow(y, y-1, x))) # Should be a 'big value'
39+
print(hex(pow(y, y-1, y))) # Should be a 'big value'

unix/mpconfigport.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
#define MICROPY_PY_BUILTINS_FROZENSET (1)
8181
#define MICROPY_PY_BUILTINS_COMPILE (1)
8282
#define MICROPY_PY_BUILTINS_NOTIMPLEMENTED (1)
83+
#define MICROPY_PY_BUILTINS_POW3 (1)
8384
#define MICROPY_PY_MICROPYTHON_MEM_INFO (1)
8485
#define MICROPY_PY_ALL_SPECIAL_METHODS (1)
8586
#define MICROPY_PY_ARRAY_SLICE_ASSIGN (1)

0 commit comments

Comments
 (0)