Four simple operations of polynomials
With addition, subtraction, multiplication and division
Addition and subtraction comparison native
Polynomial multiplication
FFT: realize the conversion between coefficient representation and point value representation of polynomial in \ (O(nlogn) \), scope: real number
NTT: realize the conversion between coefficient representation and point value representation of polynomial in \ (O(nlogn) \), scope: module with original root
MTT:
FFT on real field:
There's a thing called the n-th unit root \ (W_n^i \)
Consider bringing this thing in and looking at the whole equation
Consider how to convert a point value expression to a coefficient representation
Consider the wonderful properties of the unit root:
Consider the point value expression obtained by bringing in the reciprocal of n-th unit root
Back to FFT
In the recursion above, the coefficients are divided by parity every time. In fact, the constants are very constant
Consider a simpler approach:
Consider the case where the entire coefficient sequence recurses to the leaves
It is found that the coefficient distribution of the leaf is the binary inverted value of 0 to \ (n - 1 \)
So it can be \ (O(n) \) recursive
Then it's OK to start with the length in the FFT
The code is implemented as follows
#include<bits/stdc++.h> #define MAXN 2000005 typedef double ll; const ll PI = acos(-1.0); using namespace std; int n,m; int lim,len; int res[MAXN]; struct node{ll dx,dy;}a[MAXN],b[MAXN],c[MAXN]; node operator + (node A , node B){return (node){A.dx + B.dx , A.dy + B.dy};} node operator - (node A , node B){return (node){A.dx - B.dx , A.dy - B.dy};} node operator * (node A , node B){return (node){A.dx * B.dx - A.dy * B.dy , A.dx * B.dy + A.dy * B.dx};} node wn(double sz , int tp){ ll zz = 2.0 * PI / sz; if(tp == (-1))return (node){cos(zz) , -sin(zz)}; return (node){cos(zz) , sin(zz)}; } void fft(node f[] , int tp){ for(int i = 0 ; i < len ; i++)if(i < res[i])swap(f[i] , f[res[i]]); for(int k = 1 ; k < len ; k = k + k){ node w1 = wn(k << 1 , tp) , w , A , B; for(int i = 0 ; i < len ; i = i + k + k){ w = (node){1 , 0}; for(int j = 0 ; j < k ; j++ , w = w * w1){ A = f[i + j] , B = f[i + j + k] * w; f[i + j] = A + B; f[i + j + k] = A - B; } } } if(tp == (-1))for(int i = 0 ; i < len ; i++)f[i].dx /= (len * 1.0); } int main(){ scanf("%d%d" , &n , &m); for(int i = 0 ; i <= n ; i++)scanf("%lf" , &a[i].dx); for(int i = 0 ; i <= m ; i++)scanf("%lf" , &b[i].dx); len = 1 , lim = 0; while(len <= (n + m + 2))len = len + len , lim++; for(int i = 1 ; i < len ; i++){ res[i] = (res[i >> 1] >> 1) | ((i & 1) << (lim - 1)); } fft(a , 1); fft(b , 1); for(int i = 0 ; i < len ; i++)c[i] = a[i] * b[i]; fft(c , -1); for(int i = 0 ; i < n + m + 1 ; i++)printf("%d " , (int)(c[i].dx + 0.5)); }
Consider what to do in a modular sense
Consider the original root of 998244353, one of which is 3
After some mathematical derivation, it is found that the properties of primitive root are similar to unit root in the sense of module
Yes \ (w_n = G ^ {\ frac {p - 1} {n} \ mod p \)
The rest are similar to the unit root. I don't understand why the constant written before is so large
The code is as follows:
#include<bits/stdc++.h> #define MAXN 5000005 typedef long long ll; const ll mod = 998244353; const ll g = 3; using namespace std; int n,m,limit,len; ll a[MAXN],b[MAXN],c[MAXN]; int res[MAXN]; int poww(ll x , int y){ ll zz = 1; while(y){ if(y & 1)zz = 1ll * x * zz % mod; x = (1ll * x * x) % mod; y = y >> 1; } return zz; } //wn = void NTT(ll f[] , int tp){ for(int i = 0 ; i < len ; i++)if(i < res[i])swap(f[i] , f[res[i]]); for(int k = 1 ; k < len ; k = k + k){ ll w1 = poww(3 , (mod - 1) / (k << 1)) , w , A , B; if(tp == (-1))w1 = poww(w1 , mod - 2); for(int i = 0 ; i < len ; i = i + k + k){ w = 1; for(int j = 0 ; j < k ; j++ , w = w * w1 % mod){ A = f[i + j] , B = f[i + j + k] * w % mod; f[i + j] = ((A + B) % mod + mod) % mod; f[i + j + k] = ((A - B) % mod + mod) % mod; } } } if(tp == 1)return; ll zz = poww(len , mod - 2); for(int i = 0 ; i < len ; i++)f[i] = (1ll * f[i] * zz) % mod; } int main(){ scanf("%d%d" , &n , &m); for(int i = 0 ; i <= n ; i++)scanf("%lld" , &a[i]); for(int i = 0 ; i <= m ; i++)scanf("%lld" , &b[i]); len = 1 , limit = 0; while(len <= (n + m + 1))len = len + len , limit++; for(int i = 1 ; i < len ; i++)res[i] = (res[i >> 1] >> 1) | ((i & 1) << (limit - 1)); NTT(a , 1); NTT(b , 1); for(int i = 0 ; i < len ; i++)c[i] = 1ll * a[i] * b[i] % mod; NTT(c , -1); for(int i = 0 ; i < n + m + 1 ; i++)printf("%lld " , c[i]); }
P1919 [template] upgraded version of A*B Problem (FFT fast Fourier transform)
Board problem
#include<bits/stdc++.h> #define MAXN 5000005 typedef long long ll; const ll mod = 998244353; const ll g = 3; using namespace std; int n,m,limit,len,res[MAXN]; ll a[MAXN],b[MAXN],c[MAXN]; int sz; char s[MAXN]; ll poww(ll x , int y){ ll zz = 1; while(y){ if(y & 1)zz = zz * x % mod; x = x * x % mod; y = y >> 1; } return zz; } void NTT(ll f[] , int tp){ for(int i = 0 ; i < len ; i++)if(i < res[i])swap(f[i] , f[res[i]]); for(int k = 1 ; k < len ; k = k + k){ ll w1 = poww(3 , (mod - 1) / (k << 1)) , w , A , B; if(tp == (-1))w1 = poww(w1 , mod - 2); for(int i = 0 ; i < len ; i = i + k + k){ w = 1; for(int j = 0 ; j < k ; j++ , w = w * w1 % mod){ A = f[i + j] , B = f[i + j + k] * w % mod; f[i + j] = ((A + B) % mod + mod) % mod; f[i + j + k] = ((A - B) % mod + mod) % mod; } } } if(tp == 1)return; ll zz = poww(len , mod - 2); for(int i = 0 ; i < len ; i++)f[i] = (1ll * f[i] * zz) % mod; } int main(){ scanf("%s" , s) , sz = strlen(s) , n = sz - 1; for(int i = 0 ; i <= n ; i++)a[i] = s[n - i] - '0'; scanf("%s" , s) , sz = strlen(s) , m = sz - 1; for(int i = 0 ; i <= m ; i++)b[i] = s[m - i] - '0'; len = 1 , limit = 0; while(len <= (n + m + 4)){ len = len + len , limit++; } for(int i = 1 ; i < len ; i++)res[i] = (res[i >> 1] >> 1) | ((i & 1) << (limit - 1)); NTT(a , 1); NTT(b , 1); for(int i = 0 ; i < len ; i++)c[i] = (a[i] * b[i] % mod + mod) % mod; NTT(c , -1); for(int i = 0 ; i < n + m + 1 ; i++){ c[i + 1] = c[i + 1] + c[i] / 10 , c[i] %= 10; } int len = n + m + 1; while(c[len] >= 10)c[len + 1] = c[len + 1] + c[len] / 10 , c[len] %= 10 , len++; while(c[len] == 0)len--; for(int i = len ; i >= 0 ; i--)cout<<c[i]; }
Divide and conquer FFT
Example: lgP4721 [template] divide and conquer FFT
Consider cdq divide and conquer
The recursive interval is \ ([l, R], mid = (L + R) / 2 \)
Consider the contribution of \ ([l, mid] \) to interval \ ([mid + 1, R] \)
It is found that this is a very simple convolution form
It's just like the form of cdq divide and conquer
The complexity is probably \ (O(nlog^2n) \)
The code is as follows: lgP4721
#include<bits/stdc++.h> #define MAXN 500005 typedef long long ll; const ll mod = 998244353; using namespace std; int n; ll g[MAXN],ans[MAXN]; ll poww(ll x , int y){ ll zz = 1; while(y){ if(y & 1)zz = zz * x % mod; x = x * x % mod; y = y >> 1; } return zz; } int limit,len; ll a[MAXN],b[MAXN],res[MAXN]; void NTT(ll f[] , int tp){ for(int i = 1 ; i < len ; i++)res[i] = ((res[i >> 1] >> 1) | ((i & 1) << (limit - 1))); for(int i = 0 ; i < len ; i++)if(i < res[i])swap(f[i] , f[res[i]]); for(int k = 1 ; k < len ; k = k + k){ ll w1 = poww(3 , (mod - 1) / (k << 1)) , w , A , B; if(tp == (-1))w1 = poww(w1 , mod - 2); for(int i = 0 ; i < len ; i = i + k + k){ w = 1; for(int j = 0 ; j < k ; j++ , w = w * w1 % mod){ A = f[i + j] , B = f[i + j + k] * w % mod; f[i + j] = ((A + B) % mod + mod) % mod; f[i + j + k] = ((A - B) % mod + mod) % mod; } } } if(tp == 1)return; ll zz = poww(len , mod - 2); for(int i = 0 ; i < len ; i++)f[i] = (1ll * f[i] * zz) % mod; } void cdq(int l , int r){ if(l == r)return; int mid = (l + r) >> 1; cdq(l , mid); for(int i = l ; i <= mid ; i++)a[i - l] = ans[i] , b[i - l] = g[i - l]; for(int i = mid + 1 ; i <= r ; i++)a[i - l] = 0 , b[i - l] = g[i - l]; len = 1 , limit = 0; while(len <= (r - l + 1))len = len + len , limit++; for(int i = r - l + 1 ; i <= len ; i++)a[i] = b[i] = 0; NTT(a , 1) , NTT(b , 1); for(int i = 0 ; i < len ; i++)a[i] = (a[i] * b[i] % mod + mod) % mod; NTT(a , -1); for(int i = mid + 1 ; i <= r ; i++)ans[i] = (ans[i] + a[i - l]) % mod; cdq(mid + 1 , r); } int main(){ scanf("%d" , &n) , ans[0] = 1; for(int i = 1 ; i < n ; i++)scanf("%d" , &g[i]); len = 1;while(len < n)len = len + len; cdq(0 , len - 1); for(int i = 0 ; i < n ; i++)cout<<ans[i]<<" "; }
There is also the practice of polynomial inversion, followed by complement
Arbitrary modulus polynomial multiplication (MTT)
This thing generally deals with any module and can be used to strengthen the problem appropriately
1. Find the answers under three modules with original roots, and then do CRT for each answer for 9 times
2. Decompose each number into the form of \ (\ sqrt{mod} * p + q = a_i \), and then do polynomial multiplication
3.MTT (the stuff in the paper)
Method 2 implementation
#include<bits/stdc++.h> #define MAXN 400005 typedef long double ll; typedef long long LL; const ll PI = acos(-1.0); using namespace std; int n,m,limit,len; int res[MAXN]; LL ans[MAXN],part = 32768,p; struct node{ll dx,dy;}a1[MAXN],b1[MAXN],a2[MAXN],b2[MAXN],X[MAXN]; node operator + (node A , node B){return (node){A.dx + B.dx , A.dy + B.dy};} node operator - (node A , node B){return (node){A.dx - B.dx , A.dy - B.dy};} node operator * (node A , node B){return (node){A.dx * B.dx - A.dy * B.dy , A.dx * B.dy + A.dy * B.dx};} node wn(ll sz , int tp){ ll zz = 2.0 * PI / sz; if(tp == (-1))return (node){cos(zz) , -sin(zz)}; return (node){cos(zz) , sin(zz)}; } void fft(node f[] , int tp){ for(int i = 0 ; i < len ; i++)if(i < res[i])swap(f[i] , f[res[i]]); for(int k = 1 ; k < len ; k = k + k){ node w1 = wn(k << 1 , tp) , w , A , B; for(int i = 0 ; i < len ; i = i + k + k){ w = (node){1 , 0}; for(int j = 0 ; j < k ; j++ , w = w * w1){ A = f[i + j] , B = f[i + j + k] * w; f[i + j] = A + B; f[i + j + k] = A - B; } } } } void solve(node A[] , node B[] , LL res){ for(int i = 0 ; i < len ; i++)X[i].dx = X[i].dy = 0; for(int i = 0 ; i < len ; i++)X[i] = A[i] * B[i]; fft(X , -1); for(int i=0;i<=n+m;++i)(ans[i]+=(LL)(X[i].dx/len+0.5)%p*res%p)%=p; } void MTT(node f1[] , node f2[] , node g1[] , node g2[]){ fft(f1 , 1) , fft(f2 , 1) , fft(g1 , 1) , fft(g2 , 1); solve(f1 , g1 , part * part); solve(f1 , g2 , part); solve(f2 , g1 , part); solve(f2 , g2 , 1); for(int i = 0 ; i <= n + m ; i++){ cout<<(ans[i] % p + p) % p<<" "; } } int main(){ scanf("%d%d%lld" , &n , &m , &p);LL zz; for(int i = 0 ; i <= n ; i++){ scanf("%lld" , &zz); a1[i].dx = zz / part; a2[i].dx = zz % part; } for(int i = 0 ; i <= m ; i++){ scanf("%lld" , &zz); b1[i].dx = zz / part; b2[i].dx = zz % part; } len = 1 , limit = 0; while(len <= (n + m + 2))len = len + len , limit++; for(int i = 1 ; i < len ; i++)res[i] = (res[i >> 1] >> 1) | ((i & 1) << (limit - 1)); MTT(a1 , a2 , b1 , b2); }
Polynomial multiplicative inverse
Given a polynomial of degree n \ (F(x) \), let you find a polynomial \ (G(x) \) that satisfies \ (F(x) * G(x) = 1\ (mod\ x^n) \)
Specifically, it is equivalent to finding the inverse element of a polynomial function in the sense of multiplication
So it can be constructed recursively
#include<bits/stdc++.h> #define MAXN 400005 typedef long long ll; const ll mod = 998244353; using namespace std; int n,len,limit,res[MAXN]; ll a[MAXN],c[MAXN]; ll H[MAXN],G[MAXN]; ll poww(ll x , int y){ ll zz = 1; while(y){ if(y & 1)zz = zz * x % mod; x = x * x % mod; y = y >> 1; } return zz; } void NTT(ll f[] , int tp){ for(int i = 1 ; i < len ; i++)res[i] = (res[i >> 1] >> 1) | ((i & 1) << (limit - 1)); for(int i = 0 ; i < len ; i++)if(i < res[i])swap(f[i] , f[res[i]]); for(int k = 1 ; k < len ; k = k + k){ ll w1 = poww(3 , (mod - 1) / (k << 1)) , w , A , B; if(tp == (-1))w1 = poww(w1 , mod - 2); for(int i = 0 ; i < len ; i = i + k + k){ w = 1; for(int j = 0 ; j < k ; j++ , w = w * w1 % mod){ A = f[i + j] , B = f[i + j + k] * w % mod; f[i + j] = ((A + B) % mod + mod) % mod; f[i + j + k] = ((A - B) % mod + mod) % mod; } } } if(tp == 1)return; ll zz = poww(len , mod - 2); for(int i = 0 ; i < len ; i++)f[i] = (1ll * f[i] * zz) % mod; } void solve(int LEN , int L){ if(LEN == 1)return (void)(G[0] = poww(a[0] , mod - 2)); solve((LEN + 1) >> 1 , L - 1); swap(H , G); len = 1 , limit = 0; while(len < (LEN << 1))len = len << 1 , limit++; for(int i = 0 ; i < LEN ; i++)G[i] = c[i] = 0; for(int i = 0 ; i < LEN ; i++)c[i] = a[i]; for(int i = LEN ; i < len ; i++)c[i] = 0; NTT(c , 1) , NTT(H , 1); for(int i = 0 ; i < len ; i++)G[i] = ((H[i] * ((2ll - H[i] * c[i]) % mod + mod) % mod) % mod + mod) % mod; NTT(G , -1); for(int i = LEN ; i < len ; i++)G[i] = 0; } int main(){ scanf("%d" , &n);for(int i = 0 ; i < n ; i++)scanf("%lld" , &a[i]); solve(n , limit); for(int i = 0 ; i < n ; i++)cout<<G[i]<<" "; }
Polynomial ln,exp
We can do it by introducing the differential derivation of polynomials
Polynomial ln
#include<bits/stdc++.h> #define MAXN 400005 typedef long long ll; const ll mod = 998244353; using namespace std; int n,m,len,limit,res[MAXN]; ll a[MAXN],c[MAXN]; ll H[MAXN],G[MAXN]; ll poww(ll x , int y){ ll zz = 1; while(y){ if(y & 1)zz = zz * x % mod; x = x * x % mod; y = y >> 1; } return zz; } void NTT(ll f[] , int tp){ for(int i = 1 ; i < len ; i++)res[i] = (res[i >> 1] >> 1) | ((i & 1) << (limit - 1)); for(int i = 0 ; i < len ; i++)if(i < res[i])swap(f[i] , f[res[i]]); for(int k = 1 ; k < len ; k = k + k){ ll w1 = poww(3 , (mod - 1) / (k << 1)) , w , A , B; if(tp == (-1))w1 = poww(w1 , mod - 2); for(int i = 0 ; i < len ; i = i + k + k){ w = 1; for(int j = 0 ; j < k ; j++ , w = w * w1 % mod){ A = f[i + j] , B = f[i + j + k] * w % mod; f[i + j] = ((A + B) % mod + mod) % mod; f[i + j + k] = ((A - B) % mod + mod) % mod; } } } if(tp == 1)return; ll zz = poww(len , mod - 2); for(int i = 0 ; i < len ; i++)f[i] = (1ll * f[i] * zz) % mod; } void solve(int LEN){ if(LEN == 1)return (void)(G[0] = poww(a[0] , mod - 2)); solve((LEN + 1) >> 1); swap(H , G); len = 1 , limit = 0; while(len < (LEN << 1))len = len << 1 , limit++; for(int i = 0 ; i < LEN ; i++)G[i] = c[i] = 0; for(int i = 0 ; i < LEN ; i++)c[i] = a[i]; for(int i = LEN ; i < len ; i++)c[i] = 0; NTT(c , 1) , NTT(H , 1); for(int i = 0 ; i < len ; i++)G[i] = ((H[i] * ((2ll - H[i] * c[i]) % mod + mod) % mod) % mod + mod) % mod; NTT(G , -1); for(int i = LEN ; i < len ; i++)G[i] = 0; } int main(){ scanf("%d" , &n) , m = n - 1;for(int i = 0 ; i < n ; i++)scanf("%lld" , &a[i]); solve(n);memset(H , 0 , sizeof(H)); for(int i = 1 ; i < n ; i++)H[i - 1] = ((a[i] * (1ll * i)) % mod + mod) % mod; len = 1 , limit = 0; while(len <= (n + m + 1))len = len + len , limit++; NTT(H , 1) , NTT(G , 1); for(int i = 0 ; i < len ; i++)H[i] = (H[i] * G[i] % mod + mod) % mod; NTT(H , -1);memset(G , 0 , sizeof(G)); for(int i = 0 ; i < len ; i++){ G[i + 1] = (H[i] * poww(i + 1 , mod - 2) % mod + mod) % mod; } for(int i = 0 ; i < n ; i++)cout<<G[i]<<" "; }
newton iteration
Newton iteration over real number field
It is often used to approximate the zeros of some functions
Repeat the above process
Newton iteration on function field, although the upper and lower parts have little relationship, but the idea is similar
Note that f here is a function, but we treat it as a constant
Polynomial inversion is a good example
Polynomial exp
Polynomial square
Descending power polynomial multiplication
Coo~~~~~