Skip to content

Commit 9a21d2e

Browse files
committed
py: Make mpz able to use 16 bits per digit; and 32 on 64-bit arch.
Previously, mpz was restricted to using at most 15 bits in each digit, where a digit was a uint16_t. With this patch, mpz can use all 16 bits in the uint16_t (improvement to mpn_div was required). This gives small inprovements in speed and RAM usage. It also yields savings in ROM code size because all of the digit masking operations become no-ops. Also, mpz can now use a uint32_t as the digit type, and hence use 32 bits per digit. This will give decent improvements in mpz speed on 64-bit machines. Test for big integer division added.
1 parent afb1cf7 commit 9a21d2e

3 files changed

Lines changed: 99 additions & 20 deletions

File tree

py/mpz.c

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
3838

3939
#define DIG_SIZE (MPZ_DIG_SIZE)
40-
#define DIG_MASK ((1 << DIG_SIZE) - 1)
40+
#define DIG_MASK ((1L << DIG_SIZE) - 1)
41+
#define DIG_MSB (1L << (DIG_SIZE - 1))
42+
#define DIG_BASE (1L << DIG_SIZE)
4143

4244
/*
4345
mpz is an arbitrary precision integer type with a public API.
@@ -61,7 +63,7 @@ STATIC mp_int_t mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t *
6163
if (ilen > jlen) { return 1; }
6264

6365
for (idig += ilen, jdig += ilen; ilen > 0; --ilen) {
64-
mp_int_t cmp = *(--idig) - *(--jdig);
66+
mpz_dbl_dig_signed_t cmp = (mpz_dbl_dig_t)*(--idig) - (mpz_dbl_dig_t)*(--jdig);
6567
if (cmp < 0) { return -1; }
6668
if (cmp > 0) { return 1; }
6769
}
@@ -127,7 +129,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
127129
for (mp_uint_t i = jlen; i > 0; i--, idig++, jdig++) {
128130
mpz_dbl_dig_t d = *jdig;
129131
if (i > 1) {
130-
d |= jdig[1] << DIG_SIZE;
132+
d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE;
131133
}
132134
d >>= n_part;
133135
*idig = d & DIG_MASK;
@@ -152,7 +154,7 @@ STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
152154
jlen -= klen;
153155

154156
for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
155-
carry += *jdig + *kdig;
157+
carry += (mpz_dbl_dig_t)*jdig + (mpz_dbl_dig_t)*kdig;
156158
*idig = carry & DIG_MASK;
157159
carry >>= DIG_SIZE;
158160
}
@@ -182,7 +184,7 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
182184
jlen -= klen;
183185

184186
for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
185-
borrow += *jdig - *kdig;
187+
borrow += (mpz_dbl_dig_t)*jdig - (mpz_dbl_dig_t)*kdig;
186188
*idig = borrow & DIG_MASK;
187189
borrow >>= DIG_SIZE;
188190
}
@@ -301,7 +303,7 @@ STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t
301303
mpz_dbl_dig_t carry = dadd;
302304

303305
for (; ilen > 0; --ilen, ++idig) {
304-
carry += *idig * dmul; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
306+
carry += (mpz_dbl_dig_t)*idig * (mpz_dbl_dig_t)dmul; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
305307
*idig = carry & DIG_MASK;
306308
carry >>= DIG_SIZE;
307309
}
@@ -328,7 +330,7 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
328330

329331
mp_uint_t jl = jlen;
330332
for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) {
331-
carry += *id + *jd * *kdig; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
333+
carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
332334
*id = carry & DIG_MASK;
333335
carry >>= DIG_SIZE;
334336
}
@@ -375,7 +377,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
375377
// count number of leading zeros in leading digit of denominator
376378
{
377379
mpz_dig_t d = den_dig[den_len - 1];
378-
while ((d & (1 << (DIG_SIZE - 1))) == 0) {
380+
while ((d & DIG_MSB) == 0) {
379381
d <<= 1;
380382
++norm_shift;
381383
}
@@ -412,29 +414,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
412414

413415
// keep going while we have enough digits to divide
414416
while (*num_len > den_len) {
415-
mpz_dbl_dig_t quo = (*num_dig << DIG_SIZE) | num_dig[-1];
417+
mpz_dbl_dig_t quo = ((mpz_dbl_dig_t)*num_dig << DIG_SIZE) | num_dig[-1];
416418

417419
// get approximate quotient
418420
quo /= lead_den_digit;
419421

420-
// multiply quo by den and subtract from num get remainder
421-
{
422+
// Multiply quo by den and subtract from num to get remainder.
423+
// We have different code here to handle different compile-time
424+
// configurations of mpz:
425+
//
426+
// 1. DIG_SIZE is stricly less than half the number of bits
427+
// available in mpz_dbl_dig_t. In this case we can use a
428+
// slightly more optimal (in time and space) routine that
429+
// uses the extra bits in mpz_dbl_dig_signed_t to store a
430+
// sign bit.
431+
//
432+
// 2. DIG_SIZE is exactly half the number of bits available in
433+
// mpz_dbl_dig_t. In this (common) case we need to be careful
434+
// not to overflow the borrow variable. And the shifting of
435+
// borrow needs some special logic (it's a shift right with
436+
// round up).
437+
438+
if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
422439
mpz_dbl_dig_signed_t borrow = 0;
423440

424441
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
425-
borrow += *n - quo * *d; // will overflow if DIG_SIZE >= 16
442+
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
426443
*n = borrow & DIG_MASK;
427444
borrow >>= DIG_SIZE;
428445
}
429-
borrow += *num_dig; // will overflow if DIG_SIZE >= 16
446+
borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
430447
*num_dig = borrow & DIG_MASK;
431448
borrow >>= DIG_SIZE;
432449

433450
// adjust quotient if it is too big
434451
for (; borrow != 0; --quo) {
435452
mpz_dbl_dig_t carry = 0;
436453
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
437-
carry += *n + *d;
454+
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
438455
*n = carry & DIG_MASK;
439456
carry >>= DIG_SIZE;
440457
}
@@ -444,6 +461,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
444461

445462
borrow += carry;
446463
}
464+
} else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
465+
mpz_dbl_dig_t borrow = 0;
466+
467+
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
468+
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d);
469+
if (x >= *n || *n - x <= borrow) {
470+
borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
471+
*n = (-borrow) & DIG_MASK;
472+
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
473+
} else {
474+
*n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK;
475+
borrow = 0;
476+
}
477+
}
478+
if (borrow >= *num_dig) {
479+
borrow -= (mpz_dbl_dig_t)*num_dig;
480+
*num_dig = (-borrow) & DIG_MASK;
481+
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
482+
} else {
483+
*num_dig = (*num_dig - borrow) & DIG_MASK;
484+
borrow = 0;
485+
}
486+
487+
// adjust quotient if it is too big
488+
for (; borrow != 0; --quo) {
489+
mpz_dbl_dig_t carry = 0;
490+
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
491+
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
492+
*n = carry & DIG_MASK;
493+
carry >>= DIG_SIZE;
494+
}
495+
carry += (mpz_dbl_dig_t)*num_dig;
496+
*num_dig = carry & DIG_MASK;
497+
carry >>= DIG_SIZE;
498+
499+
//assert(borrow >= carry); // enable this to check the logic
500+
borrow -= carry;
501+
}
447502
}
448503

449504
// store this digit of the quotient
@@ -1256,7 +1311,7 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
12561311
mpz_dig_t *d = i->dig + i->len;
12571312

12581313
while (--d >= i->dig) {
1259-
if (val > ((~0) >> DIG_SIZE)) {
1314+
if (val > (~(WORD_MSBIT_HIGH) >> (DIG_SIZE - 1))) {
12601315
// will overflow
12611316
return false;
12621317
}
@@ -1273,7 +1328,7 @@ mp_float_t mpz_as_float(const mpz_t *i) {
12731328
mpz_dig_t *d = i->dig + i->len;
12741329

12751330
while (--d >= i->dig) {
1276-
val = val * (1 << DIG_SIZE) + *d;
1331+
val = val * DIG_BASE + *d;
12771332
}
12781333

12791334
if (i->neg != 0) {

py/mpz.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,34 @@
2424
* THE SOFTWARE.
2525
*/
2626

27+
// This mpz module implements arbitrary precision integers.
28+
//
29+
// The storage for each digit is defined by mpz_dig_t. The actual number of
30+
// bits in mpz_dig_t that are used is defined by MPZ_DIG_SIZE. The machine must
31+
// also provide a type that is twice as wide as mpz_dig_t, in both signed and
32+
// unsigned versions.
33+
//
34+
// MPZ_DIG_SIZE can be between 4 and 8*sizeof(mpz_dig_t), but it makes most
35+
// sense to have it as large as possible. Below, the type is auto-detected
36+
// depending on the machine, but it (and MPZ_DIG_SIZE) can be freely changed so
37+
// long as the constraints mentioned above are met.
38+
39+
#if defined(__x86_64__)
40+
// 64-bit machine, using 32-bit storage for digits
41+
typedef uint32_t mpz_dig_t;
42+
typedef uint64_t mpz_dbl_dig_t;
43+
typedef int64_t mpz_dbl_dig_signed_t;
44+
#define MPZ_DIG_SIZE (32)
45+
#else
46+
// 32-bit machine, using 16-bit storage for digits
2747
typedef uint16_t mpz_dig_t;
2848
typedef uint32_t mpz_dbl_dig_t;
2949
typedef int32_t mpz_dbl_dig_signed_t;
50+
#define MPZ_DIG_SIZE (16)
51+
#endif
52+
53+
#define MPZ_NUM_DIG_FOR_INT (sizeof(mp_int_t) * 8 / MPZ_DIG_SIZE + 1)
54+
#define MPZ_NUM_DIG_FOR_LL (sizeof(long long) * 8 / MPZ_DIG_SIZE + 1)
3055

3156
typedef struct _mpz_t {
3257
mp_uint_t neg : 1;
@@ -36,10 +61,6 @@ typedef struct _mpz_t {
3661
mpz_dig_t *dig;
3762
} mpz_t;
3863

39-
#define MPZ_DIG_SIZE (15) // see mpn_div for why this needs to be at most 15
40-
#define MPZ_NUM_DIG_FOR_INT (sizeof(mp_int_t) * 8 / MPZ_DIG_SIZE + 1)
41-
#define MPZ_NUM_DIG_FOR_LL (sizeof(long long) * 8 / MPZ_DIG_SIZE + 1)
42-
4364
// convenience macro to declare an mpz with a digit array from the stack, initialised by an integer
4465
#define MPZ_CONST_INT(z, val) mpz_t z; mpz_dig_t z ## _digits[MPZ_NUM_DIG_FOR_INT]; mpz_init_fixed_from_int(&z, z_digits, MPZ_NUM_DIG_FOR_INT, val);
4566

tests/basics/int_big_div.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
for lhs in (1000000000000000000000000, 10000000000100000000000000, 10012003400000000000000007, 12349083434598210349871029923874109871234789):
2+
for rhs in range(1, 555):
3+
print(lhs // rhs)

0 commit comments

Comments
 (0)