多项式的一堆乱七八糟的操作学了一部分了……(多点求值和快速插值还没有)
打算写下来整理一下。不过因为还有一些没学的以及没完全理解的……只好先持续更新了。
不扯淡了,直接开始。
1.NTT
FFT咱就不说了,有兴趣可以看兔哥博客.
NTT和FFT很相似。但是因为FFT涉及到复数运算所以会有一些精度误差,然后有的时候也会遇到需要取模的情况……于是快速数论变换NTT应运而生。因为单位根和原根有相似的性质,所以NTT使用原根取代了单位根进行运算。模数998244353的原根是3,每次取原根的(frac{mod-1}{2i})次方代替单位根即可。逆变换的时候就用原根的逆元。
注意最后我们要像FFT一样乘以数组长度的逆元。其实还有另一种办法,就是直接进行一次reverse。据pinkrabbit大佬说,NTT数列是共轭对称的(然鹅我不知道啥是共轭对称),不过结果是正确的。
写法啥的和FFT都一样。看一下代码。有兴趣还可以看Miskcoo的博客
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
using namespace std;
typedef long long ll;
const int M = 400005;
const int mod = 998244353;
const int G = 3;
const int invG = 332748118;
int read()
{
int ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
int n,m,a[M],b[M],c[M],rev[M],len = 1,L;
int add(int a,int b) {return a + b > mod ? a + b - mod : a + b;}
int mul(int a,int b) {return 1ll * a * b % mod;}
int qpow(int a,int b)
{
int p = 1;
while(b)
{
if(b & 1) p = mul(p,a);
a = mul(a,a),b >>= 1;
}
return p;
}
void NTT(int *a,int n,int f)
{
rep(i,0,n) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < n;i <<= 1)
{
int w1 = qpow(f ? G : invG,(mod-1) / (i<<1));
for(int j = 0;j < n;j += (i<<1))
{
int w = 1;
rep(k,0,i-1)
{
int kx = a[k+j],ky = mul(a[k+j+i],w);
a[k+j] = add(kx,ky),a[k+j+i] = add(kx,mod-ky),w = mul(w,w1);
}
}
}
if(!f)
{
int inv = qpow(n,mod-2);
rep(i,0,n) a[i] = mul(a[i],inv);
}
}
int main()
{
n = read(),m = read();
rep(i,0,n) a[i] = read();
rep(i,0,m) b[i] = read();
while(len <= n+m+2) len <<= 1,L++;
rep(i,0,len) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
NTT(a,len,1),NTT(b,len,1);
rep(i,0,len) c[i] = mul(a[i],b[i]);
NTT(c,len,0);
rep(i,0,n+m) printf("%d ",c[i]);enter;
return 0;
}
2.多项式求逆
这玩意在很多奇怪的多项式操作中都要用。
具体就是给你一个多项式(F(x)),让你求出它在(mod x^n)意义下的逆元,也就是求出多项式(G(x)),使得(F(x)G(x) equiv 1 (mod x^n)).系数对于998244353取模。
个人认为其核心思想是递归。对于只有一项的,那么显然(G(x))的常数项就是(F(x))的常数项的逆元,否则对于n项的,我们可以递归求解。
首先假设我们已经知道了(F(x)H(x)equiv 1(mod x^frac{n}{2}))
那么显然也有(F(x)G(x) equiv 1(mod x^frac{n}{2})) 这个是根据(G(x))的定义来的。
之后我们把两个式子相减,就可以得到:(F(x)(G(x)-H(x)) equiv 0(mod x^frac{n}{2}))
自然就有((G(x)-H(x)) equiv 0(mod x^frac{n}{2}))
把这个式子进行平方,我们就可以得到:((G(x)-H(x))^2 equiv 0 (mod x^n)).这里解释一下,因为一个在(mod x^frac{n}{2})情况下为0的多项式,指数小于(frac{n}{2})的项都为0,因为卷积的性质,其自乘的前n项也必然都为0,所以其在(mod x^n)的意义下也是0.
式子展开就可以得到:(G(x)^2 + H(x)^2 - 2G(x)H(x) equiv 0(mod x^n))
移项,同时乘以(F(x)),由(F(x)G(x) equiv 1(mod x^n))就可以得到(G(x) equiv 2H(x)-F(x)H(x)^2(mod x^n))
我们就可以用NTT来解决啦。时间复杂度(O(nlogn))
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define lowbit(x) (x & (-x))
using namespace std;
typedef long long ll;
const int M = 400005;
const ll mod = 998244353;
const ll G = 3;
const ll invG = 332748118;
ll read()
{
ll ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
ll n,g[M],f[M],r[M],c[M],b[M],rev[M];
ll inc(ll a,ll b) {return (a + b) % mod;}
ll mul(ll a,ll b) {return a * b % mod;}
ll qpow(ll a,ll b)
{
ll p = 1;
while(b)
{
if(b & 1) p = mul(p,a);
a = mul(a,a),b >>= 1;
}
return p;
}
void NTT(ll *a,ll l,ll f)
{
rep(i,0,l-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < l;i <<= 1)
{
ll w1 = qpow(f ? G : invG,(mod-1) / (i<<1));
for(int j = 0;j < l;j += (i<<1))
{
ll w = 1;
rep(k,0,i-1)
{
ll kx = a[k+j],ky = mul(a[k+j+i],w);
a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky);
w = mul(w,w1);
}
}
}
if(!f)
{
ll inv = qpow(l,mod-2);
rep(i,0,l-1) a[i] = mul(a[i],inv);
}
}
void solve(int len)
{
if(len == 1) {g[0] = qpow(f[0],mod-2);return;}
solve((len+1)>>1);
ll l = 1,L = 0;
while(l < (len<<1)) l <<= 1,L++;
rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
rep(i,0,len-1) c[i] = f[i];
rep(i,len,l-1) c[i] = 0;
NTT(c,l,1),NTT(g,l,1);
rep(i,0,l-1) g[i] = mul(inc(2,mod-mul(c[i],g[i])),g[i]);
NTT(g,l,0);
rep(i,len,l-1) g[i] = 0;
}
int main()
{
n = read();
rep(i,0,n-1) f[i] = read();
solve(n);
rep(i,0,n-1) printf("%lld ",g[i]);enter;
return 0;
}
3.多项式对数函数
多项式的对数函数是啥?(Misckoo)大佬说,你可以将其理解为多项式和麦克劳林级数的复合。(然鹅我不会高数啊2333)或许就是把(ln(1-x))进行一下泰勒展开?
不说这些我不大会的了,反正你计算的时候其实不用。我们要求的就是多项式(B(x)) ,使得(B(x)equiv ln A(x) (mod x^n))
这个直接算非常难。考虑同时对两边求导。就有(B'(x) equiv frac{A'(x)}{A(x)} (mod x^n))
直接对(A(x))求导求逆,之后对(B(x))积分一下就行。
求导和积分都是(O(n))的,所以复杂度就是求逆的复杂度。
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define space putchar(' ')
#define lowbit(x) (x & (-x))
using namespace std;
typedef long long ll;
const int M = 400005;
const ll mod = 998244353;
const ll G = 3;
const ll invG = 332748118;
ll read()
{
ll ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
void write(ll x)
{
if(x < 10) {putchar(x+'0');return;}
char k = x % 10 + '0';
write(x / 10),putchar(k);
}
ll n,g[M],f[M],r[M],c[M],rev[M],h[M];
ll inc(ll a,ll b) {return (a + b) % mod;}
ll mul(ll a,ll b) {return 1ll * (a) * (b) % mod;}
ll qpow(ll a,ll b)
{
ll p = 1;
while(b)
{
if(b & 1) p = mul(p,a);
a = mul(a,a),b >>= 1;
}
return p;
}
void NTT(ll *a,ll l,ll f)
{
rep(i,0,l-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < l;i <<= 1)
{
ll w1 = qpow(f ? G : invG,(mod - 1) / (i << 1));
for(int j = 0;j < l;j += (i<<1))
{
ll w = 1;
rep(k,0,i-1)
{
ll kx = a[k+j],ky = mul(w,a[k+j+i]);
a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky);
w = mul(w,w1);
}
}
}
if(!f)
{
ll inv = qpow(l,mod-2);
rep(i,0,l-1) a[i] = mul(a[i],inv);
}
}
void derit(ll *a,ll *b,ll len) {rep(i,1,len-1) b[i-1] = mul(a[i],i);b[len-1] = 0;}
void inter(ll *a,ll *b,ll len) {rep(i,1,len-1) b[i] = mul(a[i-1],qpow(i,mod-2));b[0] = 0;}
void getinv(ll *a,ll *b,ll len)
{
if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
getinv(a,b,(len+1)>>1);
ll l = 1,L = 0;
while(l < (len << 1)) l <<= 1,L++;
rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
rep(i,0,len-1) g[i] = a[i];
rep(i,len,l-1) g[i] = 0;
NTT(g,l,1),NTT(b,l,1);
rep(i,0,l-1) b[i] = mul(inc(2,mod-mul(g[i],b[i])),b[i]);
NTT(b,l,0);
rep(i,len,l-1) b[i] = 0;
}
void getln(ll *a,ll *b,ll len)
{
derit(a,c,len),getinv(a,r,len);
ll l = 1,L = 0;
while(l < (len<<1)) l <<= 1,L++;
rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
NTT(c,l,1),NTT(r,l,1);
rep(i,0,l-1) c[i] = mul(c[i],r[i]);
NTT(c,l,0);
inter(c,b,len);
}
int main()
{
n = read();
rep(i,0,n-1) f[i] = read();
ll l = 1;
while(l < n) l <<= 1;
getln(f,h,l);
rep(i,0,n-1) write(h[i]),space;enter;
return 0;
}
4.多项式指数函数
理解方法的话……就照着对数函数理解就行。其实就是让你求(B(x)),满足(B(x) equiv e^{A(x)} (mod x^n))
这个咋做?继续求导?(e^x)的导数还是(e^x)……
于是我们需要一个前置知识:多项式牛顿迭代。
具体可以参考Misckoo的博客 ,这里只给出式子了。
假设我们知道(G(x)),我们想求一个多项式(F(x)),满足(G(F(x)) equiv 0 (mod x^n))
首先只有一项的时候,(G(F(x)) equiv 0(mod x))是要单独计算的。
还是倍增的思想……假设我们已经求出(G(H(x)) equiv 0 (mod x^frac{n}{2}))
如何拓展到(mod x^n)下呢? 我们把(G(F(x)))在(H(x))处进行泰勒展开。就有:
(G(F(x)) = G(H(x)) + frac{G'(H(x))}{1!}(F(x)-H(x)) + frac{G''(H(x))}{2!}(F(x)-H(x))^2 + …)
因为(F(x))和(H(x))的后面(frac{n}{2})项相同,故((F(x)-H(x))^2)及以上次方项在(mod x^n)意义下均为0,所以就有:
(G(F(x)) equiv G(H(x)) + G'(H(x))(F(x)-H(x)) (mod x^n))
由(G(F(x)) equiv 0 (mod x^n)),所以(F(x) equiv H(x) - frac{G(H(x))}{G'(H(x)} (mod x^n))
我们就可以开始使用这个式子解决问题了。
回到刚才的问题。我们把式子变一下形再移项,就是(ln B(x) - A(x) equiv 0 (mod x^n))
我们相当于求函数零点。令(G(B(x)) = ln B(x)-A(x)),对函数求导得到:(G'(B(x)) = frac{1}{B(x)})
因为这里(F(x))是变量,自然(A(x))可以看作常数。
带回上面的多项式牛顿迭代的式子,就有:(B(x) equiv B_0(x)(1-ln B_0(x) + A(x)) (mod x^n))
用多项式对数函数+递归即可解决。
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define I inline
using namespace std;
typedef long long ll;
const int M = 400005;
const int INF = 200000000;
const int mod = 998244353;
const int T = 3;
const int invT = 332748118;
int read()
{
int ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
int n,F[M],G[M],Fi[M],A[M],H[M],rev[M],c[M],d[M],Gi[M];
int add(int a,int b){return (a+b) % mod;}
int mul(int a,int b){return 1ll * a * b % mod;}
int qpow(int a,int b)
{
int p = 1;
while(b)
{
if(b&1) p = mul(p,a);
a = mul(a,a),b >>= 1;
}
return p;
}
void NTT(int *a,int n,int f)
{
rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < n;i <<= 1)
{
int w1 = qpow(f? T : invT,(mod-1) / (i<<1));
for(int j = 0;j < n;j += (i<<1))
{
int w = 1;
rep(k,0,i-1)
{
int kx = a[k+j],ky = mul(w,a[k+j+i]);
a[k+j] = add(kx,ky),a[k+j+i] = add(kx,mod-ky),w = mul(w,w1);
}
}
}
if(!f)
{
int inv = qpow(n,mod-2);
rep(i,0,n-1) a[i] = mul(a[i],inv);
}
}
void derit(int *a,int *b,int len) {rep(i,1,len-1) b[i-1] = mul(a[i],i);b[len-1] = 0;}
void inter(int *a,int *b,int len) {rep(i,1,len-1) b[i] = mul(a[i-1],qpow(i,mod-2));b[0] = 0;}
void getrev(int l,int L){rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));}
void getinv(int *a,int *b,int len)
{
if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
getinv(a,b,(len+1)>>1);
int l = 1,L = 0;
while(l < (len<<1)) l <<= 1,L++;
getrev(l,L);
rep(i,0,len-1) Gi[i] = a[i];
rep(i,len,l-1) Gi[i] = 0;
NTT(Gi,l,1),NTT(b,l,1);
rep(i,0,l-1) b[i] = mul(add(2,mod-mul(Gi[i],b[i])),b[i]);
NTT(b,l,0);
rep(i,len,l-1) b[i] = 0;
}
void getln(int *a,int *b,int len)
{
derit(a,c,len),getinv(a,G,len);
int l = 1,L = 0;
while(l < (len<<1)) l <<= 1,L++;
getrev(l,L);
NTT(c,l,1),NTT(G,l,1);
rep(i,0,l-1) c[i] = mul(G[i],c[i]);
NTT(c,l,0);
inter(c,b,len);
rep(i,0,l-1) c[i] = G[i] = 0;
}
void getexp(int *a,int *b,int len)
{
if(len == 1) {b[0] = 1;return;}
getexp(a,b,(len+1)>>1),getln(b,F,len);
F[0] = add(a[0]+1,mod-F[0]);
rep(i,1,len-1) F[i] = add(a[i],mod-F[i]);
int l = 1,L = 0;
while(l < (len<<1)) l <<= 1,L++;
getrev(l,L);
NTT(F,l,1),NTT(b,l,1);
rep(i,0,l-1) b[i] = mul(b[i],F[i]);
NTT(b,l,0);
rep(i,len,l-1) b[i] = F[i] = 0;
}
int main()
{
n = read();
rep(i,0,n-1) A[i] = read();
int l = 1;
while(l <= n) l <<= 1;
getexp(A,H,l);
rep(i,0,n-1) printf("%d ",H[i]);enter;
return 0;
}
5.多项式开根
给定(F(x)),求(G(x)^2 equiv F(x) (mod x^n))
这个有几种做法……首先说纯代数推导吧。
同样用倍增的想法,假设我们知道(H(x)^2 equiv F(x) (mod x^frac{n}{2}))
将两个式子相减得到(G(x)^2 - H(x)^2 equiv 0 (mod x^frac{n}{2}))
则有(G(x) - H(x) equiv 0 (mod x^frac{n}{2}))
将这个式子平方之后展开,用(F(x))来替换(G(x)^2),于是有了:(F(x) - 2G(x)H(x) + H(x)^2 equiv 0 (mod x^n))
移项得到(G(x) equiv frac{H(x)^2 + F(x)}{2H(x)})
于是我们就可以用求逆来解决这个问题了。注意你可以选择先约分再计算,或者直接计算,结果都是一样的,不过要特别注意的是,在做加法的时候不要加多了,加到len即可。
这种做法还有另一种推导方式,就是多项式牛顿迭代。
我们要求的式子都一样,那我们可以构造函数:(H(G(x)) = G(x)^2 - F(x)),求这个函数的零点。
对之求导,(H'(G(x)) = 2G(x)),之后带入上面的多项式牛顿迭代的方程,立即得到:(G(x) equiv frac{G_0(x)^2 + F(x)}{2G_0(x)} (mod x^n))
用与上面相同的方法解决即可。
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define I inline
using namespace std;
typedef long long ll;
const int M = 800005;
const int mod = 998244353;
const int G = 3;
const int invG = 332748118;
int read()
{
int ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
int rev[M],A[M],B[M],C[M],D[M],F[M],n,inv2;
int inc(int a,int b){return (a+b) % mod;}
int mul(int a,int b){return 1ll * a * b % mod;}
int qpow(int a,int b)
{
int p = 1;
while(b)
{
if(b & 1) p = mul(p,a);
a = mul(a,a),b >>= 1;
}
return p;
}
void getrev(int l,int L) {rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));}
void NTT(int *a,int n,int f)
{
rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < n;i <<= 1)
{
int w1 = qpow(f ? G : invG,(mod-1) / (i<<1));
for(int j = 0;j < n;j += (i<<1))
{
int w = 1;
rep(k,0,i-1)
{
int kx = a[k+j],ky = mul(a[k+j+i],w);
a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky),w = mul(w,w1);
}
}
}
if(!f)
{
int inv = qpow(n,mod-2);
rep(i,0,n-1) a[i] = mul(a[i],inv);
}
}
void getinv(int *a,int *b,int len)
{
if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
getinv(a,b,(len+1) >> 1);
int l = 1,L = 0;
while(l < (len<<1)) l <<= 1,L++;
getrev(l,L);
rep(i,0,len-1) C[i] = a[i];
rep(i,len,l-1) C[i] = 0;
NTT(C,l,1),NTT(b,l,1);
rep(i,0,l-1) b[i] = mul(inc(2,mod-mul(C[i],b[i])),b[i]);
NTT(b,l,0);
rep(i,len,l-1) b[i] = 0;
}
/*
void getsqrt(int *a,int *b,int len)
{
if(len == 1) {b[0] = 1;return;}
getsqrt(a,b,(len+1)>>1);
rep(i,0,len<<1) F[i] = 0;
getinv(b,F,len);
int l = 1,L = 0;
while(l < len << 1) l <<= 1,L++;
getrev(l,L);
rep(i,0,len-1) D[i] = a[i];
rep(i,len,l-1) D[i] = 0;
NTT(D,l,1),NTT(b,l,1),NTT(F,l,1);
rep(i,0,l-1) b[i] = mul(inc(b[i],mul(D[i],F[i])),inv2);
NTT(b,l,0);
rep(i,len,l-1) b[i] = 0;
}
*/
void getsqrt(int *a,int *b,int len)
{
if(len == 1) {b[0] = 1;return;}
getsqrt(a,b,(len+1)>>1);
rep(i,0,len<<1) F[i] = 0;
getinv(b,F,len);
int l = 1,L = 0;
while(l < len<<1) l <<= 1,L++;
getrev(l,L);
NTT(b,l,1);
rep(i,0,l-1) b[i] = mul(b[i],b[i]);
NTT(b,l,0);
rep(i,0,len-1) b[i] = inc(b[i],a[i]);
NTT(b,l,1),NTT(F,l,1);
rep(i,0,l-1) b[i] = mul(mul(b[i],F[i]),inv2);
NTT(b,l,0);
rep(i,len,l-1) b[i] = 0;
}
int main()
{
n = read(),inv2 = qpow(2,mod-2);
rep(i,0,n-1) A[i] = read();
getsqrt(A,B,n);
rep(i,0,n-1) printf("%d ",B[i]);enter;
return 0;
}
另一种做法是用(ln)和(exp),它可以推广到任意次幂,不过常数会很大。
对于(G(x)^k equiv F(x)),我们把式子变形可以得到:(F(x) equiv e^{kln(G(x))})
于是我们就可以用(ln)和(exp)来解决这个问题了……
(代码暂时咕了,可能某些时候补上)
6.多项式除法与取模
这玩意的做法很神奇。
给定一个多项式(A(x))和一个多项式(B(x)),求多项式(D(x))和(R(x)),使得(A(x) = B(x)D(x) + R(x))
首先我们要想办法消除(R(x))的影响。如何消除呢……
我们假设A有n项,B有m项,且(m < n),那么显然D应该有n-m项,R有m-1项。
我们引入一个神奇的操作:将所有的(x)用(frac{1}{x})来替代.两边再同时乘以(x^n)
容易发现,我们相当于将多项式的系数进行了反转。原来的式子变成了这样:
(x^nA(frac{1}{x}) = x^mB(frac{1}{x})x^{n-m}D(frac{1}{x}) + x^{n-m+1}x^{m-1}R(frac{1}{x}))
我们定义(A^r(x) = x^nA(frac{1}{x})),注意这个n对于每个多项式是不同的,指的是多项式自身的项数。
于是就有(A^r(x) = B^r(x)D^r(x) + x^{n-m+1}R^r(x))
我们发现,这个式子在(mod (n-m))意义下,(R(x))的影响就会被消除。而(D(x))在反转后,次数仍然不高于(n-m),所以我们有(A^r(x) equiv B^r(x)D^r(x) (mod x^{n-m+1}))
然后求一下在模意义下B的逆元,倒着推回去就能求出(D(x))和(R(x))了。
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define I inline
using namespace std;
typedef long long ll;
const int M = 800005;
const int INF = 200000000;
const int mod = 998244353;
const int T = 3;
const int invT = 332748118;
int read()
{
int ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
int n,m,F[M],G[M],rev[M],Q[M],R[M],Fr[M],Gr[M],Gi[M],c[M];
int add(int a,int b){return (a+b) % mod;}
int mul(int a,int b){return 1ll * a * b % mod;}
int qpow(int a,int b)
{
int p = 1;
while(b)
{
if(b&1) p = mul(p,a);
a = mul(a,a),b >>= 1;
}
return p;
}
void NTT(int *a,int n,int f)
{
rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < n;i <<= 1)
{
int w1 = qpow(f ? T : invT,(mod-1) / (i<<1));
for(int j = 0;j < n;j += (i<<1))
{
int w = 1;
rep(k,0,i-1)
{
int kx = a[k+j],ky = mul(a[k+j+i],w);
a[k+j] = add(kx,ky),a[k+j+i] = add(kx,mod-ky),w = mul(w,w1);
}
}
}
if(!f)
{
int inv = qpow(n,mod-2);
rep(i,0,n-1) a[i] = mul(a[i],inv);
}
}
void getrev(int n,int L)
{
rep(i,0,n-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
}
void getinv(int *a,int *b,int len)
{
if(len == 1) {b[0] = qpow(a[0],mod-2);return;}
getinv(a,b,(len+1)>>1);
int l = 1,L = 0;
while(l < (len << 1)) l <<= 1,L++;
getrev(l,L);
rep(i,0,len-1) c[i] = a[i];
rep(i,len,l-1) c[i] = 0;
NTT(c,l,1),NTT(b,l,1);
rep(i,0,l-1) b[i] = mul(add(2,mod-mul(c[i],b[i])),b[i]);
NTT(b,l,0);
rep(i,len,l-1) b[i] = 0;
}
int main()
{
n = read(),m = read();
rep(i,0,n) F[i] = read(),Fr[n-i] = F[i];
rep(i,0,m) G[i] = read(),Gr[m-i] = G[i];
rep(i,n-m+2,m) Gr[i] = 0;
getinv(Gr,Gi,n-m+1);
int l = 1,L = 0;
while(l <= (n<<1)) l <<= 1,L++;
getrev(l,L);
NTT(Fr,l,1),NTT(Gi,l,1);
rep(i,0,l-1) Q[i] = mul(Fr[i],Gi[i]);
NTT(Q,l,0);reverse(Q,Q+n-m+1);
rep(i,n-m+1,n) Q[i] = 0;
rep(i,0,n-m) printf("%d ",Q[i]);enter;
l = 1,L = 0;
while(l <= (n << 1)) l <<= 1,L++;
getrev(l,L);
NTT(G,l,1),NTT(Q,l,1);
rep(i,0,l-1) G[i] = mul(G[i],Q[i]);
NTT(G,l,0);
rep(i,0,l-1) R[i] = add(F[i],mod-G[i]);
rep(i,0,m-1) printf("%d ",R[i]);enter;
return 0;
}
7.任意模数NTT。
这个不知道会不会有毒瘤出题人这么出……就是他给你的模数不是NTT模数……
MTT……?我好像不大会。我只会三模NTT,这种做法好像被(Shadowice1984)大佬疯狂批评,不过我暂时还不会别的哎……只好先写这个了。
这其实是比较投机取巧的做法,因为每个数不超过(1e9),所以我们可以选三个大的NTT模数,在每个模数的模意义下算出答案,之后用CRT合并。
咋子合并……假设答案分别为(x_1,x_2,x_3),模数分别为(A,B,C)
就有如下方程:
先合并前两个。
求出(k_1)之后,令(x_4 = x_1 + k_1A),继续合并:
得到(k_4),就知道(x = x_4 + k_4AB (mod ABC)),因为(x < ABC),所以就有(x = x_4 + k_4AB (mod p))
直接算完答案合并就可以。注意超出longlong的时候要用那个神奇的合并方式。
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define I inline
using namespace std;
typedef long long ll;
const int M = 800005;
const int INF = 200000000;
const ll mod1 = 998244353,mod2 = 469762049,mod3 = 1004535809;
const ll T = 3;
ll read()
{
ll ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
ll n,m,p,F[M],G[M],rev[M],A[M],B[M],C[M],D[M];
ll add(ll a,ll b,ll mod){return (a+b) % mod;}
ll mul(ll a,ll b,ll mod){return ((a*b-(ll)((long double)a/mod*b+1e-8)*mod)+mod)%mod;}
ll qpow(ll a,ll b,ll mod)
{
a %= mod;
ll t = 1;
while(b)
{
if(b&1) t = mul(t,a,mod);
a = mul(a,a,mod),b >>= 1;
}
return t;
}
void getrev(int l,int L){rep(i,0,l-1) rev[i] = (rev[i>>1]>>1) | ((i&1) << (L-1));}
void NTT(ll *a,ll n,ll f,ll mod)
{
rep(i,0,n-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
ll invT = qpow(T,mod-2,mod);
for(int i = 1;i < n;i <<= 1)
{
ll w1 = qpow(f? T : invT,(mod-1) / (i<<1),mod);
for(int j = 0;j < n;j += (i<<1))
{
ll w = 1;
rep(k,0,i-1)
{
ll kx = a[k+j]%mod,ky = mul(a[k+j+i],w,mod);
a[k+j] = add(kx,ky,mod),a[k+j+i] = add(kx,mod-ky,mod),w = mul(w,w1,mod);
}
}
}
if(!f)
{
ll inv = qpow(n,mod-2,mod);
//reverse(a+1,a+n);
rep(i,0,n-1) a[i] = mul(a[i],inv,mod);
}
}
void Polymul(ll *a,ll *b,ll l,ll mod)
{
NTT(a,l,1,mod),NTT(b,l,1,mod);
rep(i,0,l-1) a[i] = mul(a[i],b[i],mod);
NTT(a,l,0,mod);
}
void merge(ll *a,ll *b,ll *c,ll l,ll moda,ll modb,ll modc)
{
ll inv = qpow(moda,modb-2,modb);
rep(i,0,l-1)
{
ll k1 = add(b[i],modb-(a[i]%modb),modb);
k1 = mul(k1,inv,modb);
a[i] = add(a[i],k1*moda,moda*modb);
}
inv = qpow(moda*modb,modc-2,modc);
rep(i,0,l-1)
{
ll k2 = add(c[i],modc-(a[i]%modc),modc);
k2 = mul(k2,inv,modc);
k2 = mul(k2,moda*modb,p);
a[i] = add(a[i],k2,p);
}
}
int main()
{
n = read(),m = read(),p = read();
rep(i,0,n) A[i] = read(),A[i] %= p,C[i] = F[i] = A[i];
rep(i,0,m) B[i] = read(),B[i] %= p,D[i] = G[i] = B[i];
int l = 1,L = 0;
while(l <= n+m) l <<= 1,L++;
getrev(l,L);
Polymul(A,B,l,mod1);
Polymul(C,D,l,mod2);
Polymul(F,G,l,mod3);
merge(A,C,F,n+m+1,mod1,mod2,mod3);
rep(i,0,n+m) printf("%lld ",A[i]);enter;
return 0;
}
8.分治FFT
这玩意还是挺有用的。一般用于:已知(G(x)),求(F[i] = sum_{j=1}^{i}F[i-j]G[j]),F[0] = 1.
这个怎么求呢?直接做应该是不行的……会超时。我们考虑分治,考虑一下式子的左半边计算出结果之后对于右半边的贡献,在计算之前把这些贡献加上即可,也就是一个CDQ分治套FFT。
假设我们已经求出(l-mid)的答案,对于(mid-r)之中的一点x,其所获得的贡献为:(w_x = sum_{i=l}^{mid}f[i]g[x-i])
所以我们做一遍CDQ套FFT就可以解决了。注意边界问题。
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')
#define lowbit(x) (x & (-x))
using namespace std;
typedef long long ll;
const int M = 400005;
const ll mod = 998244353;
const ll G = 3;
const ll invG = 332748118;
ll read()
{
ll ans = 0,op = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
return ans * op;
}
ll n,g[M],f[M],r[M],a[M],b[M],rev[M];
ll inc(ll a,ll b) {return (a + b) % mod;}
ll mul(ll a,ll b) {return a * b % mod;}
ll qpow(ll a,ll b)
{
ll p = 1;
while(b)
{
if(b & 1) p = mul(p,a);
a = mul(a,a),b >>= 1;
}
return p;
}
void NTT(ll *a,ll l,ll f)
{
rep(i,0,l-1) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i < l;i <<= 1)
{
ll w1 = qpow(f ? G : invG,(mod - 1) / (i<<1));
for(int j = 0;j < l;j += (i<<1))
{
ll w = 1;
rep(k,0,i-1)
{
ll kx = a[k+j],ky = mul(a[k+j+i],w);
a[k+j] = inc(kx,ky),a[k+j+i] = inc(kx,mod-ky);
w = mul(w,w1);
}
}
}
if(!f)
{
ll inv = qpow(l,mod-2);
rep(i,0,l-1) a[i] = mul(a[i],inv);
}
}
void CDQ(int kl,int kr)
{
if(kl == kr) return;
int mid = (kl + kr) >> 1;
CDQ(kl,mid);
int l = 1,L = 0;
while(l < (kr - kl + 1) << 1) l <<= 1,L++;
rep(i,0,l-1) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
rep(i,0,l-1) a[i] = b[i] = 0;
rep(i,kl,mid) a[i-kl] = f[i];
rep(i,0,kr-kl) b[i] = g[i];
NTT(a,l,1),NTT(b,l,1);
rep(i,0,l-1) a[i] = mul(a[i],b[i]);
NTT(a,l,0);
rep(i,mid+1,kr) f[i] = inc(f[i],a[i-kl]);
CDQ(mid+1,kr);
}
int main()
{
n = read(),f[0] = 1;
rep(i,1,n-1) g[i] = read();
CDQ(0,n-1);
rep(i,0,n-1) printf("%lld ",f[i]);enter;
return 0;
}