3737#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
3838
3939#define DIG_SIZE (MPZ_DIG_SIZE)
40- #define DIG_MASK ((1 << DIG_SIZE) - 1)
40+ #define DIG_MASK ((1L << DIG_SIZE) - 1)
41+ #define DIG_MSB (1L << (DIG_SIZE - 1))
42+ #define DIG_BASE (1L << DIG_SIZE)
4143
4244/*
4345 mpz is an arbitrary precision integer type with a public API.
@@ -61,7 +63,7 @@ STATIC mp_int_t mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t *
6163 if (ilen > jlen ) { return 1 ; }
6264
6365 for (idig += ilen , jdig += ilen ; ilen > 0 ; -- ilen ) {
64- mp_int_t cmp = * (-- idig ) - * (-- jdig );
66+ mpz_dbl_dig_signed_t cmp = ( mpz_dbl_dig_t ) * (-- idig ) - ( mpz_dbl_dig_t ) * (-- jdig );
6567 if (cmp < 0 ) { return -1 ; }
6668 if (cmp > 0 ) { return 1 ; }
6769 }
@@ -127,7 +129,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
127129 for (mp_uint_t i = jlen ; i > 0 ; i -- , idig ++ , jdig ++ ) {
128130 mpz_dbl_dig_t d = * jdig ;
129131 if (i > 1 ) {
130- d |= jdig [1 ] << DIG_SIZE ;
132+ d |= ( mpz_dbl_dig_t ) jdig [1 ] << DIG_SIZE ;
131133 }
132134 d >>= n_part ;
133135 * idig = d & DIG_MASK ;
@@ -152,7 +154,7 @@ STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
152154 jlen -= klen ;
153155
154156 for (; klen > 0 ; -- klen , ++ idig , ++ jdig , ++ kdig ) {
155- carry += * jdig + * kdig ;
157+ carry += ( mpz_dbl_dig_t ) * jdig + ( mpz_dbl_dig_t ) * kdig ;
156158 * idig = carry & DIG_MASK ;
157159 carry >>= DIG_SIZE ;
158160 }
@@ -182,7 +184,7 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
182184 jlen -= klen ;
183185
184186 for (; klen > 0 ; -- klen , ++ idig , ++ jdig , ++ kdig ) {
185- borrow += * jdig - * kdig ;
187+ borrow += ( mpz_dbl_dig_t ) * jdig - ( mpz_dbl_dig_t ) * kdig ;
186188 * idig = borrow & DIG_MASK ;
187189 borrow >>= DIG_SIZE ;
188190 }
@@ -301,7 +303,7 @@ STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t
301303 mpz_dbl_dig_t carry = dadd ;
302304
303305 for (; ilen > 0 ; -- ilen , ++ idig ) {
304- carry += * idig * dmul ; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
306+ carry += ( mpz_dbl_dig_t ) * idig * ( mpz_dbl_dig_t ) dmul ; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/ 2
305307 * idig = carry & DIG_MASK ;
306308 carry >>= DIG_SIZE ;
307309 }
@@ -328,7 +330,7 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
328330
329331 mp_uint_t jl = jlen ;
330332 for (mpz_dig_t * jd = jdig ; jl > 0 ; -- jl , ++ jd , ++ id ) {
331- carry += * id + * jd * * kdig ; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
333+ carry += ( mpz_dbl_dig_t ) * id + ( mpz_dbl_dig_t ) * jd * ( mpz_dbl_dig_t ) * kdig ; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/ 2
332334 * id = carry & DIG_MASK ;
333335 carry >>= DIG_SIZE ;
334336 }
@@ -375,7 +377,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
375377 // count number of leading zeros in leading digit of denominator
376378 {
377379 mpz_dig_t d = den_dig [den_len - 1 ];
378- while ((d & ( 1 << ( DIG_SIZE - 1 )) ) == 0 ) {
380+ while ((d & DIG_MSB ) == 0 ) {
379381 d <<= 1 ;
380382 ++ norm_shift ;
381383 }
@@ -412,29 +414,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
412414
413415 // keep going while we have enough digits to divide
414416 while (* num_len > den_len ) {
415- mpz_dbl_dig_t quo = (* num_dig << DIG_SIZE ) | num_dig [-1 ];
417+ mpz_dbl_dig_t quo = (( mpz_dbl_dig_t ) * num_dig << DIG_SIZE ) | num_dig [-1 ];
416418
417419 // get approximate quotient
418420 quo /= lead_den_digit ;
419421
420- // multiply quo by den and subtract from num get remainder
421- {
422+ // Multiply quo by den and subtract from num to get remainder.
423+ // We have different code here to handle different compile-time
424+ // configurations of mpz:
425+ //
426+ // 1. DIG_SIZE is stricly less than half the number of bits
427+ // available in mpz_dbl_dig_t. In this case we can use a
428+ // slightly more optimal (in time and space) routine that
429+ // uses the extra bits in mpz_dbl_dig_signed_t to store a
430+ // sign bit.
431+ //
432+ // 2. DIG_SIZE is exactly half the number of bits available in
433+ // mpz_dbl_dig_t. In this (common) case we need to be careful
434+ // not to overflow the borrow variable. And the shifting of
435+ // borrow needs some special logic (it's a shift right with
436+ // round up).
437+
438+ if (DIG_SIZE < 8 * sizeof (mpz_dbl_dig_t ) / 2 ) {
422439 mpz_dbl_dig_signed_t borrow = 0 ;
423440
424441 for (mpz_dig_t * n = num_dig - den_len , * d = den_dig ; n < num_dig ; ++ n , ++ d ) {
425- borrow += * n - quo * * d ; // will overflow if DIG_SIZE >= 16
442+ borrow += ( mpz_dbl_dig_t ) * n - ( mpz_dbl_dig_t ) quo * ( mpz_dbl_dig_t ) * d ; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
426443 * n = borrow & DIG_MASK ;
427444 borrow >>= DIG_SIZE ;
428445 }
429- borrow += * num_dig ; // will overflow if DIG_SIZE >= 16
446+ borrow += * num_dig ; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
430447 * num_dig = borrow & DIG_MASK ;
431448 borrow >>= DIG_SIZE ;
432449
433450 // adjust quotient if it is too big
434451 for (; borrow != 0 ; -- quo ) {
435452 mpz_dbl_dig_t carry = 0 ;
436453 for (mpz_dig_t * n = num_dig - den_len , * d = den_dig ; n < num_dig ; ++ n , ++ d ) {
437- carry += * n + * d ;
454+ carry += ( mpz_dbl_dig_t ) * n + ( mpz_dbl_dig_t ) * d ;
438455 * n = carry & DIG_MASK ;
439456 carry >>= DIG_SIZE ;
440457 }
@@ -444,6 +461,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
444461
445462 borrow += carry ;
446463 }
464+ } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
465+ mpz_dbl_dig_t borrow = 0 ;
466+
467+ for (mpz_dig_t * n = num_dig - den_len , * d = den_dig ; n < num_dig ; ++ n , ++ d ) {
468+ mpz_dbl_dig_t x = (mpz_dbl_dig_t )quo * (mpz_dbl_dig_t )(* d );
469+ if (x >= * n || * n - x <= borrow ) {
470+ borrow += (mpz_dbl_dig_t )x - (mpz_dbl_dig_t )* n ;
471+ * n = (- borrow ) & DIG_MASK ;
472+ borrow = (borrow >> DIG_SIZE ) + ((borrow & DIG_MASK ) == 0 ? 0 : 1 ); // shift-right with round-up
473+ } else {
474+ * n = ((mpz_dbl_dig_t )* n - (mpz_dbl_dig_t )x - (mpz_dbl_dig_t )borrow ) & DIG_MASK ;
475+ borrow = 0 ;
476+ }
477+ }
478+ if (borrow >= * num_dig ) {
479+ borrow -= (mpz_dbl_dig_t )* num_dig ;
480+ * num_dig = (- borrow ) & DIG_MASK ;
481+ borrow = (borrow >> DIG_SIZE ) + ((borrow & DIG_MASK ) == 0 ? 0 : 1 ); // shift-right with round-up
482+ } else {
483+ * num_dig = (* num_dig - borrow ) & DIG_MASK ;
484+ borrow = 0 ;
485+ }
486+
487+ // adjust quotient if it is too big
488+ for (; borrow != 0 ; -- quo ) {
489+ mpz_dbl_dig_t carry = 0 ;
490+ for (mpz_dig_t * n = num_dig - den_len , * d = den_dig ; n < num_dig ; ++ n , ++ d ) {
491+ carry += (mpz_dbl_dig_t )* n + (mpz_dbl_dig_t )* d ;
492+ * n = carry & DIG_MASK ;
493+ carry >>= DIG_SIZE ;
494+ }
495+ carry += (mpz_dbl_dig_t )* num_dig ;
496+ * num_dig = carry & DIG_MASK ;
497+ carry >>= DIG_SIZE ;
498+
499+ //assert(borrow >= carry); // enable this to check the logic
500+ borrow -= carry ;
501+ }
447502 }
448503
449504 // store this digit of the quotient
@@ -1256,7 +1311,7 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
12561311 mpz_dig_t * d = i -> dig + i -> len ;
12571312
12581313 while (-- d >= i -> dig ) {
1259- if (val > ((~ 0 ) >> DIG_SIZE )) {
1314+ if (val > (~( WORD_MSBIT_HIGH ) >> ( DIG_SIZE - 1 ) )) {
12601315 // will overflow
12611316 return false;
12621317 }
@@ -1273,7 +1328,7 @@ mp_float_t mpz_as_float(const mpz_t *i) {
12731328 mpz_dig_t * d = i -> dig + i -> len ;
12741329
12751330 while (-- d >= i -> dig ) {
1276- val = val * ( 1 << DIG_SIZE ) + * d ;
1331+ val = val * DIG_BASE + * d ;
12771332 }
12781333
12791334 if (i -> neg != 0 ) {
0 commit comments