Skip to content

Commit dc3faea

Browse files
committed
py/mpz: Fix bug with overflowing C-shift in division routine.
When DIG_SIZE=32, a uint32_t is used to store limbs, and no normalisation is needed because the MSB is already set, then there will be left and right shifts (in C) by 32 of a 32-bit variable, leading to undefined behaviour. This patch fixes this bug.
1 parent d59c2e5 commit dc3faea

3 files changed

Lines changed: 18 additions & 4 deletions

File tree

py/mpz.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
491491
for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) {
492492
mpz_dig_t d = *den;
493493
*den = ((d << norm_shift) | carry) & DIG_MASK;
494-
carry = d >> (DIG_SIZE - norm_shift);
494+
carry = (mpz_dbl_dig_t)d >> (DIG_SIZE - norm_shift);
495495
}
496496

497497
// now need to shift numerator by same amount as denominator
@@ -501,7 +501,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
501501
for (mpz_dig_t *num = num_dig, carry = 0; num < num_dig + *num_len; ++num) {
502502
mpz_dig_t n = *num;
503503
*num = ((n << norm_shift) | carry) & DIG_MASK;
504-
carry = n >> (DIG_SIZE - norm_shift);
504+
carry = (mpz_dbl_dig_t)n >> (DIG_SIZE - norm_shift);
505505
}
506506

507507
// cache the leading digit of the denominator
@@ -618,14 +618,14 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
618618
for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) {
619619
mpz_dig_t d = *den;
620620
*den = ((d >> norm_shift) | carry) & DIG_MASK;
621-
carry = d << (DIG_SIZE - norm_shift);
621+
carry = (mpz_dbl_dig_t)d << (DIG_SIZE - norm_shift);
622622
}
623623

624624
// unnormalise numerator (remainder now)
625625
for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
626626
mpz_dig_t n = *num;
627627
*num = ((n >> norm_shift) | carry) & DIG_MASK;
628-
carry = n << (DIG_SIZE - norm_shift);
628+
carry = (mpz_dbl_dig_t)n << (DIG_SIZE - norm_shift);
629629
}
630630

631631
// strip trailing zeros

tests/basics/int_big_div.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
for lhs in (1000000000000000000000000, 10000000000100000000000000, 10012003400000000000000007, 12349083434598210349871029923874109871234789):
22
for rhs in range(1, 555):
33
print(lhs // rhs)
4+
5+
# these check an edge case on 64-bit machines where two mpz limbs
6+
# are used and the most significant one has the MSB set
7+
x = 0x8000000000000000
8+
print((x + 1) // x)
9+
x = 0x86c60128feff5330
10+
print((x + 1) // x)

tests/basics/int_big_mod.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@
88
y = delta * (j)# - 5) # TODO reinstate negative number test when % is working with sign correctly
99
if y != 0:
1010
print(x % y)
11+
12+
# these check an edge case on 64-bit machines where two mpz limbs
13+
# are used and the most significant one has the MSB set
14+
x = 0x8000000000000000
15+
print((x + 1) % x)
16+
x = 0x86c60128feff5330
17+
print((x + 1) % x)

0 commit comments

Comments
 (0)