diff --git a/gmp-impl.h b/gmp-impl.h index 4c1663ce..b7963807 100644 --- a/gmp-impl.h +++ b/gmp-impl.h @@ -1064,11 +1064,14 @@ void mpn_toom3_mul _PROTO ((mp_ptr, mp_srcptr, mp_size_t, mp_srcptr, mp_size_t,mp_ptr)); #define mpn_toom3_interpolate __MPN(toom3_interpolate) -void -mpn_toom3_interpolate _PROTO ((mp_ptr c, mp_ptr v1, mp_ptr v2, mp_ptr vm1, +void mpn_toom3_interpolate _PROTO ((mp_ptr c, mp_ptr v1, mp_ptr v2, mp_ptr vm1, mp_ptr vinf, mp_size_t k, mp_size_t rr2, int sa, mp_limb_t vinf0, mp_ptr ws)); +#define mpn_toom32_mul __MPN(toom32_mul) +void mpn_toom32_mul _PROTO ((mp_ptr c, mp_srcptr a, mp_size_t an, mp_srcptr b, + mp_size_t bn, mp_ptr t)); + #define mpn_toom42_mul __MPN(toom42_mul) void mpn_toom42_mul _PROTO ((mp_ptr, mp_srcptr, mp_size_t, mp_srcptr, mp_size_t,mp_ptr)); diff --git a/mpn/generic/mul.c b/mpn/generic/mul.c index 6dbaf89b..25bc75ed 100644 --- a/mpn/generic/mul.c +++ b/mpn/generic/mul.c @@ -144,31 +144,33 @@ mpn_mul (mp_ptr prodp, } k = (un + 3)/4; // ceil(un/3) - - if ((un + vn >= 2*MUL_TOOM3_THRESHOLD) && (vn > k)) - { - mp_ptr ws; - - if (vn < 2*k) // un/2 >= vn > un/4 - { - TMP_DECL; - TMP_MARK; - ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un)); - mpn_toom42_mul(prodp, up, un, vp, vn, ws); - TMP_FREE; - return prodp[un + vn - 1]; - } - k = (un+2)/3; //ceil(u/3) - if (vn > 2*k) // un >= vn > 2un/3 - { - TMP_DECL; - TMP_MARK; + if ((un + vn >= 2*MUL_TOOM3_THRESHOLD) && (vn > k)) + { + mp_ptr ws; + TMP_DECL; + TMP_MARK; + + if (vn < 2*k) // un/2 >= vn > un/4 + { ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un)); - mpn_toom3_mul(prodp, up, un, vp, vn, ws); + mpn_toom42_mul(prodp, up, un, vp, vn, ws); TMP_FREE; return prodp[un + vn - 1]; - } + } + + k = (un+2)/3; //ceil(u/3) + if (vn > 2*k) // un >= vn > 2un/3 + { + ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un)); + mpn_toom3_mul(prodp, up, un, vp, vn, ws); + } else + { + ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un)); + mpn_toom32_mul(prodp, up, un, vp, vn, ws); + } + TMP_FREE; + return prodp[un + vn - 1]; } mpn_mul_n (prodp, up, vp, vn); diff --git a/mpn/generic/toom3_mul.c b/mpn/generic/toom3_mul.c index 7e0c4b10..0eaff8a9 100644 --- a/mpn/generic/toom3_mul.c +++ b/mpn/generic/toom3_mul.c @@ -604,3 +604,207 @@ mpn_toom42_mul (mp_ptr c, mp_srcptr a, mp_size_t an, mp_srcptr b, mp_size_t bn, #undef v2 #undef vinf } + +/* + We have a 3x2 blocked multiplication and therefore the output is of length + 4 blocks. Therefore we evaluate at the 4 points 0, inf, -1, 1, i.e. we need + (a0*b0), (a2*b1), (a0-a1+a2)*(b0-b1), (a0+a1+a2)*(b0+b1). + The multiplication will be (2k+r) x (k + r2) and therefore the output has + space for 3k + rr2 limbs. +*/ +void +mpn_toom32_mul (mp_ptr c, mp_srcptr a, mp_size_t an, mp_srcptr b, mp_size_t bn, mp_ptr t) +{ + mp_size_t k, k1, kk1, r, r2, twok, threek, rr2, n1, n2; + mp_limb_t cy, cc, saved; + mp_ptr trec; + int sa, sb; + mp_ptr c1, c2, c3, c4, c5, t1, t2, t3, t4; + + ASSERT(GMP_NUMB_BITS >= 6); + + k = (an + 2) / 3; /* ceil(an/3) */ + ASSERT(bn > k); + ASSERT(an >= 20); + + twok = 2 * k; + threek = 3 * k; + k1 = k + 1; + kk1 = k + k1; + r = an - twok; /* last chunk */ + r2 = bn - k; /* last chunk */ + rr2 = r + r2; + + c1 = c + k; + c2 = c1 + k; + c3 = c2 + k; + c4 = c3 + k; + c5 = c4 + k; + + t1 = t + k; + t2 = t1 + k; + t3 = t2 + k; + t4 = t3 + k; + + trec = t + 3 * k + 3; + + /* put a0+a2 in {t, k+1}, and b0+b1 in {t2 + 2, k+1}; + put a0+a1+a2 in {t1 + 1, k+1} + */ + cy = mpn_add_n (t, a, a + twok, r); + t3[2] = mpn_add_n (t2 + 2, b, b + k, r2); + if (r < k) + { + cy = mpn_add_1 (t + r, a + r, k - r, cy); + } + if (r2 < k) + { + t3[2] = mpn_add_1 (t2 + 2 + r2, b + r2, k - r2, t3[2]); + } + t2[1] = (t1[0] = cy) + mpn_add_n (t1 + 1, t, a + k, k); + + /* compute v1 := (a0+a1+a2)*(b0+b1) in {c1, 2k+1}; + since v1 < 6*B^(2k), v1 uses only 2k+1 words if GMP_NUMB_BITS >= 3 */ + TOOM3_MUL_REC (c1, t1 + 1, t2 + 2, k1, trec); + + saved = c1[0]; + + /* {c,2k} {c+2k,2k+1} {c+4k+1,r+r2-1} + v1 + */ + + /* put |a0-a1+a2| in {c0, k+1} and |b0-b1| in {t2 + 2,k+1} */ + /* sa = sign(a0-a1+a2) */ + /* sb = sign(b0-b1) */ + sa = (t[k] != 0) ? 1 : mpn_cmp (t, a + k, k); + if (sa >= 0) c[k] = t[k] - mpn_sub_n (c, t, a + k, k); + else c[k] = -mpn_sub_n (c, a + k, t, k); + + n1 = k; + n2 = r2; + MPN_NORMALIZE(b, n1); + MPN_NORMALIZE(b+k, n2); + if (n1 != n2) sb = (n1 > n2) ? 1 : -1; + else sb = mpn_cmp (b, b + k, n2); + + if (sb >= 0) + { + t3[2] = mpn_sub_n (t2 + 2, b, b + k, r2); + if (k > r2) t3[2] = -mpn_sub_1(t2 + 2 + r2, b + r2, k - r2, t3[2]); + } else + { + mpn_sub_n (t2 + 2, b + k, b, r2); + MPN_ZERO(t2 + r2 + 2, k1 - r2); + } + + sa *= sb; /* sign of vm1 */ + + /* compute vm1 := (a0-a1+a2)*(b0-b1) in {t, 2k+1}; + since |vm1| < 2*B^(2k), vm1 uses only 2k+1 limbs */ + TOOM3_MUL_REC (t, t2 + 2, c, k1, trec); + + /* {c,2k} {c+2k,2k+1} {c+4k+1,r+r2-1} + v1 + + {t, 2k+1} {t+2k+1, 2k + 1} + vm1 + */ + + c1[0] = saved; + + + /* {c,k} {c+k,2k+1} {c+3k+1,r+r2-1} + v1 + + {t, 2k+1} {t+2k+1, 2k + 2} + vm1 + */ + + /* Compute vm1 <-- (vm1 + v1)/2 (note vm1 + v1 is positive) */ + + if (sa > 0) + { +#if HAVE_NATIVE_mpn_rsh1add_n + mpn_rsh1add_n(t, t, c1, kk1); +#else + mpn_add_n(t, t, c1, kk1); + mpn_rshift1(t, t, kk1); +#endif + } else + { +#if HAVE_NATIVE_mpn_rsh1sub_n + mpn_rsh1sub_n(t, c1, t, kk1); +#else + mpn_sub_n(t, c1, t, kk1); + mpn_rshift1(t, t, kk1); +#endif + } + + /* Compute v1 <-- v1 - vm1 */ + + mpn_sub_n(c1, c1, t, kk1); + + /* Note we could technically overflow + the end of the output if we add + everything in place without subtracting + the right things first. We get around + this by throwing away any high limbs + and carries, which must of necessity + cancel. + + First we add vm1 in its place... + */ + + n1 = kk1; + MPN_NORMALIZE(t, n1); + + if (n1 >= k + rr2) /* if > here, high limb of vm1 and carry may be discarded */ + { + cy = mpn_add_n(c2, c2, t, k1); + mpn_add_1(c3 + 1, t + k1, rr2 - 1, cy); + n2 = threek + rr2; + } else + { + c2[k1] = mpn_add_n(c2, c2, t, k1); + if (n1 > k1) c2[n1] = mpn_add_1(c3 + 1, t + k1, n1 - k1, c2[k1]); + n2 = twok + MAX(n1, k1) + 1; + } + + /* Compute vinf := a2*b1 in {t, rr2} */ + + if (r == r2) TOOM3_MUL_REC (t, a + twok, b + k, r, trec); + else if (r > r2) mpn_mul(t, a + twok, r, b + k, r2); + else mpn_mul(t, b + k, r2, a + twok, r); + + /* Add vinf into place */ + + cy = mpn_add_n(c3, c3, t, n2 - threek); + if (rr2 + threek > n2) + mpn_add_1(c + n2, t + n2 - threek, rr2 + threek - n2, cy); + + /* v1 <-- v1 - vinf */ + + cy = mpn_sub_n(c1, c1, t, rr2); + if (cy) mpn_sub_1(c1 + rr2, c1 + rr2, twok, cy); + + /* compute v0 := a0*b0 in {t, 2k} */ + + TOOM3_MUL_REC (t, a, b, k, trec); + + /* Add v0 into place */ + + MPN_COPY(c, t, k); + cy = mpn_add_n(c + k, c + k, t + k, k); + if (cy) mpn_add_1(c + twok, c + twok, k + rr2, cy); + + /* vm1 <-- vm1 - v0 */ + + if (twok >= k + rr2) + mpn_sub_n(c2, c2, t, k + rr2); + else + { + cy = mpn_sub_n(c2, c2, t, twok); + mpn_sub_1(c4, c4, rr2 + k - twok, cy); + } +} +