Skip to content

Commit 460b086

Browse files
committed
py/mpz: Fix mpn_div so that it doesn't modify memory of denominator.
Previous to this patch bignum division and modulo would temporarily modify the RHS argument to the operation (eg x/y would modify y), but on return the RHS would be restored to its original value. This is not allowed because arguments to binary operations are const, and in particular might live in ROM. The modification was to normalise the arg (and then unnormalise before returning), and this patch makes it so the normalisation is done on the fly and the arg is now accessed as read-only. This change doesn't increase the order complexity of the operation, and actually reduces code size.
1 parent de5e0ed commit 460b086

1 file changed

Lines changed: 30 additions & 27 deletions

File tree

py/mpz.c

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,8 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
454454
assumes num_dig has enough memory to be extended by 1 digit
455455
assumes quo_dig has enough memory (as many digits as num)
456456
assumes quo_dig is filled with zeros
457-
modifies den_dig memory, but restors it to original state at end
458457
*/
459-
460-
STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
458+
STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, const mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
461459
mpz_dig_t *orig_num_dig = num_dig;
462460
mpz_dig_t *orig_quo_dig = quo_dig;
463461
mpz_dig_t norm_shift = 0;
@@ -478,6 +476,11 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
478476
}
479477
}
480478

479+
// We need to normalise the denominator (leading bit of leading digit is 1)
480+
// so that the division routine works. Since the denominator memory is
481+
// read-only we do the normalisation on the fly, each time a digit of the
482+
// denominator is needed. We need to know is how many bits to shift by.
483+
481484
// count number of leading zeros in leading digit of denominator
482485
{
483486
mpz_dig_t d = den_dig[den_len - 1];
@@ -487,13 +490,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
487490
}
488491
}
489492

490-
// normalise denomenator (leading bit of leading digit is 1)
491-
for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) {
492-
mpz_dig_t d = *den;
493-
*den = ((d << norm_shift) | carry) & DIG_MASK;
494-
carry = (mpz_dbl_dig_t)d >> (DIG_SIZE - norm_shift);
495-
}
496-
497493
// now need to shift numerator by same amount as denominator
498494
// first, increase length of numerator in case we need more room to shift
499495
num_dig[*num_len] = 0;
@@ -505,7 +501,10 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
505501
}
506502

507503
// cache the leading digit of the denominator
508-
lead_den_digit = den_dig[den_len - 1];
504+
lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift;
505+
if (den_len >= 2) {
506+
lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift);
507+
}
509508

510509
// point num_dig to last digit in numerator
511510
num_dig += *num_len - 1;
@@ -540,10 +539,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
540539
// round up).
541540

542541
if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
542+
const mpz_dig_t *d = den_dig;
543+
mpz_dbl_dig_t d_norm = 0;
543544
mpz_dbl_dig_signed_t borrow = 0;
544545

545-
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
546-
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
546+
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
547+
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
548+
borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
547549
*n = borrow & DIG_MASK;
548550
borrow >>= DIG_SIZE;
549551
}
@@ -553,9 +555,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
553555

554556
// adjust quotient if it is too big
555557
for (; borrow != 0; --quo) {
558+
d = den_dig;
559+
d_norm = 0;
556560
mpz_dbl_dig_t carry = 0;
557-
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
558-
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
561+
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
562+
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
563+
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
559564
*n = carry & DIG_MASK;
560565
carry >>= DIG_SIZE;
561566
}
@@ -566,10 +571,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
566571
borrow += carry;
567572
}
568573
} else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
574+
const mpz_dig_t *d = den_dig;
575+
mpz_dbl_dig_t d_norm = 0;
569576
mpz_dbl_dig_t borrow = 0;
570577

571-
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
572-
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d);
578+
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
579+
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
580+
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
573581
if (x >= *n || *n - x <= borrow) {
574582
borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
575583
*n = (-borrow) & DIG_MASK;
@@ -590,9 +598,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
590598

591599
// adjust quotient if it is too big
592600
for (; borrow != 0; --quo) {
601+
d = den_dig;
602+
d_norm = 0;
593603
mpz_dbl_dig_t carry = 0;
594-
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
595-
carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
604+
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
605+
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
606+
carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
596607
*n = carry & DIG_MASK;
597608
carry >>= DIG_SIZE;
598609
}
@@ -614,13 +625,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
614625
--(*num_len);
615626
}
616627

617-
// unnormalise denomenator
618-
for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) {
619-
mpz_dig_t d = *den;
620-
*den = ((d >> norm_shift) | carry) & DIG_MASK;
621-
carry = (mpz_dbl_dig_t)d << (DIG_SIZE - norm_shift);
622-
}
623-
624628
// unnormalise numerator (remainder now)
625629
for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
626630
mpz_dig_t n = *num;
@@ -1506,7 +1510,6 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
15061510
dest_quo->len = 0;
15071511
mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
15081512
mpz_set(dest_rem, lhs);
1509-
//rhs->dig[rhs->len] = 0;
15101513
mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
15111514

15121515
// check signs and do Python style modulo

0 commit comments

Comments
 (0)