Commit 629eda4f authored by Gerard Ryan's avatar Gerard Ryan
Browse files

be6 now uses interface.h

parent e02b723d
......@@ -36,7 +36,7 @@ Field2n::Field2n(const Poly & element)
} else {
// the value of element.at(i) is usually small - so a 64-bit integer is more than enough
// this approach is much faster than BigInteger::ConvertToDouble
BigInteger negativeThreshold(element.GetModulus()/ 2);
BigInteger negativeThreshold(element.GetModulus() / Poly::Integer(2));
for (size_t i = 0; i < element.GetLength(); i++) {
if (element.at(i) > negativeThreshold)
this->push_back((double)(int64_t)(-1 * (element.GetModulus() - element.at(i)).ConvertToInt()));
......
......@@ -237,9 +237,9 @@ const PolyImpl<ModType,IntType,VecType,ParmType>& PolyImpl<ModType,IntType,VecTy
for (usint j = 0; j < vectorLength; ++j) { // loops within a tower
if (j < len) {
this->at(j)= IntType(*(rhs.begin() + j));
this->operator[](j)= IntType(*(rhs.begin() + j));
} else {
this->at(j)= ZERO;
this->operator[](j)= ZERO;
}
}
......@@ -275,10 +275,10 @@ const PolyImpl<ModType, IntType, VecType, ParmType>& PolyImpl<ModType, IntType,
tempInteger = *(rhs.begin() + j);
tempBI = IntType(tempInteger);
}
at(j)= tempBI;
operator[](j)= tempBI;
}
else {
at(j)= ZERO;
operator[](j)= ZERO;
}
}
......@@ -302,10 +302,10 @@ const PolyImpl<ModType, IntType, VecType, ParmType>& PolyImpl<ModType, IntType,
tempInteger = *(rhs.begin() + j);
tempBI = IntType(tempInteger);
}
temp.at(j)= tempBI;
temp.operator[](j)= tempBI;
}
else {
temp.at(j)= ZERO;
temp.operator[](j)= ZERO;
}
}
this->SetValues(std::move(temp), m_format);
......@@ -336,10 +336,10 @@ const PolyImpl<ModType, IntType, VecType, ParmType>& PolyImpl<ModType, IntType,
tempInteger = *(rhs.begin() + j);
tempBI = IntType(tempInteger);
}
at(j)= tempBI;
operator[](j)= tempBI;
}
else {
at(j)= ZERO;
operator[](j)= ZERO;
}
}
......@@ -363,10 +363,10 @@ const PolyImpl<ModType, IntType, VecType, ParmType>& PolyImpl<ModType, IntType,
tempInteger = *(rhs.begin() + j);
tempBI = IntType(tempInteger);
}
temp.at(j)= tempBI;
temp.operator[](j)= tempBI;
}
else {
temp.at(j)= ZERO;
temp.operator[](j)= ZERO;
}
}
this->SetValues(std::move(temp), m_format);
......@@ -385,9 +385,9 @@ const PolyImpl<ModType,IntType,VecType,ParmType>& PolyImpl<ModType,IntType,VecTy
for (usint j = 0; j < vectorLength; ++j) { // loops within a tower
if (j < len) {
m_values->at(j)= *(rhs.begin() + j);
m_values->operator[](j)= *(rhs.begin() + j);
} else {
m_values->at(j)= ZERO;
m_values->operator[](j)= ZERO;
}
}
......@@ -438,7 +438,7 @@ const PolyImpl<ModType,IntType,VecType,ParmType>& PolyImpl<ModType,IntType,VecTy
m_values = make_unique<VecType>(m_params->GetRingDimension(), m_params->GetModulus());
}
for (size_t i = 0; i < m_values->GetLength(); ++i) {
this->at(i)= IntType(val);
this->operator[](i)= IntType(val);
}
return *this;
}
......@@ -499,7 +499,7 @@ usint PolyImpl<ModType,IntType,VecType,ParmType>::GetLength() const
template<typename ModType, typename IntType, typename VecType, typename ParmType>
void PolyImpl<ModType,IntType,VecType,ParmType>::SetValues(const VecType& values, Format format)
{
if (m_params->GetRootOfUnity() == 0){
if (m_params->GetRootOfUnity() == IntType(0)){
PALISADE_THROW(type_error, "Polynomial has a 0 root of unity");
}
if (m_params->GetRingDimension() != values.GetLength() || m_params->GetModulus() != values.GetModulus()) {
......@@ -518,11 +518,11 @@ void PolyImpl<ModType,IntType,VecType,ParmType>::SetValuesToZero()
template<typename ModType, typename IntType, typename VecType, typename ParmType>
void PolyImpl<ModType,IntType,VecType,ParmType>::SetValuesToMax()
{
IntType max = m_params->GetModulus() - 1;
IntType max = m_params->GetModulus() - IntType(1);
usint size = m_params->GetRingDimension();
m_values = make_unique<VecType>(m_params->GetRingDimension(), m_params->GetModulus());
for (usint i = 0; i < size; i++) {
m_values->at(i)= IntType(max);
m_values->operator[](i)= IntType(max);
}
}
......@@ -574,7 +574,7 @@ PolyImpl<ModType,IntType,VecType,ParmType> PolyImpl<ModType,IntType,VecType,Parm
// throw std::logic_error("Negate for PolyImpl is supported only in EVALUATION format.\n");
PolyImpl<ModType,IntType,VecType,ParmType> tmp( *this );
*tmp.m_values = m_values->ModMul(this->m_params->GetModulus() - 1);
*tmp.m_values = m_values->ModMul(this->m_params->GetModulus() - IntType(1));
return std::move( tmp );
}
......@@ -667,9 +667,9 @@ void PolyImpl<ModType,IntType,VecType,ParmType>::AddILElementOne()
{
IntType tempValue;
for (usint i = 0; i < m_params->GetRingDimension(); i++) {
tempValue = GetValues().at(i) + 1;
tempValue = GetValues().operator[](i) + IntType(1);
tempValue = tempValue.Mod(m_params->GetModulus());
m_values->at(i)= tempValue;
m_values->operator[](i)= tempValue;
}
}
......@@ -697,15 +697,15 @@ PolyImpl<ModType,IntType,VecType,ParmType> PolyImpl<ModType,IntType,VecType,Parm
// based on the totient index (between 0 and m - 1)
VecType expanded(m, modulus);
for (usint i = 0; i < n; i++) {
expanded.at(totientList.at(i))= m_values->at(i);
expanded.operator[](totientList.operator[](i))= m_values->operator[](i);
}
for (usint i = 0; i < n; i++) {
//determines which power of primitive root unity we should switch to
usint idx = totientList.at(i)*k % m;
usint idx = totientList.operator[](i)*k % m;
result.m_values->at(i)= expanded.at(idx);
result.m_values->operator[](i)= expanded.operator[](idx);
}
} else {
......@@ -716,7 +716,7 @@ PolyImpl<ModType,IntType,VecType,ParmType> PolyImpl<ModType,IntType,VecType,Parm
//determines which power of primitive root unity we should switch to
usint idx = (j*k) % m;
result.m_values->at((j + 1) / 2 - 1)= GetValues().at((idx + 1) / 2 - 1);
result.m_values->operator[]((j + 1) / 2 - 1)= GetValues().operator[]((idx + 1) / 2 - 1);
}
......@@ -863,7 +863,7 @@ void PolyImpl<ModType,IntType,VecType,ParmType>::MakeSparse(const uint32_t &wFac
if (m_values != 0) {
for (usint i = 0; i < m_params->GetRingDimension();i++) {
if (i%wFactor != 0) {
m_values->at(i)= IntType(0);
m_values->operator[](i)= IntType(0);
}
}
}
......@@ -893,7 +893,7 @@ void PolyImpl<ModType,IntType,VecType,ParmType>::Decompose()
//Interleaving operation.
VecType decomposeValues(GetLength() / 2, GetModulus());
for (usint i = 0; i < GetLength(); i = i + 2) {
decomposeValues.at(i / 2)= GetValues().at(i);
decomposeValues.operator[](i / 2)= GetValues().operator[](i);
}
SetValues(decomposeValues, m_format);
......@@ -912,7 +912,7 @@ template<typename ModType, typename IntType, typename VecType, typename ParmType
bool PolyImpl<ModType,IntType,VecType,ParmType>::InverseExists() const
{
for (usint i = 0; i < GetValues().GetLength(); i++) {
if (m_values->at(i) == 0)
if (m_values->operator[](i) == IntType(0))
return false;
}
return true;
......@@ -927,10 +927,10 @@ double PolyImpl<ModType,IntType,VecType,ParmType>::Norm() const
const IntType &half = m_params->GetModulus() >> 1;
for (usint i = 0; i < GetValues().GetLength(); i++) {
if (m_values->at(i) > half)
if (m_values->operator[](i) > half)
locVal = q - (*m_values)[i];
else
locVal = m_values->at(i);
locVal = m_values->operator[](i);
if (locVal > retVal)
retVal = locVal;
......
......@@ -186,7 +186,7 @@ public:
* Sets/gets a value at an index.
*
* @param index is the index to set a value at.
*/
*/
IntegerType& at(size_t i) {
if(!this->IndexCheck(i)) {
......
......@@ -67,7 +67,7 @@ IntType DiscreteUniformGeneratorImpl<IntType,VecType>::GenerateInteger () const
//stores current random number generated by built-in C++ 11 uniform generator (used for 32-bit unsigned integers)
uint32_t value;
if( m_modulus == 0 ) {
if( m_modulus == IntType(0) ) {
throw std::logic_error("0 modulus?");
}
......
......@@ -133,20 +133,42 @@ public:
void SetIdentity(){*this=1;}
//palisade arithmetic methods
myZZ Plus(const myZZ& b) const {return *this+b;};
myZZ Plus(const myZZ& b) const {
return *this + b;
}
const myZZ& PlusEq(const myZZ& b) {
return *this += b;
}
//note that in Sub we return 0, if a<b
myZZ Minus(const myZZ& b) const {return((*this<b)? ZZ(0):( *this-b));}
myZZ Minus(const myZZ& b) const {
return (*this<b) ? ZZ(0): (*this-b);
}
const myZZ& MinusEq(const myZZ& b) {
if (*this<b)
*this = ZZ(0);
else
*this -= b;
return *this;
}
myZZ Times(const myZZ& b) const { return *this * b; }
const myZZ& TimesEq(const myZZ& b) { return *this *= b; }
myZZ DividedBy(const myZZ& b) const {return *this/b;}
myZZ DividedBy(const myZZ& b) const {return *this / b; }
const myZZ& DividedByEq(const myZZ& b) {return *this /= b; }
myZZ Exp(const usint p) const {return power(*this,p);}
//palisade modular arithmetic methods
myZZ Mod(const myZZ& modulus) const {return *this%modulus;}
const myZZ& ModEq(const myZZ& modulus) {
*this %= modulus;
return *this;
}
myZZ ModBarrett(const myZZ& modulus, const myZZ& mu) const {return *this%modulus;}
void ModBarrettInPlace(const myZZ& modulus, const myZZ& mu) { *this%=modulus;}
......@@ -250,6 +272,45 @@ public:
myZZ MultiplyAndRound(const myZZ &p, const myZZ &q) const;
myZZ DivideAndRound(const myZZ &q) const;
/**
* << operation
*
* @param shift # of bits
* @return result of the shift operation.
*/
myZZ LShift(usshort shift) const { return (*this) << shift; }
/**
* <<= operation
*
* @param shift # of bits
* @return result of the shift operation.
*/
const myZZ& LShiftEq(usshort shift) {
(*this) <<= shift;
return *this;
}
/**
* >> operation
*
* @param shift # of bits
* @return result of the shift operation.
*/
myZZ RShift(usshort shift) const { return (*this) >> shift; }
/**
* >>= operation
*
* @param shift # of bits
* @return result of the shift operation.
*/
const myZZ& RShiftEq(usshort shift) {
(*this) >>= shift;
return *this;
}
//big integer stream output
friend std::ostream& operator<<(std::ostream& os, const myZZ&ptr_obj);
......
......@@ -376,14 +376,6 @@ const myVecP<myT>& myVecP<myT>::operator=( myVecP<myT> &&rhs)
return *this;
}
//desctructor
template<class myT>
myVecP<myT>::~myVecP()
{
}
template<class myT>
void myVecP<myT>::clear(myVecP<myT>& x)
{
......
......@@ -110,7 +110,24 @@ public:
myVecP(const myVecP<myT> &a, const uint64_t q);
//destructor
~myVecP();
~myVecP() {}
/**
* ostream operator to output vector values to console
*
* @param os is the std ostream object.
* @param &ptr_obj is the BigVectorImpl object to be printed.
* @return std ostream object which captures the vector values.
*/
friend std::ostream& operator<<(std::ostream& os, const myVecP<myT> &ptr_obj) {
auto len = ptr_obj.GetLength();
os<<"[";
for(size_t i=0; i < len; i++) {
os<< ptr_obj.at(i);
os << ((i == (len-1))?"]":" ");
}
return os;
}
//adapters
myVecP(std::vector<std::string>& s); //without modulus
......@@ -202,72 +219,88 @@ public:
void SwitchModulus(const myT& newModulus);
inline myVecP Add(const myT& b) const {ModulusCheck("Warning: myVecP::Add"); return (*this)+b%m_modulus; };
inline myVecP ModAdd(const myT& b) const {ModulusCheck("Warning: myVecP::ModAdd"); return this->Add(b); };
myVecP ModAdd(const myT& b) const {
ModulusCheck("Warning: myVecP::ModAdd");
return (*this) + b % m_modulus;
}
const myVecP& ModAddEq(const myT& b) {
ModulusCheck("Warning: myVecP::ModAdd");
(*this) += b % m_modulus;
return *this;
}
void modadd_p(myVecP& x, const myVecP& a, const myVecP& b) const; //define procedural version
myVecP ModAddAtIndex(size_t i, const myT &b) const;
//vector add
myVecP Add(const myVecP& b) const {
ArgCheckVector(b, "myVecP Add()");
return (*this)+b;
};
myVecP ModAdd(const myVecP& b) const {
return (this->Add(b));
};
// //vector add
// myVecP Add(const myVecP& b) const {
// ArgCheckVector(b, "myVecP Add()");
// return (*this)+b;
// }
//Subtraction
//vector subtraction assignment note uses DIFFERNT modsub than standard math
//this is a SIGNED mod sub
inline myVecP& operator-=(const myVecP& a) {
ArgCheckVector(a, "myVecP -=");
modsub_p(*this, *this, a);
return *this;
};
myVecP ModAdd(const myVecP& b) const {
return (*this) + b % m_modulus;
}
//scalar subtraction assignment
inline myVecP& operator-=(const myT& a)
{
ModulusCheck("Warning: myVecP::op-=");
*this = *this-a;
const myVecP& ModAddEq(const myVecP& b) {
(*this) += b % m_modulus;
return *this;
};
}
//scalar
myVecP Sub(const myT& b) const {ModulusCheck("Warning: myVecP::Sub"); return (*this)-b%m_modulus;};
myVecP ModSub(const myT& b) const {ModulusCheck("Warning: myVecP::ModSub"); return (*this)-b%m_modulus;};
myVecP ModSub(const myT& b) const {
ModulusCheck("Warning: myVecP::ModSub");
return (*this) - b % m_modulus;
}
const myVecP& ModSubEq(const myT& b) {
ModulusCheck("Warning: myVecP::ModSub");
(*this) -= b % m_modulus;
return (*this);
}
//vector
myVecP Sub(const myVecP& b) const {
bool dbg_flag = false;
DEBUG("in myVecP::Sub");
DEBUG(*this);
DEBUG(this->GetModulus());
DEBUG(b);
DEBUG(b.GetModulus());
ArgCheckVector(b, "myVecP Sub()");
return (*this)-b;
};
myVecP ModSub(const myVecP& b) const {ArgCheckVector(b, "myVecP ModSub()"); return (this->Sub(b));};
myVecP ModSub(const myVecP& b) const {
ArgCheckVector(b, "myVecP ModSub()");
return (*this) - b % m_modulus;
}
//deprecated vector
inline myVecP Minus(const myVecP& b) const {ArgCheckVector(b, "myVecP Minus()"); return (this->Sub(b));};
const myVecP& ModSubEq(const myVecP& b) {
ArgCheckVector(b, "myVecP ModSub()");
(*this) -= b % m_modulus;
return (*this);
}
//procecural
void modsub_p(myVecP& x, const myVecP& a, const myVecP& b) const; //define procedural
//scalar
inline myVecP Mul(const myT& b) const {ModulusCheck("Warning: myVecP::Mul"); return (*this)*b%m_modulus;};
inline myVecP ModMul(const myT& b) const {ModulusCheck("Warning: myVecP::ModMul"); return (*this)*b%m_modulus;};
myVecP ModMul(const myT& b) const {
ModulusCheck("Warning: myVecP::ModMul");
return (*this) * b % m_modulus;
}
const myVecP& ModMulEq(const myT& b) {
ModulusCheck("Warning: myVecP::ModMul");
(*this) *= b % m_modulus;
return (*this);
}
//vector
inline myVecP Mul(const myVecP& b) const {ArgCheckVector(b, "myVecP Mul()"); return (*this)*b;};
inline myVecP ModMul(const myVecP& b) const {ArgCheckVector(b, "myVecP Mul()");return (this->Mul(b));};
myVecP ModMul(const myVecP& b) const {
ArgCheckVector(b, "myVecP Mul()");
return (*this) * b % m_modulus;
}
void modmul_p(myVecP& x, const myVecP& a, const myVecP& b) const; //define procedural
const myVecP& ModMulEq(const myVecP& b) {
ArgCheckVector(b, "myVecP Mul()");
(*this) *= b % m_modulus;
return (*this);
}
void modmul_p(myVecP& x, const myVecP& a, const myVecP& b) const; //define procedural
/**
* Scalar exponentiation.
......
......@@ -518,7 +518,7 @@ MatrixStrassen<double> Cholesky(const MatrixStrassen<int32_t> &input) {
MatrixStrassen<int32_t> ConvertToInt32(const MatrixStrassen<BigInteger> &input, const BigInteger& modulus) {
size_t rows = input.GetRows();
size_t cols = input.GetCols();
BigInteger negativeThreshold(modulus / 2);
BigInteger negativeThreshold(modulus / BigInteger(2));
MatrixStrassen<int32_t> result([](){ return make_unique<int32_t>(); }, rows, cols);
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
......@@ -535,7 +535,7 @@ MatrixStrassen<int32_t> ConvertToInt32(const MatrixStrassen<BigInteger> &input,
MatrixStrassen<int32_t> ConvertToInt32(const MatrixStrassen<BigVector> &input, const BigInteger& modulus) {
size_t rows = input.GetRows();
size_t cols = input.GetCols();
BigInteger negativeThreshold(modulus / 2);
BigInteger negativeThreshold(modulus / BigInteger(2));
MatrixStrassen<int32_t> result([](){ return make_unique<int32_t>(); }, rows, cols);
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
......
......@@ -186,14 +186,14 @@ namespace lbcrypto {
bool prevMod = false;
for (usint i = 1; i < s + 1; i++) {
DEBUG("wf " << i);
if (mod != 1 && mod != p - 1)
if (mod != IntType(1) && mod != p - IntType(1))
prevMod = true;
else
prevMod = false;
mod = mod.ModMul(mod, p);
if (mod == 1 && prevMod) return true;
if (mod == IntType(1) && prevMod) return true;
}
return (mod != 1);
return (mod != IntType(1));
}
/*
......@@ -208,8 +208,8 @@ namespace lbcrypto {
std::set<IntType> primeFactors;
DEBUG("FindGenerator(" << q << "),calling PrimeFactorize");
IntType qm1 = q - 1;
IntType qm2 = q - 2;
IntType qm1 = q - IntType(1);
IntType qm2 = q - IntType(2);
PrimeFactorize<IntType>(qm1, primeFactors);
DEBUG("prime factors of " << qm1);
for( auto& v : primeFactors ) DEBUG(v << " ");
......@@ -220,14 +220,14 @@ namespace lbcrypto {
usint count = 0;
//gen = RNG(qm2).ModAdd(IntType::ONE, q); //modadd note needed
gen = RNG(qm2) + 1;
gen = RNG(qm2) + IntType(1);
DEBUG("generator " << gen);
DEBUG("cycling thru prime factors");
for (auto it = primeFactors.begin(); it != primeFactors.end(); ++it) {
DEBUG(qm1 << " / " << *it << " " << gen.ModExp(qm1 / (*it), q));
if (gen.ModExp(qm1 / (*it), q) == 1) break;
if (gen.ModExp(qm1 / (*it), q) == IntType(1)) break;
else count++;
}
if (count == primeFactors.size()) generatorFound = true;
......@@ -258,8 +258,8 @@ namespace lbcrypto {
usint count = 0;
DEBUG("count " << count);
gen = RNG(phi_q_m1) + 1; // gen is random in [1, phi(q)]
if (GreatestCommonDivisor<IntType>(gen, q) != 1) {
gen = RNG(phi_q_m1) + IntType(1); // gen is random in [1, phi(q)]
if (GreatestCommonDivisor<IntType>(gen, q) != IntType(1)) {
// Generator must lie in the group!
continue;
}
......@@ -269,7 +269,7 @@ namespace lbcrypto {
DEBUG("in set");
DEBUG("divide " << phi_q << " by " << *it);
if (gen.ModExp(phi_q / (*it), q) == 1) break;
if (gen.ModExp(phi_q / (*it), q) == IntType(1)) break;
else count++;
}
......@@ -301,7 +301,7 @@ namespace lbcrypto {
DEBUG("in set");
DEBUG("divide " << qm1 << " by " << *it);
if (g.ModExp(qm1 / (*it), q) == 1) break;
if (g.ModExp(qm1 / (*it), q) == IntType(1)) break;
else count++;
}