Commit bc9e170f authored by Davide Galassi's avatar Davide Galassi

MPI Toom-3 multiplication improvement

Using the Bodrato's suggested sequence
parent 1fc7806a
Pipeline #29397792 passed with stages
in 1 minute and 52 seconds
......@@ -3,7 +3,7 @@
#include <stdio.h>
#define KARATSUBA_CUTOFF 64
#define TOOM3_CUTOFF 356
#define TOOM3_CUTOFF 300
int cry_mpi_mul_abs(cry_mpi *r, const cry_mpi *a, const cry_mpi *b)
{
......
......@@ -32,20 +32,6 @@ static int mod_2e(cry_mpi *r, const cry_mpi *a, unsigned int e)
return 0;
}
static int mul3(cry_mpi *a)
{
int res;
cry_mpi b;
res = cry_mpi_init(&b);
if (res == 0) {
res = cry_mpi_shl(&b, a, 1);
if (res == 0)
res = cry_mpi_add(a, a, &b);
cry_mpi_clear(&b);
}
return res;
}
/*
* Based on Tudor Jebelean "exact division" algorithm.
......@@ -93,20 +79,31 @@ static int div3(cry_mpi *a)
return res;
}
/*
* For nodes evaluation and interpolation the implementation uses the
* operations sequence given by Bodrato's paper:
* "Optimal Toom-Cook Multiplication for univariate and Multivariate
* Polynomials in Characteristic 2 and 0"
*/
int cry_mpi_mul_toom3(cry_mpi *r, const cry_mpi *a, const cry_mpi *b)
{
int res, B;
cry_mpi w0, w1, w2, w3, w4, tmp1, tmp2, a0, a1, a2, b0, b1, b2;
cry_mpi w0, w1, w2, w3, w4, a0, a1, a2, b0, b1, b2, t1, t2;
/* init temps */
if ((res = cry_mpi_init_list(&w0, &w1, &w2, &w3, &w4,
&a0, &a1, &a2, &b0, &b1,
&b2, &tmp1, &tmp2, NULL)) != 0) {
&b2, &t1, &t2, NULL)) != 0) {
return res;
}
/* B */
B = CRY_MIN(a->used, b->used) / 3;
/*
* Splitting
*/
/* a = a2 * B**2 + a1 * B + a0 */
if ((res = mod_2e(&a0, a, CRY_MPI_DIGIT_BITS * B)) != 0)
goto e;
......@@ -131,133 +128,111 @@ int cry_mpi_mul_toom3(cry_mpi *r, const cry_mpi *a, const cry_mpi *b)
goto e;
cry_mpi_shrd(&b2, B*2);
/* w0 = a0*b0 */
if ((res = cry_mpi_mul(&w0, &a0, &b0)) != 0)
goto e;
/* w4 = a2 * b2 */
if ((res = cry_mpi_mul(&w4, &a2, &b2)) != 0)
goto e;
/*
* Evaluation (using Bodrato's steps)
*/
/* w1 = (a2 + 2(a1 + 2a0))(b2 + 2(b1 + 2b0)) */
if ((res = cry_mpi_shl(&tmp1, &a0, 1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp1, &tmp1, &a1)) != 0)
/* t1 = a0 + a2 */
if ((res = cry_mpi_add(&t1, &a0, &a2)) != 0)
goto e;
if ((res = cry_mpi_shl(&tmp1, &tmp1, 1)) != 0)
/* t2 = b0 + b2 */
if ((res = cry_mpi_add(&t2, &b0, &b2)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp1, &tmp1, &a2)) != 0)
/* w1 = ((a0 + a2) + a1)((b0 + b2) + b1) */
if ((res = cry_mpi_add(&w1, &t1, &a1)) != 0)
goto e;
if ((res = cry_mpi_shl(&tmp2, &b0, 1)) != 0)
if ((res = cry_mpi_add(&w2, &t2, &b1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp2, &tmp2, &b1)) != 0)
if ((res = cry_mpi_mul(&w1, &w1, &w2)) != 0)
goto e;
if ((res = cry_mpi_shl(&tmp2, &tmp2, 1)) != 0)
/* w2 = ((a0 + a2) - a1)((b0 + b2) - b1) */
if ((res = cry_mpi_sub(&t1, &t1, &a1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp2, &b2, &tmp2)) != 0)
if ((res = cry_mpi_sub(&t2, &t2, &b1)) != 0)
goto e;
if ((res = cry_mpi_mul(&w1, &tmp1, &tmp2)) != 0)
if ((res = cry_mpi_mul(&w2, &t1, &t2)) != 0)
goto e;
/* w3 = (a0 + 2(a1 + 2a2))(b0 + 2(b1 + 2b2)) */
if ((res = cry_mpi_shl(&tmp1, &a2, 1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp1, &tmp1, &a1)) != 0)
/* t1 = 2*((a0+a2-a1) + a2) - a0 = a0 - a1 + 4a2 */
if ((res = cry_mpi_add(&t1, &t1, &a2)) != 0)
goto e;
if ((res = cry_mpi_shl(&tmp1, &tmp1, 1)) != 0)
if ((res = cry_mpi_shl(&t1, &t1, 1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp1, &tmp1, &a0)) != 0)
if ((res = cry_mpi_sub(&t1, &t1, &a0)) != 0)
goto e;
if ((res = cry_mpi_shl(&tmp2, &b2, 1)) != 0)
/* t2 = 2*((b0+b2-b1) + b2) - b0 = b0 - 2a1 + 4a2 */
if ((res = cry_mpi_add(&t2, &t2, &b2)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp2, &tmp2, &b1)) != 0)
if ((res = cry_mpi_shl(&t2, &t2, 1)) != 0)
goto e;
if ((res = cry_mpi_shl(&tmp2, &tmp2, 1)) != 0)
if ((res = cry_mpi_sub(&t2, &t2, &b0)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp2, &tmp2, &b0)) != 0)
goto e;
if ((res = cry_mpi_mul(&w3, &tmp1, &tmp2)) != 0)
/* w3 = t1 * t2 */
if ((res = cry_mpi_mul(&w3, &t1, &t2)) != 0)
goto e;
/* w2 = (a2 + a1 + a0)(b2 + b1 + b0) */
if ((res = cry_mpi_add(&tmp1, &a2, &a1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp1, &tmp1, &a0)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp2, &b2, &b1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp2, &tmp2, &b0)) != 0)
/* w0 = a0 * b0 */
if ((res = cry_mpi_mul(&w0, &a0, &b0)) != 0)
goto e;
if ((res = cry_mpi_mul(&w2, &tmp1, &tmp2)) != 0)
/* w4 = a2 * b2 */
if ((res = cry_mpi_mul(&w4, &a2, &b2)) != 0)
goto e;
/*
* Solve the matrix
* Interpolation
*
* 0 0 0 0 1
* 1 2 4 8 16
* 1 1 1 1 1
* 16 8 4 2 1
* 1 0 0 0 0
* w0 1 0 0 0 0 r0
* w1 1 1 1 1 1 r1
* w2 = 1 -1 1 -1 1 x r2
* w3 1 -2 4 -8 16 r3
* w4 0 0 0 0 1 r4
*
* Using 12 subtractions, 4 shifts,
* 2 small divisions and 1 small multiplication
*/
/* r1 - r4 */
if ((res = cry_mpi_sub(&w1, &w1, &w4)) != 0)
goto e;
/* r3 - r0 */
if ((res = cry_mpi_sub(&w3, &w3, &w0)) != 0)
goto e;
/* r1/2 */
if ((res = cry_mpi_shr(&w1, &w1, 1)) != 0)
goto e;
/* r3/2 */
if ((res = cry_mpi_shr(&w3, &w3, 1)) != 0)
goto e;
/* r2 - r0 - r4 */
if ((res = cry_mpi_sub(&w2, &w2, &w0)) != 0)
/* w3 = (w3 - w1)/3 */
if ((res = cry_mpi_sub(&w3, &w3, &w1)) != 0)
goto e;
if ((res = cry_mpi_sub(&w2, &w2, &w4)) != 0)
if ((res = div3(&w3)) != 0)
goto e;
/* r1 - r2 */
/* w1 = (w1 - w2)/2 */
if ((res = cry_mpi_sub(&w1, &w1, &w2)) != 0)
goto e;
/* r3 - r2 */
if ((res = cry_mpi_sub(&w3, &w3, &w2)) != 0)
goto e;
/* r1 - 8r0 */
if ((res = cry_mpi_shl(&tmp1, &w0, 3)) != 0)
goto e;
if ((res = cry_mpi_sub(&w1, &w1, &tmp1)) != 0)
goto e;
/* r3 - 8r4 */
if ((res = cry_mpi_shl(&tmp1, &w4, 3)) != 0)
if ((res = cry_mpi_shr(&w1, &w1, 1)) != 0)
goto e;
if ((res = cry_mpi_sub(&w3, &w3, &tmp1)) != 0)
/* w2 = w2 - w0 */
if ((res = cry_mpi_sub(&w2, &w2, &w0)) != 0)
goto e;
/* 3r2 - r1 - r3 */
if ((res = mul3(&w2)) != 0)
/* w3 = (w2 - w3)/2 + 2*w4 */
if ((res = cry_mpi_sub(&w3, &w2, &w3)) != 0)
goto e;
if ((res = cry_mpi_sub(&w2, &w2, &w1)) != 0)
if ((res = cry_mpi_shr(&w3, &w3, 1)) != 0)
goto e;
if ((res = cry_mpi_sub(&w2, &w2, &w3)) != 0)
if ((res = cry_mpi_shl(&t1, &w4, 1)) != 0)
goto e;
/* r1 - r2 */
if ((res = cry_mpi_sub(&w1, &w1, &w2)) != 0)
if ((res = cry_mpi_add(&w3, &w3, &t1)) != 0)
goto e;
/* r3 - r2 */
if ((res = cry_mpi_sub(&w3, &w3, &w2)) != 0)
/* w2 = w2 + w1 - w4 */
if ((res = cry_mpi_add(&w2, &w2, &w1)) != 0)
goto e;
/* r1/3 */
if ((res = div3(&w1)) != 0)
if ((res = cry_mpi_sub(&w2, &w2, &w4)) != 0)
goto e;
/* r3/3 */
if ((res = div3(&w3)) != 0)
/* w1 = w1 - w3 */
if ((res = cry_mpi_sub(&w1, &w1, &w3)) != 0)
goto e;
/* Reconstruct by shifting wn by B*n */
/*
* Reconstruction
*/
/* Shift wn by B*n */
if ((res = cry_mpi_shld(&w1, 1*B)) != 0)
goto e;
if ((res = cry_mpi_shld(&w2, 2*B)) != 0)
......@@ -267,18 +242,18 @@ int cry_mpi_mul_toom3(cry_mpi *r, const cry_mpi *a, const cry_mpi *b)
if ((res = cry_mpi_shld(&w4, 4*B)) != 0)
goto e;
/* Add the parts */
if ((res = cry_mpi_add(r, &w0, &w1)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp1, &w2, &w3)) != 0)
if ((res = cry_mpi_add(&t1, &w2, &w3)) != 0)
goto e;
if ((res = cry_mpi_add(&tmp1, &w4, &tmp1)) != 0)
if ((res = cry_mpi_add(&t1, &w4, &t1)) != 0)
goto e;
if ((res = cry_mpi_add(r, &tmp1, r)) != 0)
if ((res = cry_mpi_add(r, &t1, r)) != 0)
goto e;
e: cry_mpi_clear_list(&w0, &w1, &w2, &w3, &w4,
&a0, &a1, &a2, &b0, &b1,
&b2, &tmp1, &tmp2, NULL);
&b2, &t1, &t2, NULL);
return res;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment