Added toom32 for unbalanced multiplications.

This commit is contained in:
wbhart 2009-05-12 18:28:20 +00:00
parent c6881fa3a9
commit 21f51a706c
3 changed files with 232 additions and 23 deletions

View File

@ -1064,11 +1064,14 @@ void mpn_toom3_mul _PROTO ((mp_ptr, mp_srcptr, mp_size_t, mp_srcptr,
mp_size_t,mp_ptr));
#define mpn_toom3_interpolate __MPN(toom3_interpolate)
void
mpn_toom3_interpolate _PROTO ((mp_ptr c, mp_ptr v1, mp_ptr v2, mp_ptr vm1,
void mpn_toom3_interpolate _PROTO ((mp_ptr c, mp_ptr v1, mp_ptr v2, mp_ptr vm1,
mp_ptr vinf, mp_size_t k, mp_size_t rr2, int sa,
mp_limb_t vinf0, mp_ptr ws));
#define mpn_toom32_mul __MPN(toom32_mul)
void mpn_toom32_mul _PROTO ((mp_ptr c, mp_srcptr a, mp_size_t an, mp_srcptr b,
mp_size_t bn, mp_ptr t));
#define mpn_toom42_mul __MPN(toom42_mul)
void mpn_toom42_mul _PROTO ((mp_ptr, mp_srcptr, mp_size_t, mp_srcptr, mp_size_t,mp_ptr));

View File

@ -144,31 +144,33 @@ mpn_mul (mp_ptr prodp,
}
k = (un + 3)/4; // ceil(un/3)
if ((un + vn >= 2*MUL_TOOM3_THRESHOLD) && (vn > k))
{
mp_ptr ws;
if (vn < 2*k) // un/2 >= vn > un/4
{
TMP_DECL;
TMP_MARK;
ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un));
mpn_toom42_mul(prodp, up, un, vp, vn, ws);
TMP_FREE;
return prodp[un + vn - 1];
}
k = (un+2)/3; //ceil(u/3)
if (vn > 2*k) // un >= vn > 2un/3
{
TMP_DECL;
TMP_MARK;
if ((un + vn >= 2*MUL_TOOM3_THRESHOLD) && (vn > k))
{
mp_ptr ws;
TMP_DECL;
TMP_MARK;
if (vn < 2*k) // un/2 >= vn > un/4
{
ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un));
mpn_toom3_mul(prodp, up, un, vp, vn, ws);
mpn_toom42_mul(prodp, up, un, vp, vn, ws);
TMP_FREE;
return prodp[un + vn - 1];
}
}
k = (un+2)/3; //ceil(u/3)
if (vn > 2*k) // un >= vn > 2un/3
{
ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un));
mpn_toom3_mul(prodp, up, un, vp, vn, ws);
} else
{
ws = TMP_ALLOC_LIMBS (MPN_TOOM3_MUL_TSIZE(un));
mpn_toom32_mul(prodp, up, un, vp, vn, ws);
}
TMP_FREE;
return prodp[un + vn - 1];
}
mpn_mul_n (prodp, up, vp, vn);

View File

@ -604,3 +604,207 @@ mpn_toom42_mul (mp_ptr c, mp_srcptr a, mp_size_t an, mp_srcptr b, mp_size_t bn,
#undef v2
#undef vinf
}
/*
We have a 3x2 blocked multiplication and therefore the output is of length
4 blocks. Therefore we evaluate at the 4 points 0, inf, -1, 1, i.e. we need
(a0*b0), (a2*b1), (a0-a1+a2)*(b0-b1), (a0+a1+a2)*(b0+b1).
The multiplication will be (2k+r) x (k + r2) and therefore the output has
space for 3k + rr2 limbs.
*/
void
mpn_toom32_mul (mp_ptr c, mp_srcptr a, mp_size_t an, mp_srcptr b, mp_size_t bn, mp_ptr t)
{
mp_size_t k, k1, kk1, r, r2, twok, threek, rr2, n1, n2;
mp_limb_t cy, cc, saved;
mp_ptr trec;
int sa, sb;
mp_ptr c1, c2, c3, c4, c5, t1, t2, t3, t4;
ASSERT(GMP_NUMB_BITS >= 6);
k = (an + 2) / 3; /* ceil(an/3) */
ASSERT(bn > k);
ASSERT(an >= 20);
twok = 2 * k;
threek = 3 * k;
k1 = k + 1;
kk1 = k + k1;
r = an - twok; /* last chunk */
r2 = bn - k; /* last chunk */
rr2 = r + r2;
c1 = c + k;
c2 = c1 + k;
c3 = c2 + k;
c4 = c3 + k;
c5 = c4 + k;
t1 = t + k;
t2 = t1 + k;
t3 = t2 + k;
t4 = t3 + k;
trec = t + 3 * k + 3;
/* put a0+a2 in {t, k+1}, and b0+b1 in {t2 + 2, k+1};
put a0+a1+a2 in {t1 + 1, k+1}
*/
cy = mpn_add_n (t, a, a + twok, r);
t3[2] = mpn_add_n (t2 + 2, b, b + k, r2);
if (r < k)
{
cy = mpn_add_1 (t + r, a + r, k - r, cy);
}
if (r2 < k)
{
t3[2] = mpn_add_1 (t2 + 2 + r2, b + r2, k - r2, t3[2]);
}
t2[1] = (t1[0] = cy) + mpn_add_n (t1 + 1, t, a + k, k);
/* compute v1 := (a0+a1+a2)*(b0+b1) in {c1, 2k+1};
since v1 < 6*B^(2k), v1 uses only 2k+1 words if GMP_NUMB_BITS >= 3 */
TOOM3_MUL_REC (c1, t1 + 1, t2 + 2, k1, trec);
saved = c1[0];
/* {c,2k} {c+2k,2k+1} {c+4k+1,r+r2-1}
v1
*/
/* put |a0-a1+a2| in {c0, k+1} and |b0-b1| in {t2 + 2,k+1} */
/* sa = sign(a0-a1+a2) */
/* sb = sign(b0-b1) */
sa = (t[k] != 0) ? 1 : mpn_cmp (t, a + k, k);
if (sa >= 0) c[k] = t[k] - mpn_sub_n (c, t, a + k, k);
else c[k] = -mpn_sub_n (c, a + k, t, k);
n1 = k;
n2 = r2;
MPN_NORMALIZE(b, n1);
MPN_NORMALIZE(b+k, n2);
if (n1 != n2) sb = (n1 > n2) ? 1 : -1;
else sb = mpn_cmp (b, b + k, n2);
if (sb >= 0)
{
t3[2] = mpn_sub_n (t2 + 2, b, b + k, r2);
if (k > r2) t3[2] = -mpn_sub_1(t2 + 2 + r2, b + r2, k - r2, t3[2]);
} else
{
mpn_sub_n (t2 + 2, b + k, b, r2);
MPN_ZERO(t2 + r2 + 2, k1 - r2);
}
sa *= sb; /* sign of vm1 */
/* compute vm1 := (a0-a1+a2)*(b0-b1) in {t, 2k+1};
since |vm1| < 2*B^(2k), vm1 uses only 2k+1 limbs */
TOOM3_MUL_REC (t, t2 + 2, c, k1, trec);
/* {c,2k} {c+2k,2k+1} {c+4k+1,r+r2-1}
v1
{t, 2k+1} {t+2k+1, 2k + 1}
vm1
*/
c1[0] = saved;
/* {c,k} {c+k,2k+1} {c+3k+1,r+r2-1}
v1
{t, 2k+1} {t+2k+1, 2k + 2}
vm1
*/
/* Compute vm1 <-- (vm1 + v1)/2 (note vm1 + v1 is positive) */
if (sa > 0)
{
#if HAVE_NATIVE_mpn_rsh1add_n
mpn_rsh1add_n(t, t, c1, kk1);
#else
mpn_add_n(t, t, c1, kk1);
mpn_rshift1(t, t, kk1);
#endif
} else
{
#if HAVE_NATIVE_mpn_rsh1sub_n
mpn_rsh1sub_n(t, c1, t, kk1);
#else
mpn_sub_n(t, c1, t, kk1);
mpn_rshift1(t, t, kk1);
#endif
}
/* Compute v1 <-- v1 - vm1 */
mpn_sub_n(c1, c1, t, kk1);
/* Note we could technically overflow
the end of the output if we add
everything in place without subtracting
the right things first. We get around
this by throwing away any high limbs
and carries, which must of necessity
cancel.
First we add vm1 in its place...
*/
n1 = kk1;
MPN_NORMALIZE(t, n1);
if (n1 >= k + rr2) /* if > here, high limb of vm1 and carry may be discarded */
{
cy = mpn_add_n(c2, c2, t, k1);
mpn_add_1(c3 + 1, t + k1, rr2 - 1, cy);
n2 = threek + rr2;
} else
{
c2[k1] = mpn_add_n(c2, c2, t, k1);
if (n1 > k1) c2[n1] = mpn_add_1(c3 + 1, t + k1, n1 - k1, c2[k1]);
n2 = twok + MAX(n1, k1) + 1;
}
/* Compute vinf := a2*b1 in {t, rr2} */
if (r == r2) TOOM3_MUL_REC (t, a + twok, b + k, r, trec);
else if (r > r2) mpn_mul(t, a + twok, r, b + k, r2);
else mpn_mul(t, b + k, r2, a + twok, r);
/* Add vinf into place */
cy = mpn_add_n(c3, c3, t, n2 - threek);
if (rr2 + threek > n2)
mpn_add_1(c + n2, t + n2 - threek, rr2 + threek - n2, cy);
/* v1 <-- v1 - vinf */
cy = mpn_sub_n(c1, c1, t, rr2);
if (cy) mpn_sub_1(c1 + rr2, c1 + rr2, twok, cy);
/* compute v0 := a0*b0 in {t, 2k} */
TOOM3_MUL_REC (t, a, b, k, trec);
/* Add v0 into place */
MPN_COPY(c, t, k);
cy = mpn_add_n(c + k, c + k, t + k, k);
if (cy) mpn_add_1(c + twok, c + twok, k + rr2, cy);
/* vm1 <-- vm1 - v0 */
if (twok >= k + rr2)
mpn_sub_n(c2, c2, t, k + rr2);
else
{
cy = mpn_sub_n(c2, c2, t, twok);
mpn_sub_1(c4, c4, rr2 + k - twok, cy);
}
}