2929
3030#include "py/mpz.h"
3131
32- // this is only needed for mp_not_implemented, which should eventually be removed
33- #include "py/runtime.h"
34-
3532#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
3633
3734#define DIG_SIZE (MPZ_DIG_SIZE)
@@ -199,6 +196,14 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
199196 return idig + 1 - oidig ;
200197}
201198
199+ STATIC mp_uint_t mpn_remove_trailing_zeros (mpz_dig_t * oidig , mpz_dig_t * idig ) {
200+ for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
201+ }
202+ return idig + 1 - oidig ;
203+ }
204+
205+ #if MICROPY_OPT_MPZ_BITWISE
206+
202207/* computes i = j & k
203208 returns number of digits in i
204209 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
@@ -211,41 +216,46 @@ STATIC mp_uint_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t
211216 * idig = * jdig & * kdig ;
212217 }
213218
214- // remove trailing zeros
215- for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
216- }
217-
218- return idig + 1 - oidig ;
219+ return mpn_remove_trailing_zeros (oidig , idig );
219220}
220221
221- /* computes i = j & -k = j & (~k + 1)
222+ #endif
223+
224+ /* i = -((-j) & (-k)) = ~((~j + 1) & (~k + 1)) + 1
225+ i = (j & (-k)) = (j & (~k + 1)) = ( j & (~k + 1))
226+ i = ((-j) & k) = ((~j + 1) & k) = ((~j + 1) & k )
227+ computes general form:
228+ i = (im ^ (((j ^ jm) + jc) & ((k ^ km) + kc))) + ic where Xm = Xc == 0 ? 0 : DIG_MASK
222229 returns number of digits in i
223- assumes enough memory in i; assumes normalised j, k
230+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
224231 can have i, j, k pointing to same memory
225232*/
226- STATIC mp_uint_t mpn_and_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ) {
233+ STATIC mp_uint_t mpn_and_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
234+ mpz_dbl_dig_t carryi , mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
227235 mpz_dig_t * oidig = idig ;
228- mpz_dbl_dig_t carry = 1 ;
236+ mpz_dig_t imask = (0 == carryi ) ? 0 : DIG_MASK ;
237+ mpz_dig_t jmask = (0 == carryj ) ? 0 : DIG_MASK ;
238+ mpz_dig_t kmask = (0 == carryk ) ? 0 : DIG_MASK ;
229239
230- for (; jlen > 0 && klen > 0 ; -- jlen , -- klen , ++ idig , ++ jdig , ++ kdig ) {
231- carry += * kdig ^ DIG_MASK ;
232- * idig = (* jdig & carry ) & DIG_MASK ;
233- carry >>= DIG_SIZE ;
240+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
241+ carryj += * jdig ^ jmask ;
242+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ ^ kmask ) : kmask ;
243+ carryi += ((carryj & carryk ) ^ imask ) & DIG_MASK ;
244+ * idig = carryi & DIG_MASK ;
245+ carryk >>= DIG_SIZE ;
246+ carryj >>= DIG_SIZE ;
247+ carryi >>= DIG_SIZE ;
234248 }
235249
236- for (; jlen > 0 ; -- jlen , ++ idig , ++ jdig ) {
237- carry += DIG_MASK ;
238- * idig = (* jdig & carry ) & DIG_MASK ;
239- carry >>= DIG_SIZE ;
240- }
241-
242- // remove trailing zeros
243- for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
250+ if (0 != carryi ) {
251+ * idig ++ = carryi ;
244252 }
245253
246- return idig + 1 - oidig ;
254+ return mpn_remove_trailing_zeros ( oidig , idig ) ;
247255}
248256
257+ #if MICROPY_OPT_MPZ_BITWISE
258+
249259/* computes i = j | k
250260 returns number of digits in i
251261 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -267,6 +277,74 @@ STATIC mp_uint_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
267277 return idig - oidig ;
268278}
269279
280+ #endif
281+
282+ /* i = -((-j) | (-k)) = ~((~j + 1) | (~k + 1)) + 1
283+ i = -(j | (-k)) = -(j | (~k + 1)) = ~( j | (~k + 1)) + 1
284+ i = -((-j) | k) = -((~j + 1) | k) = ~((~j + 1) | k ) + 1
285+ computes general form:
286+ i = ~(((j ^ jm) + jc) | ((k ^ km) + kc)) + 1 where Xm = Xc == 0 ? 0 : DIG_MASK
287+ returns number of digits in i
288+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
289+ can have i, j, k pointing to same memory
290+ */
291+
292+ #if MICROPY_OPT_MPZ_BITWISE
293+
294+ STATIC mp_uint_t mpn_or_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
295+ mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
296+ mpz_dig_t * oidig = idig ;
297+ mpz_dbl_dig_t carryi = 1 ;
298+ mpz_dig_t jmask = (0 == carryj ) ? 0 : DIG_MASK ;
299+ mpz_dig_t kmask = (0 == carryk ) ? 0 : DIG_MASK ;
300+
301+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
302+ carryj += * jdig ^ jmask ;
303+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ ^ kmask ) : kmask ;
304+ carryi += ((carryj | carryk ) ^ DIG_MASK ) & DIG_MASK ;
305+ * idig = carryi & DIG_MASK ;
306+ carryk >>= DIG_SIZE ;
307+ carryj >>= DIG_SIZE ;
308+ carryi >>= DIG_SIZE ;
309+ }
310+
311+ if (0 != carryi ) {
312+ * idig ++ = carryi ;
313+ }
314+
315+ return mpn_remove_trailing_zeros (oidig , idig );
316+ }
317+
318+ #else
319+
320+ STATIC mp_uint_t mpn_or_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
321+ mpz_dbl_dig_t carryi , mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
322+ mpz_dig_t * oidig = idig ;
323+ mpz_dig_t imask = (0 == carryi ) ? 0 : DIG_MASK ;
324+ mpz_dig_t jmask = (0 == carryj ) ? 0 : DIG_MASK ;
325+ mpz_dig_t kmask = (0 == carryk ) ? 0 : DIG_MASK ;
326+
327+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
328+ carryj += * jdig ^ jmask ;
329+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ ^ kmask ) : kmask ;
330+ carryi += ((carryj | carryk ) ^ imask ) & DIG_MASK ;
331+ * idig = carryi & DIG_MASK ;
332+ carryk >>= DIG_SIZE ;
333+ carryj >>= DIG_SIZE ;
334+ carryi >>= DIG_SIZE ;
335+ }
336+
337+ if (0 != carryi ) {
338+ * idig ++ = carryi ;
339+ }
340+
341+ return mpn_remove_trailing_zeros (oidig , idig );
342+ }
343+
344+ #endif
345+
346+ #if MICROPY_OPT_MPZ_BITWISE
347+
270348/* computes i = j ^ k
271349 returns number of digits in i
272350 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -285,11 +363,39 @@ STATIC mp_uint_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
285363 * idig = * jdig ;
286364 }
287365
288- // remove trailing zeros
289- for (-- idig ; idig >= oidig && * idig == 0 ; -- idig ) {
366+ return mpn_remove_trailing_zeros (oidig , idig );
367+ }
368+
369+ #endif
370+
371+ /* i = (-j) ^ (-k) = ~(j - 1) ^ ~(k - 1) = (j - 1) ^ (k - 1)
372+ i = -(j ^ (-k)) = -(j ^ ~(k - 1)) = ~(j ^ ~(k - 1)) + 1 = (j ^ (k - 1)) + 1
373+ i = -((-j) ^ k) = -(~(j - 1) ^ k) = ~(~(j - 1) ^ k) + 1 = ((j - 1) ^ k) + 1
374+ computes general form:
375+ i = ((j - 1 + jc) ^ (k - 1 + kc)) + ic
376+ returns number of digits in i
377+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
378+ can have i, j, k pointing to same memory
379+ */
380+ STATIC mp_uint_t mpn_xor_neg (mpz_dig_t * idig , const mpz_dig_t * jdig , mp_uint_t jlen , const mpz_dig_t * kdig , mp_uint_t klen ,
381+ mpz_dbl_dig_t carryi , mpz_dbl_dig_t carryj , mpz_dbl_dig_t carryk ) {
382+ mpz_dig_t * oidig = idig ;
383+
384+ for (; jlen > 0 ; ++ idig , ++ jdig ) {
385+ carryj += * jdig + DIG_MASK ;
386+ carryk += (-- klen <= -- jlen ) ? (* kdig ++ + DIG_MASK ) : DIG_MASK ;
387+ carryi += (carryj ^ carryk ) & DIG_MASK ;
388+ * idig = carryi & DIG_MASK ;
389+ carryk >>= DIG_SIZE ;
390+ carryj >>= DIG_SIZE ;
391+ carryi >>= DIG_SIZE ;
290392 }
291393
292- return idig + 1 - oidig ;
394+ if (0 != carryi ) {
395+ * idig ++ = carryi ;
396+ }
397+
398+ return mpn_remove_trailing_zeros (oidig , idig );
293399}
294400
295401/* computes i = i * d1 + d2
@@ -1097,81 +1203,106 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
10971203 can have dest, lhs, rhs the same
10981204*/
10991205void mpz_and_inpl (mpz_t * dest , const mpz_t * lhs , const mpz_t * rhs ) {
1100- if (lhs -> neg == rhs -> neg ) {
1101- if (lhs -> neg == 0 ) {
1102- // make sure lhs has the most digits
1103- if (lhs -> len < rhs -> len ) {
1104- const mpz_t * temp = lhs ;
1105- lhs = rhs ;
1106- rhs = temp ;
1107- }
1108- // do the and'ing
1109- mpz_need_dig (dest , rhs -> len );
1110- dest -> len = mpn_and (dest -> dig , lhs -> dig , rhs -> dig , rhs -> len );
1111- dest -> neg = 0 ;
1112- } else {
1113- // TODO both args are negative
1114- mp_not_implemented ("bignum and with negative args" );
1115- }
1206+ // make sure lhs has the most digits
1207+ if (lhs -> len < rhs -> len ) {
1208+ const mpz_t * temp = lhs ;
1209+ lhs = rhs ;
1210+ rhs = temp ;
1211+ }
1212+
1213+ #if MICROPY_OPT_MPZ_BITWISE
1214+
1215+ if ((0 == lhs -> neg ) && (0 == rhs -> neg )) {
1216+ mpz_need_dig (dest , lhs -> len );
1217+ dest -> len = mpn_and (dest -> dig , lhs -> dig , rhs -> dig , rhs -> len );
1218+ dest -> neg = 0 ;
11161219 } else {
1117- // args have different sign
1118- // make sure lhs is the positive arg
1119- if (rhs -> neg == 0 ) {
1120- const mpz_t * temp = lhs ;
1121- lhs = rhs ;
1122- rhs = temp ;
1123- }
11241220 mpz_need_dig (dest , lhs -> len + 1 );
1125- dest -> len = mpn_and_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len );
1126- assert ( dest -> len <= dest -> alloc );
1127- dest -> neg = 0 ;
1221+ dest -> len = mpn_and_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1222+ lhs -> neg == rhs -> neg , 0 != lhs -> neg , 0 != rhs -> neg );
1223+ dest -> neg = lhs -> neg & rhs -> neg ;
11281224 }
1225+
1226+ #else
1227+
1228+ mpz_need_dig (dest , lhs -> len + (lhs -> neg || rhs -> neg ));
1229+ dest -> len = mpn_and_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1230+ (lhs -> neg == rhs -> neg ) ? lhs -> neg : 0 , lhs -> neg , rhs -> neg );
1231+ dest -> neg = lhs -> neg & rhs -> neg ;
1232+
1233+ #endif
11291234}
11301235
11311236/* computes dest = lhs | rhs
11321237 can have dest, lhs, rhs the same
11331238*/
11341239void mpz_or_inpl (mpz_t * dest , const mpz_t * lhs , const mpz_t * rhs ) {
1135- if (mpn_cmp (lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ) < 0 ) {
1240+ // make sure lhs has the most digits
1241+ if (lhs -> len < rhs -> len ) {
11361242 const mpz_t * temp = lhs ;
11371243 lhs = rhs ;
11381244 rhs = temp ;
11391245 }
11401246
1141- if (lhs -> neg == rhs -> neg ) {
1247+ #if MICROPY_OPT_MPZ_BITWISE
1248+
1249+ if ((0 == lhs -> neg ) && (0 == rhs -> neg )) {
11421250 mpz_need_dig (dest , lhs -> len );
11431251 dest -> len = mpn_or (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len );
1252+ dest -> neg = 0 ;
11441253 } else {
1145- mpz_need_dig (dest , lhs -> len );
1146- // TODO
1147- mp_not_implemented ( "bignum or with negative args" );
1148- // dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len) ;
1254+ mpz_need_dig (dest , lhs -> len + 1 );
1255+ dest -> len = mpn_or_neg ( dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1256+ 0 != lhs -> neg , 0 != rhs -> neg );
1257+ dest -> neg = 1 ;
11491258 }
11501259
1151- dest -> neg = lhs -> neg ;
1260+ #else
1261+
1262+ mpz_need_dig (dest , lhs -> len + (lhs -> neg || rhs -> neg ));
1263+ dest -> len = mpn_or_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1264+ (lhs -> neg || rhs -> neg ), lhs -> neg , rhs -> neg );
1265+ dest -> neg = lhs -> neg | rhs -> neg ;
1266+
1267+ #endif
11521268}
11531269
11541270/* computes dest = lhs ^ rhs
11551271 can have dest, lhs, rhs the same
11561272*/
11571273void mpz_xor_inpl (mpz_t * dest , const mpz_t * lhs , const mpz_t * rhs ) {
1158- if (mpn_cmp (lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ) < 0 ) {
1274+ // make sure lhs has the most digits
1275+ if (lhs -> len < rhs -> len ) {
11591276 const mpz_t * temp = lhs ;
11601277 lhs = rhs ;
11611278 rhs = temp ;
11621279 }
11631280
1281+ #if MICROPY_OPT_MPZ_BITWISE
1282+
11641283 if (lhs -> neg == rhs -> neg ) {
11651284 mpz_need_dig (dest , lhs -> len );
1166- dest -> len = mpn_xor (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len );
1285+ if (lhs -> neg == 0 ) {
1286+ dest -> len = mpn_xor (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len );
1287+ } else {
1288+ dest -> len = mpn_xor_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len , 0 , 0 , 0 );
1289+ }
1290+ dest -> neg = 0 ;
11671291 } else {
1168- mpz_need_dig (dest , lhs -> len );
1169- // TODO
1170- mp_not_implemented ( "bignum xor with negative args" );
1171- // dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len) ;
1292+ mpz_need_dig (dest , lhs -> len + 1 );
1293+ dest -> len = mpn_xor_neg ( dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len , 1 ,
1294+ 0 == lhs -> neg , 0 == rhs -> neg );
1295+ dest -> neg = 1 ;
11721296 }
11731297
1174- dest -> neg = 0 ;
1298+ #else
1299+
1300+ mpz_need_dig (dest , lhs -> len + (lhs -> neg || rhs -> neg ));
1301+ dest -> len = mpn_xor_neg (dest -> dig , lhs -> dig , lhs -> len , rhs -> dig , rhs -> len ,
1302+ (lhs -> neg != rhs -> neg ), 0 == lhs -> neg , 0 == rhs -> neg );
1303+ dest -> neg = lhs -> neg ^ rhs -> neg ;
1304+
1305+ #endif
11751306}
11761307
11771308/* computes dest = lhs * rhs
0 commit comments