Commit 5a1c50a4 authored by Gerard Ryan's avatar Gerard Ryan
Browse files

Merge branch 'issue-783' into 'v1.1.1-work-branch'

Issue 783

Fixes #783

See merge request !362
parents e2e94eba 2448e308
......@@ -1217,7 +1217,7 @@ template<typename uint_type,usint BITLENGTH>
BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::Mod(const BigInteger& modulus) const{
//return the same value if value is less than modulus
if(*this<modulus){
return BigInteger(*this);
return *this;
}
//masking operation if modulus is 2
if(modulus.m_MSB==2 && modulus.m_value[m_nSize-1]==2){
......@@ -1442,6 +1442,40 @@ BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModBarrett(cons
}
/*
* in place version.
*/
template<typename uint_type,usint BITLENGTH>
void BigInteger<uint_type,BITLENGTH>::ModBarrettInPlace(const BigInteger& modulus, const BigInteger mu_arr[BARRETT_LEVELS+1]) {
if(*this<modulus){
return;
}
BigInteger q(*this);
usint n = modulus.m_MSB;
//level is set to the index between 0 and BARRET_LEVELS - 1
usint level = (this->m_MSB-1-n)*BARRETT_LEVELS/(n+1)+1;
usint gamma = (n*level)/BARRETT_LEVELS;
usint alpha = gamma + 3;
int beta = -2;
const BigInteger& mu = mu_arr[level];
q>>=n + beta;
q=q*mu;
q>>=alpha-beta;
*this -= q*modulus;
if(*this >= modulus)
*this -= modulus;
return;
}
//Extended Euclid algorithm used to find the multiplicative inverse
template<typename uint_type,usint BITLENGTH>
BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModInverse(const BigInteger& modulus) const{
......@@ -1538,49 +1572,49 @@ BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModBarrettAdd(c
template<typename uint_type,usint BITLENGTH>
BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModSub(const BigInteger& b, const BigInteger& modulus) const{
BigInteger* a = const_cast<BigInteger*>(this);
BigInteger* b_op = const_cast<BigInteger*>(&b);
BigInteger a(*this);
BigInteger b_op(b);
//reduce this to a value lower than modulus
if(*this>modulus){
*a = this->Mod(modulus);
if(a > modulus){
a.ModEq(modulus);
}
//reduce b to a value lower than modulus
if(b>modulus){
*b_op = b.Mod(modulus);
if(b > modulus){
b_op.ModEq(modulus);
}
if(*a>=*b_op){
return ((*a-*b_op).Mod(modulus));
if(a >= b_op){
a.MinusEq(b_op);
a.ModEq(modulus);
}
else{
return ((*a + modulus) - *b_op);
a.PlusEq(modulus);
a.MinusEq(b_op);
}
return a;
}
template<typename uint_type,usint BITLENGTH>
const BigInteger<uint_type,BITLENGTH>& BigInteger<uint_type,BITLENGTH>::ModSubEq(const BigInteger& b, const BigInteger& modulus) {
BigInteger* b_op = const_cast<BigInteger*>(&b);
BigInteger b_op(b);
//reduce this to a value lower than modulus
if(*this>modulus){
*this = this->Mod(modulus);
if(*this > modulus){
this->ModEq(modulus);
}
//reduce b to a value lower than modulus
if(b>modulus){
*b_op = b.Mod(modulus);
if(b > modulus){
b_op.ModEq(modulus);
}
if(*this >= *b_op){
// FIXME use modeq
*this = this->Mod(*b_op).Mod(modulus);
if(*this >= b_op){
this->ModEq(b_op);
this->ModEq(modulus);
}
else{
// ugh, the *Eq ops return const so we cannot chain them
this->PlusEq(modulus);
this->MinusEq(*b_op);
this->MinusEq(b_op);
}
return *this;
......@@ -1590,55 +1624,54 @@ const BigInteger<uint_type,BITLENGTH>& BigInteger<uint_type,BITLENGTH>::ModSubEq
template<typename uint_type,usint BITLENGTH>
BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModBarrettSub(const BigInteger& b, const BigInteger& modulus,const BigInteger& mu) const{
BigInteger* a = const_cast<BigInteger*>(this);
BigInteger* b_op = const_cast<BigInteger*>(&b);
BigInteger a(*this);
BigInteger b_op(b);
if(*this>modulus){
*a = this->ModBarrett(modulus,mu);
if(*this > modulus){
a.ModBarrettInPlace(modulus,mu);
}
if(b>modulus){
*b_op = b.ModBarrett(modulus,mu);
b_op.ModBarrettInPlace(modulus,mu);
}
if(*a >= *b_op){
return ((*a-*b_op).ModBarrett(modulus,mu));
if(a >= b_op){
a.MinusEq(b_op);
a.ModBarrettInPlace(modulus,mu);
}
else{
return ((*a + modulus) - *b_op);
a.PlusEq(modulus);
a.MinusEq(b_op);
}
return a;
}
template<typename uint_type,usint BITLENGTH>
BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModBarrettSub(const BigInteger& b, const BigInteger& modulus,const BigInteger mu_arr[BARRETT_LEVELS]) const{
BigInteger* a = NULL;
BigInteger* b_op = NULL;
BigInteger a(*this);
BigInteger b_op(b);
if(*this>modulus){
*a = this->ModBarrett(modulus,mu_arr);
}
else{
a = const_cast<BigInteger*>(this);
if(*this > modulus){
a.ModBarrettInPlace(modulus,mu_arr);
}
if(b>modulus){
*b_op = b.ModBarrett(modulus,mu_arr);
}
else{
b_op = const_cast<BigInteger*>(&b);
b_op.ModBarrettInPlace(modulus,mu_arr);
}
if(!(*a<*b_op)){
return ((*a-*b_op).ModBarrett(modulus,mu_arr));
if(a >= b_op){
a.MinusEq(b_op);
a.ModBarrettInPlace(modulus,mu_arr);
}
else{
return ((*a + modulus) - *b_op);
a.PlusEq(modulus);
a.MinusEq(b_op);
}
return a;
}
template<typename uint_type,usint BITLENGTH>
......@@ -1656,9 +1689,8 @@ BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModMul(const Bi
bb = bb.Mod(modulus);
}
//return a*b%q
return (a*bb).Mod(modulus);
a.TimesEq(bb);
return a.ModEq(modulus);
}
template<typename uint_type,usint BITLENGTH>
......@@ -1675,8 +1707,8 @@ const BigInteger<uint_type,BITLENGTH>& BigInteger<uint_type,BITLENGTH>::ModMulEq
bb.ModEq(modulus);
}
*this *= bb;
*this %= modulus;
this->TimesEq(bb);
this->ModEq(modulus);
return *this;
}
......@@ -1707,19 +1739,19 @@ This algorithm would most like give the biggest improvement but it sets constrai
template<typename uint_type,usint BITLENGTH>
BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModBarrettMul(const BigInteger& b, const BigInteger& modulus,const BigInteger& mu) const{
BigInteger* a = const_cast<BigInteger*>(this);
BigInteger* bb = const_cast<BigInteger*>(&b);
BigInteger a(*this);
BigInteger bb(b);
//if a is greater than q reduce a to its mod value
if(*this>modulus)
*a = this->ModBarrett(modulus,mu);
a.ModBarrettInPlace(modulus,mu);
//if b is greater than q reduce b to its mod value
if(b>modulus)
*bb = b.ModBarrett(modulus,mu);
bb.ModBarrettInPlace(modulus,mu);
return (*a**bb).ModBarrett(modulus,mu);
a.TimesEq(bb);
return a.ModBarrett(modulus,mu);
}
......@@ -1751,8 +1783,7 @@ This algorithm would most like give the biggest improvement but it sets constrai
template<typename uint_type, usint BITLENGTH>
void BigInteger<uint_type, BITLENGTH>::ModBarrettMulInPlace(const BigInteger& b, const BigInteger& modulus, const BigInteger& mu) {
//BigInteger* a = const_cast<BigInteger*>(this);
BigInteger* bb = const_cast<BigInteger*>(&b);
BigInteger bb(b);
//if a is greater than q reduce a to its mod value
if (*this>modulus)
......@@ -1761,10 +1792,9 @@ void BigInteger<uint_type, BITLENGTH>::ModBarrettMulInPlace(const BigInteger& b,
//if b is greater than q reduce b to its mod value
if (b>modulus)
*bb = b.ModBarrett(modulus, mu);
*this = *this**bb;
bb.ModBarrettInPlace(modulus, mu);
this->TimesEq(bb);
this->ModBarrettInPlace(modulus, mu);
return;
......@@ -1774,24 +1804,19 @@ void BigInteger<uint_type, BITLENGTH>::ModBarrettMulInPlace(const BigInteger& b,
template<typename uint_type,usint BITLENGTH>
BigInteger<uint_type,BITLENGTH> BigInteger<uint_type,BITLENGTH>::ModBarrettMul(const BigInteger& b, const BigInteger& modulus,const BigInteger mu_arr[BARRETT_LEVELS]) const{
BigInteger* a = NULL;
BigInteger* bb = NULL;
BigInteger a(*this);
BigInteger bb(b);
//if a is greater than q reduce a to its mod value
if(*this>modulus)
*a = this->ModBarrett(modulus,mu_arr);
else
a = const_cast<BigInteger*>(this);
a.ModBarrettInPlace(modulus,mu_arr);
//if b is greater than q reduce b to its mod value
if(b>modulus)
*bb = b.ModBarrett(modulus,mu_arr);
else
bb = const_cast<BigInteger*>(&b);
//return a*b%q
bb.ModBarrettInPlace(modulus,mu_arr);
return (*a**bb).ModBarrett(modulus,mu_arr);
a.TimesEq(bb);
return a.ModBarrett(modulus,mu_arr);
}
//Modular Multiplication using Square and Multiply Algorithm
......
......@@ -456,6 +456,17 @@ namespace cpu_int{
*/
BigInteger ModBarrett(const BigInteger& modulus, const BigInteger mu_arr[BARRETT_LEVELS+1]) const;
/**
* returns the modulus with respect to the input value - In place version.
* Implements generalized Barrett modular reduction algorithm. Uses an array of precomputed values \mu.
* See the cpp file for details of the implementation.
*
* @param modulus is the modulus to perform operations with.
* @param mu_arr is an array of the Barrett values of length BARRETT_LEVELS.
* @return result of the modulus operation.
*/
void ModBarrettInPlace(const BigInteger& modulus, const BigInteger mu_arr[BARRETT_LEVELS+1]);
/**
* returns the modulus inverse with respect to the input value.
*
......
......@@ -2060,7 +2060,7 @@ return result;
ubint<limb_t> ubint<limb_t>::ModBarrett(const ubint& modulus, const ubint& mu) const{
#ifdef NO_BARRETT
ubint ans(*this);
ans%=modulus;
ans.ModEq(modulus);
return(ans);
#else
if(*this<modulus){
......@@ -2212,7 +2212,6 @@ return result;
}
//Need to mimic signed modulus return of BE 2
template<typename limb_t>
ubint<limb_t> ubint<limb_t>::ModSub(const ubint& b, const ubint& modulus) const{
ubint a(*this);
......@@ -2228,11 +2227,15 @@ return result;
}
if(a>=b_op){
return ((a-b_op).Mod(modulus));
a.MinusEq(b_op);
a.ModEq(modulus);
}
else{
return ((a + modulus) - b_op);
a.PlusEq(modulus);
a.MinusEq(b_op);
}
return a;
}
template<typename limb_t>
......@@ -2260,7 +2263,6 @@ return result;
return *this;
}
template<typename limb_t>
ubint<limb_t> ubint<limb_t>::ModMul(const ubint& b, const ubint& modulus) const{
......@@ -2358,18 +2360,20 @@ return result;
return this->ModMul(b, modulus);
#else
ubint* a = const_cast<ubint*>(this);
ubint* bb = const_cast<ubint*>(&b);
ubint a(*this);
ubint bb(b);
//if a is greater than q reduce a to its mod value
if(*this>modulus)
*a = std::move(this->ModBarrett(modulus,mu));
a.ModBarrettInPlace(modulus,mu);
//if b is greater than q reduce b to its mod value
if(b>modulus)
*bb = std::move(b.ModBarrett(modulus,mu));
bb.ModBarrettInPlace(modulus,mu);
return (*a**bb).ModBarrett(modulus,mu);
a.TimesEq(bb);
a.ModBarrettInPlace(modulus,mu);
return a;
#endif
}
......@@ -2378,12 +2382,11 @@ return result;
template<typename limb_t>
void ubint<limb_t>::ModBarrettMulInPlace(const ubint& b, const ubint& modulus,const ubint& mu) {
#ifdef NO_BARRETT
*this = this->ModMul(b, modulus);
this->ModMulEq(b, modulus);
return ;
#else
ubint* bb = const_cast<ubint*>(&b);
ubint bb(b);
//if this is greater than q reduce a to its mod value
if(*this>modulus)
......@@ -2391,9 +2394,9 @@ return result;
//if b is greater than q reduce b to its mod value
if(b>modulus)
*bb = b.ModBarrett(modulus,mu);
*this = *this**bb;
bb.ModBarrettInPlace(modulus,mu);
this->TimesEq(bb);
this->ModBarrettInPlace(modulus, mu);
return;
......@@ -2409,23 +2412,20 @@ return result;
ubint ans(*this);
return ans.ModMul(b, modulus);
#else
ubint* a = NULL;
ubint* bb = NULL;
ubint a(*this);
ubint bb(b);
//if a is greater than q reduce a to its mod value
if(*this>modulus)
*a = std::move(this->ModBarrett(modulus,mu_arr));
else
a = const_cast<ubint*>(this);
a.ModBarrettInPlace(modulus,mu_arr);
//if b is greater than q reduce b to its mod value
if(b>modulus)
*bb = std::move(b.ModBarrett(modulus,mu_arr));
else
bb = const_cast<ubint*>(&b);
bb.ModBarrettInPlace(modulus,mu_arr);
//return a*b%q
return (*a**bb).ModBarrett(modulus,mu_arr);
a.TimesEq(bb);
a.ModBarrettInPlace(modulus,mu_arr);
return a;
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment