Unverified Commit 61bc6405 authored by Horace He's avatar Horace He Committed by GitHub

Updated NTT to share interface with FFT (#167)

parent 0897eb95
......@@ -4,13 +4,14 @@
* License: CC0
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf (do read, it's excellent)
Accuracy bound from http://www.daemonology.net/papers/fft.pdf
* Description: fft(a) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution:
* Description: fft(a) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. N must be a power of 2.
Useful for convolution:
\texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$.
For convolution of complex numbers or more than two vectors: FFT, multiply
pointwise, divide by n, reverse(start+1, end), FFT back.
Rounding is safe if $(\sum a_i^2 + \sum b_i^2)\log_2{N} < 9\cdot10^{14}$
(in practice $10^{16}$; higher for random inputs).
Otherwise, use long doubles/NTT/FFTMod.
Otherwise, use NTT/FFTMod.
* Time: O(N \log N) with $N = |A|+|B|$ ($\tilde 1s$ for $N=2^{22}$)
* Status: somewhat tested
* Details: An in-depth examination of precision for both FFT and FFTMod can be found
......
......@@ -3,8 +3,13 @@
* Date: 2019-04-16
* License: CC0
* Source: based on KACTL's FFT
* Description: Can be used for convolutions modulo specific nice primes
* of the form $2^a b+1$, where the convolution result has size at most $2^a$.
* Description: ntt(a) computes $\hat f(k) = \sum_x a[x] g^{xk}$ for all $k$, where $g=\text{root}^{(mod-1)/N}$.
* N must be a power of 2.
* Useful for convolution modulo specific nice primes of the form $2^a b+1$,
* where the convolution result has size at most $2^a$. For arbitrary modulo, see FFTMod.
\texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$.
For manual convolution: NTT the inputs, multiply
pointwise, divide by n, reverse(start+1, end), NTT back.
* Inputs must be in [0, mod).
* Time: O(N \log N)
* Status: stress-tested
......@@ -16,32 +21,33 @@
const ll mod = (119 << 23) + 1, root = 62; // = 998244353
// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21
// and 483 << 21 (same root). The last two are > 10^9.
typedef vector<ll> vl;
void ntt(vl& a, vl& rt, vl& rev, int n) {
void ntt(vl &a) {
int n = sz(a), L = 31 - __builtin_clz(n);
static vl rt(2, 1);
for (static int k = 2, s = 2; k < n; k *= 2, s++) {
rt.resize(n);
ll z[] = {1, modpow(root, mod >> s)};
rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod;
}
vi rev(n);
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j];
a[i + j + k] = (z > ai ? ai - z + mod : ai - z);
ai += (ai + z >= mod ? z - mod : z);
}
ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j];
a[i + j + k] = ai - z + (z > ai ? mod : 0);
ai += (ai + z >= mod ? z - mod : z);
}
}
vl conv(const vl& a, const vl& b) {
if (a.empty() || b.empty())
return {};
int s = sz(a)+sz(b)-1, B = 32 - __builtin_clz(s), n = 1 << B;
vl L(a), R(b), out(n), rt(n, 1), rev(n);
vl conv(const vl &a, const vl &b) {
if (a.empty() || b.empty()) return {};
int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s), n = 1 << B;
int inv = modpow(n, mod - 2);
vl L(a), R(b), out(n);
L.resize(n), R.resize(n);
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << B) / 2;
ll curL = mod / 2, inv = modpow(n, mod - 2);
for (int k = 2; k < n; k *= 2) {
ll z[] = {1, modpow(root, curL /= 2)};
rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod;
}
ntt(L, rt, rev, n); ntt(R, rt, rev, n);
rep(i,0,n) out[-i & (n-1)] = L[i] * R[i] % mod * inv % mod;
ntt(out, rt, rev, n);
ntt(L), ntt(R);
rep(i,0,n) out[-i & (n - 1)] = (ll)L[i] * R[i] % mod * inv % mod;
ntt(out);
return {out.begin(), out.begin() + s};
}
......@@ -7,12 +7,12 @@ namespace ignore {
ll modpow(ll a, ll e);
#include "../../content/numerical/NumberTheoreticTransform.h"
ll modpow(ll a, ll e) {
if (e == 0) return 1;
ll x = modpow(a * a % mod, e >> 1);
return e & 1 ? x * a % mod : x;
if (e == 0)
return 1;
ll x = modpow(a * a % mod, e >> 1);
return e & 1 ? x * a % mod : x;
}
vl simpleConv(vl a, vl b) {
int s = sz(a) + sz(b) - 1;
if (a.empty() || b.empty()) return {};
......@@ -24,11 +24,11 @@ vl simpleConv(vl a, vl b) {
}
int ra() {
static unsigned X;
X *= 123671231;
X += 1238713;
X ^= 1237618;
return (X >> 1);
static unsigned X;
X *= 123671231;
X += 1238713;
X ^= 1237618;
return (X >> 1);
}
int main() {
......@@ -42,6 +42,14 @@ int main() {
for(auto &x: b) x = (ra() % 100 - 50+mod)%mod;
for(auto &x: simpleConv(a, b)) res += (ll)x * ind++ % mod;
for(auto &x: conv(a, b)) res2 += (ll)x * ind2++ % mod;
a.resize(16);
vl a2 = a;
ntt(a2);
rep(k, 0, sz(a2)) {
ll sum = 0;
rep(x, 0, sz(a2)) { sum = (sum + a[x] * modpow(root, k * x * (mod - 1) / sz(a))) % mod; }
assert(sum == a2[k]);
}
}
assert(res==res2);
cout<<"Tests passed!"<<endl;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment