...
 
Commits (2)
/**
* Author: Simon Lindholm
* Date: 2019-04-19
* Date: 2020-05-30
* License: CC0
* Source: https://en.wikipedia.org/wiki/Barrett_reduction
* Description: Compute $a \% b$ about 4 times faster than usual, where $a < b^2$ and $b$ is constant but not known at compile time.
* Fails for $b = 1$.
* \vspace{0.5mm}
* Description: Compute $a \% b$ about 5 times faster than usual, where $b$ is constant but not known at compile time.
* Returns a value congruent to $a \pmod b$ in the range $[0, 2b)$.
* Status: proven correct, stress-tested
* Measured as having 3 times lower latency, and 8 times higher throughput.
* Measured as having 4 times lower latency, and 8 times higher throughput, see stress-test.
* Details:
* More precisely, it can be proven that the result equals 0 only if $a = 0$,
* and otherwise lies in $[1, (1 + a/2^64) * b)$.
*/
#pragma once
typedef unsigned long long ull;
typedef __uint128_t L;
struct FastMod {
ull b, m;
FastMod(ull b) : b(b), m(ull((L(1) << 64) / b)) {}
ull reduce(ull a) {
ull q = (ull)((L(m) * a) >> 64), r = a - q * b;
return r - (r >= b) * b;
FastMod(ull b) : b(b), m(-1ULL / b) {}
ull reduce(ull a) { // a % b + (0 or b)
return a - (ull)((__uint128_t(m) * a) >> 64) * b;
}
};
......@@ -2,6 +2,18 @@
#include "../../content/various/FastMod.h"
typedef unsigned long long ull;
struct OldBarrett {
ull b, m;
OldBarrett(ull b) : b(b), m(-1ULL / b) {}
ull reduce(ull a) {
ull q = (ull)((__uint128_t(m) * a) >> 64), r = a - q * b;
return r >= b ? r - b : r;
}
};
// If EIGHT is defined, we compute eight simultaneous factorials, thus measuring
// throughput instead of latency.
// #define EIGHT
#ifdef EIGHT
......@@ -67,16 +79,21 @@ void perf_const() {
#undef FINISH
}
void perf_barrett(int mod) {
FastMod bar(mod);
#define INIT(x) ll ret##x = (x + 1);
#define UPDATE(x) ret##x = bar.reduce(ret##x * i);
#define FINISH(x) cout << ret##x << endl;
void perf_old_barrett(int mod) {
OldBarrett bar(mod);
TEST()
}
void perf_barrett(int mod) {
FastMod bar(mod);
TEST()
}
#undef INIT
#undef UPDATE
#undef FINISH
}
ull rand_u64() {
ull ret = rand();
......@@ -87,36 +104,42 @@ ull rand_u64() {
return ret;
}
#define main1 main
// Correctness
int main() {
int main1() {
const int bflim = 3000;
rep(a,0,bflim) rep(b,2,bflim) {
FastMod bar(b);
ull ret = bar.reduce(a);
assert((ret == 0) == (a == 0));
if (ret >= (ull)b) ret -= b;
assert(ret == (ull)(a % b));
}
rep(it,0,10'000'000) {
ull a = rand_u64();
ull b = rand_u64();
if (b <= 1) continue;
if (b == 0) continue;
FastMod bar(b);
ull ret = bar.reduce(a);
if (ret >= b) ret -= b;
assert(ret == a % b);
}
cout<<"Tests passed!"<<endl;
return 0;
}
// Performance
int main2(int argc, char** argv) {
const int which = atoi(argv[1]);
#ifndef MOD
#define MOD 90217093 //202171241
#endif
int mod = MOD;
if (which < 0) mod = which;
int mod = MOD;
// Performance
int main2(int argc, char** argv) {
int which = atoi(argv[1]);
if (which == 0) perf_plain(mod); // 7.529 for 8, 1.714 for 1
if (which == 1) perf_const<MOD>(); // 0.971 for 8, 0.499 for 1
if (which == 2) perf_barrett(mod); // 1.094 for 8, 0.564 for 1
if (which == 2) perf_old_barrett(mod); // 1.094 for 8, 0.564 for 1
if (which == 3) perf_barrett(mod); // 0.870 for 8, 0.405 for 1
return 0;
}