Commit dae30d66 authored by Gerard Ryan's avatar Gerard Ryan
Browse files

be6 passes all unit tests, yay

parent 629eda4f
......@@ -134,46 +134,53 @@ public:
//palisade arithmetic methods
myZZ Plus(const myZZ& b) const {
return *this + b;
return *static_cast<const ZZ*>(this) + static_cast<const ZZ&>(b);
}
const myZZ& PlusEq(const myZZ& b) {
return *this += b;
*static_cast<ZZ*>(this) += static_cast<const ZZ&>(b);
return *this;
}
//note that in Sub we return 0, if a<b
myZZ Minus(const myZZ& b) const {
return (*this<b) ? ZZ(0): (*this-b);
return (*this<b) ? ZZ(0): (*static_cast<const ZZ*>(this) - static_cast<const ZZ&>(b));
}
const myZZ& MinusEq(const myZZ& b) {
if (*this<b)
*this = ZZ(0);
else
*this -= b;
*static_cast<ZZ*>(this) -= static_cast<const ZZ&>(b);
return *this;
}
myZZ Times(const myZZ& b) const { return *this * b; }
const myZZ& TimesEq(const myZZ& b) { return *this *= b; }
myZZ Times(const myZZ& b) const { return *static_cast<const ZZ*>(this) * static_cast<const ZZ&>(b); }
const myZZ& TimesEq(const myZZ& b) {
*static_cast<ZZ*>(this) *= static_cast<const ZZ&>(b);
return *this;
}
myZZ DividedBy(const myZZ& b) const {return *this / b; }
const myZZ& DividedByEq(const myZZ& b) {return *this /= b; }
myZZ DividedBy(const myZZ& b) const {return *static_cast<const ZZ*>(this) / static_cast<const ZZ&>(b); }
const myZZ& DividedByEq(const myZZ& b) {
*static_cast<ZZ*>(this) /= static_cast<const ZZ&>(b);
return *this;
}
myZZ Exp(const usint p) const {return power(*this,p);}
//palisade modular arithmetic methods
myZZ Mod(const myZZ& modulus) const {return *this%modulus;}
myZZ Mod(const myZZ& modulus) const {return *static_cast<const ZZ*>(this) % static_cast<const ZZ&>(modulus);}
const myZZ& ModEq(const myZZ& modulus) {
*this %= modulus;
*static_cast<ZZ*>(this) %= static_cast<const ZZ&>(modulus);
return *this;
}
myZZ ModBarrett(const myZZ& modulus, const myZZ& mu) const {return *this%modulus;}
void ModBarrettInPlace(const myZZ& modulus, const myZZ& mu) { *this%=modulus;}
myZZ ModBarrett(const myZZ& modulus, const myZZ& mu) const {return *static_cast<const ZZ*>(this) % static_cast<const ZZ&>(modulus);}
void ModBarrettInPlace(const myZZ& modulus, const myZZ& mu) { *static_cast<ZZ*>(this) %= static_cast<const ZZ&>(modulus);}
myZZ ModBarrett(const myZZ& modulus, const myZZ mu_arr[BARRETT_LEVELS+1]) const {return *this%modulus;}
myZZ ModBarrett(const myZZ& modulus, const myZZ mu_arr[BARRETT_LEVELS+1]) const {return *static_cast<const ZZ*>(this) % static_cast<const ZZ&>(modulus);}
myZZ ModInverse(const myZZ& modulus) const {
bool dbg_flag = false;
......@@ -197,7 +204,15 @@ public:
return tmp;
}
myZZ ModAdd(const myZZ& b, const myZZ& modulus) const {return myZZ(AddMod(*this%modulus, b%modulus, modulus));}
myZZ ModAdd(const myZZ& b, const myZZ& modulus) const {
return AddMod(this->Mod(modulus), b.Mod(modulus), modulus);
}
const myZZ& ModAddEq(const myZZ& b, const myZZ& modulus) {
AddMod(*this, this->Mod(modulus), b.Mod(modulus), modulus);
return *this;
}
//Fast version does not check for modulus bounds.
myZZ ModAddFast(const myZZ& b, const myZZ& modulus) const {return AddMod(*this, b, modulus);}
......@@ -228,8 +243,29 @@ public:
}
}
const myZZ& ModSubEq(const myZZ& b, const myZZ& modulus)
{
bool dbg_flag = false;
this->ModEq(modulus);
myZZ newb(b%modulus);
if (*this>=newb) {
SubMod(*this, *this, newb, modulus); //normal mod sub
DEBUG("in modsub submod tmp "<< *this);
return *this;
} else {
this->PlusEq(modulus);
this->MinusEq(newb); //signed mod
DEBUG("in modsub alt tmp "<< *this);
return *this;
}
}
//Fast version does not check for modulus bounds.
inline myZZ ModSubFast(const myZZ& b, const myZZ& modulus) const
myZZ ModSubFast(const myZZ& b, const myZZ& modulus) const
{
if (*this>=b) {
return SubMod(*this, b, modulus); //normal mod sub
......@@ -237,14 +273,21 @@ public:
return (*this+modulus -b) ; //signed mod
}
};
}
inline myZZ ModBarrettSub(const myZZ& b, const myZZ& modulus,const myZZ& mu) const {
myZZ ModBarrettSub(const myZZ& b, const myZZ& modulus,const myZZ& mu) const {
return this->ModSub(b, modulus);
};
}
myZZ ModMul(const myZZ& b, const myZZ& modulus) const {
return MulMod(this->Mod(modulus), b.Mod(modulus), modulus);
}
const myZZ& ModMulEq(const myZZ& b, const myZZ& modulus) {
MulMod(*this, this->Mod(modulus), b.Mod(modulus), modulus);
return *this;
}
inline myZZ ModMul(const myZZ& b, const myZZ& modulus) const {return myZZ(MulMod(*this%modulus, b%modulus, modulus));};
//Fast version does not check for modulus bounds.
inline myZZ ModMulFast(const myZZ& b, const myZZ& modulus) const {return MulMod(*this, b, modulus);};
......@@ -278,7 +321,7 @@ public:
* @param shift # of bits
* @return result of the shift operation.
*/
myZZ LShift(usshort shift) const { return (*this) << shift; }
myZZ LShift(usshort shift) const { return *static_cast<const ZZ*>(this) << shift; }
/**
* <<= operation
......@@ -287,7 +330,7 @@ public:
* @return result of the shift operation.
*/
const myZZ& LShiftEq(usshort shift) {
(*this) <<= shift;
*static_cast<ZZ*>(this) <<= shift;
return *this;
}
......@@ -297,7 +340,7 @@ public:
* @param shift # of bits
* @return result of the shift operation.
*/
myZZ RShift(usshort shift) const { return (*this) >> shift; }
myZZ RShift(usshort shift) const { return *static_cast<const ZZ*>(this) >> shift; }
/**
* >>= operation
......@@ -306,12 +349,10 @@ public:
* @return result of the shift operation.
*/
const myZZ& RShiftEq(usshort shift) {
(*this) >>= shift;
*static_cast<ZZ*>(this) >>= shift;
return *this;
}
//big integer stream output
friend std::ostream& operator<<(std::ostream& os, const myZZ&ptr_obj);
......
......@@ -221,12 +221,16 @@ public:
myVecP ModAdd(const myT& b) const {
ModulusCheck("Warning: myVecP::ModAdd");
return (*this) + b % m_modulus;
myVecP ans(*this);
ans.ModAddEq(b);
return ans;
}
const myVecP& ModAddEq(const myT& b) {
ModulusCheck("Warning: myVecP::ModAdd");
(*this) += b % m_modulus;
for(usint i=0;i<this->GetLength();i++){
this->operator[](i).ModAddEq(b, this->m_modulus);
}
return *this;
}
......@@ -234,42 +238,50 @@ public:
myVecP ModAddAtIndex(size_t i, const myT &b) const;
// //vector add
// myVecP Add(const myVecP& b) const {
// ArgCheckVector(b, "myVecP Add()");
// return (*this)+b;
// }
myVecP ModAdd(const myVecP& b) const {
return (*this) + b % m_modulus;
ArgCheckVector(b, "myVecP ModAdd()");
myVecP ans(*this);
ans.ModAddEq(b);
return ans;
}
const myVecP& ModAddEq(const myVecP& b) {
(*this) += b % m_modulus;
ArgCheckVector(b, "myVecP ModAddEq()");
for(usint i=0;i<this->GetLength();i++){
this->operator[](i).ModAddEq(b[i], this->m_modulus);
}
return *this;
}
//scalar
myVecP ModSub(const myT& b) const {
ModulusCheck("Warning: myVecP::ModSub");
return (*this) - b % m_modulus;
myVecP ans(*this);
ans.ModSubEq(b);
return ans;
}
const myVecP& ModSubEq(const myT& b) {
ModulusCheck("Warning: myVecP::ModSub");
(*this) -= b % m_modulus;
ModulusCheck("Warning: myVecP::ModSubEq");
for(usint i=0;i<this->GetLength();i++){
this->operator[](i).ModSubEq(b, this->m_modulus);
}
return (*this);
}
//vector
myVecP ModSub(const myVecP& b) const {
ArgCheckVector(b, "myVecP ModSub()");
return (*this) - b % m_modulus;
myVecP ans(*this);
ans.ModSubEq(b);
return ans;
}
const myVecP& ModSubEq(const myVecP& b) {
ArgCheckVector(b, "myVecP ModSub()");
(*this) -= b % m_modulus;
ArgCheckVector(b, "myVecP ModSubEq()");
for(usint i=0;i<this->GetLength();i++){
this->operator[](i).ModSubEq(b[i], this->m_modulus);
}
return (*this);
}
......@@ -279,24 +291,32 @@ public:
//scalar
myVecP ModMul(const myT& b) const {
ModulusCheck("Warning: myVecP::ModMul");
return (*this) * b % m_modulus;
myVecP ans(*this);
ans.ModMulEq(b);
return ans;
}
const myVecP& ModMulEq(const myT& b) {
ModulusCheck("Warning: myVecP::ModMul");
(*this) *= b % m_modulus;
for(usint i=0;i<this->GetLength();i++){
this->operator[](i).ModMulEq(b, this->m_modulus);
}
return (*this);
}
//vector
myVecP ModMul(const myVecP& b) const {
ArgCheckVector(b, "myVecP Mul()");
return (*this) * b % m_modulus;
myVecP ans(*this);
ans.ModMulEq(b);
return ans;
}
const myVecP& ModMulEq(const myVecP& b) {
ArgCheckVector(b, "myVecP Mul()");
(*this) *= b % m_modulus;
for(usint i=0;i<this->GetLength();i++){
this->operator[](i).ModMulEq(b[i], this->m_modulus);
}
return (*this);
}
......
......@@ -258,7 +258,7 @@ TEST(UTBinVect, CTOR_Test){
// TEST CASE WHEN NUMBERS AFTER ADDITION ARE SMALLER THAN MODULUS
TEST(UTBinVect,ModAddBBITestBigModulus){
TEST(UTBinVect,ModAddBigModulus){
BigInteger q("3435435"); // constructor calling to set mod value
BigVector m(5,q); // calling constructor to create a vector of length 5 and passing value of q
......@@ -284,7 +284,7 @@ TEST(UTBinVect,ModAddBBITestBigModulus){
// TEST CASE WHEN NUMBERS AFTER ADDITION ARE GREATER THAN MODULUS
TEST(UTBinVect,ModAddBBITestSmallerModulus){
TEST(UTBinVect,ModAddSmallerModulus){
bool dbg_flag = false;
BigInteger q("3534"); // constructor calling to set mod value
......@@ -370,11 +370,11 @@ TEST(UTBinVect,modsub_first_number_greater_than_second_number){
/*--------------TESTING METHOD MODUMUL FOR ALL CONDITIONS---------------------------*/
/* The method "Mod Mod" operates on Big Vector m, BigIntegers n,q
/* The method "Mod Mul" operates on Big Vector m, BigIntegers n,q
Returns: (m*n)mod q
and the result is stored in Big Vector calculatedResult.
*/
TEST(UTBinVect,test_modmul_BBI){
TEST(UTBinVect,ModMulTest){
BigInteger q("3534"); // constructor calling to set mod value
BigVector m(5,q); // calling constructor to create a vector of length 5 and passing value of q
......@@ -403,7 +403,7 @@ TEST(UTBinVect,test_modmul_BBI){
Returns: (m^n)mod q
and the result is stored in Big Vector calculatedResult.
*/
TEST(UTBinVect,test_modexp){
TEST(UTBinVect,ModExpTest){
bool dbg_flag = false;
BigInteger q("3534"); // constructor calling to set mod value
BigVector m(5,q); // calling constructor to create a vector of length 5 and passing value of q
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment