多项式与点值式
正常( ext{DFT/IDFT})是构造一个特殊的点值式,即(x_i=omega_{n}^i)
如果能通过题目条件构造出来这样的点值,就可以直接( ext{DFT/IDFT})
那如果不能的话。。。。。
多项式多点求值
一个多项式(F(x))我们求它在(x_0,x_0,cdots x_{m-1})上的点值
核心是分治+多项式取模,因此常数很大
对于当前分治区间([l,r]in[0,m-1])
需要快速构造一个长度为(frac{r-l+1}{2})的等价多项式进入分治区间
令(G_{l,r}(x)=prod_l^r(1-x_i))
由于(G_{l,r(x_l)}=cdots=G_{l,r}(x_r)=0)
所以可以将(F(x))对于(G_{l,mid}(x))和(G_{mid+1,r}(x))分别取模之后得到两个等价式
递归到([l=r])时,(F(x))只剩下常数项
需要被访问的(G(x))可以预先跑一遍分治NTT求出
那么复杂度就是(O(nlog ^2n))
这种做法代码实现困难,而且常数非常大
多项式快速插值
对于点对((x_i,y_i))
多项式拉格朗日插值的式子是
那么需要快速求出(prod frac{1}{x_i-x_j})
构造多项式(G(x)=prod (x-x_i))
那么(prod (x_i-x_j)=frac{G}{x-x_i}(x_i))
由于(G(x),x-x_i)在(x_i)上的点值均为(0)
我们要求的多项式就是(egin{aligned} prod_{i e j} (x_i-x_j) end{aligned}=frac{G(x)}{x-x_i})
即求出(frac{G}{x-x_i}(x_i))
分母分子均为(0),所以带入洛必达法则(egin{aligned}frac{G}{x-x_i}(x_i)=frac{G'}{(x-x_i)'}(x_i)=G'(x_i)end{aligned})
那么求出(G'(x)),然后多项式多点求值即可
剩下那一部分的答案,可以简单地分治合并上来,([l=r])时,多项式是一个常数
合并上来时
([l,mid])的答案补上(prod_{mid+1}^r (x-x_i))
([mid+1,r])的答案补上(prod_{l}^{mid} (x-x_i))
即复杂度为(O(nlog ^2n))
垃圾模板题卡常
应用转置原理对于多点求值的优化
由于这个东西实在是太新了,所以没有什么文献可以看
关于转置原理的前置定义
矩阵的转置:
对于(ncdot m)的矩阵(M),它的转置(M^{T})为交换行列坐标后得到的(mcdot n)的矩阵
其满足运算性质:
1.逆: ({(A^T)}^T=A)
2.和:((A+B)^T=A^T+B^T)
3.反积:((AB)^T=B^TA^T)
初等矩阵:
初等矩阵是指单位矩阵通过初等变换(交换行列,某一行(列)乘上(k)加到另一行(列)上,类似高斯消元)得到的矩阵
对于计算(b=Acdot a),其中(A)为矩阵,(a,b)为列向量
考虑先计算(b'=A^Tcdot a)
出计算(b')的过程,这可以分解成若干步操作(或说是初等矩阵)(E_1,E_2,cdots E_k)
即(b'=E_1cdot E_2cdot E_3cdots E_kcdot a)
将(E_i)倒序执行,并且每一步都换成原先操作的转置(E_i^T),就能得到(Acdot a)
即(b=E^T_kcdot E^T_{k-1}cdots E^T_1cdot a)
应用转置原理的优化核心
如果把多项式系数视为列向量(F),则可以把多项式多点求值的过程视为一个矩阵运算(M)
为了便于描述,设要求的点值和多项式项数均为(n)
设要求的点横坐标为(x_i),则(M)是范德蒙德矩阵
(1) | (x_0^1) | (x_0^2) | ... |
---|---|---|---|
1 | (x_1^1) | (x_1^2) | ... |
1 | (x_2^1) | (x_2^2) | ... |
... |
分析会发现我们要求的实际上是(b=Mcdot F)(到底是谁对矩阵乘法有误解?)
现在来将问题转置,先假装求(b'=M^Tcdot F)
(1) | 1 | 1 | ... |
---|---|---|---|
(x_0^1) | (x_1^1) | (x_2^1) | ... |
(x_0^2) | (x_1^2) | (x_2^2) | ... |
... |
实际(M^Tcdot F)得到的结果用形式幂级数表示是
(displaystylesum F_isum_{j=0}^{n-1}x_i^jequiv sum frac{F_i}{1-x_ix}pmod {x^n})
求(displaystyle M^Tcdot F= sum frac{F_i}{1-x_ix}pmod {x^n})
可以用两次分治 ( ext{NTT}) 解决,大致过程可以描述为
1.将问题转化为求$egin{aligned} frac{sum F_iprod _{i e j}{(1-x_jx)}}{prod (1-x_ix)}end{aligned} $
2.对于分治节点([L,R]),求得(T(L,R)=prod_{i=L}^R{(1-x_i)})
3.从下往上合并,每次合并答案为(A(L,R)=A(L,mid)cdot T(mid+1,R)+A(mid+1,R)cdot T(L,mid))
4.最后将答案(A(0,n-1))除以(prod(1-x_ix))
然后我们考虑把所有的操作都反过来并且换成转置,求得(Mcdot F)
因为过程中涉及到多项式卷积,设其转置运算为(oplus)
我们知道普通的多项式卷积为(F(x)cdot G(x)=sum_isum_j [x^i]F(x)[x^j]G(x)x^{i+j})
则其转置为(mul^T(F(x),G(x))=F(x)oplus G(x)=sum_isum_{jleq i} [x^i]F(x)[x^j]G(x)x^{i-j})
可以看到这个操作会导致多项式项数降低,若原先(F(x),G(x))最高项为(n,m),则转置卷积后最高项为(n-m)
那么给出整个转置后的过程为
1.在(F(x))后面加上若干个(0),求出(displaystyle A(0,n-1)=F(x) oplus frac{1}{prod(1-x_ix)})的前(n)项
2.对于每个分治节点,依然预处理(displaystyle T(L,R)=prod_{i=L}^R{(1-x_ix)})
3.从顶向下递归,向子节点下传
(A(L,mid)= A(L,R)oplus T(mid+1,R))
(A(mid+1,R)= A(L,R)oplus T(L,mid))
递归到子节点时,只剩一项,即是每一个点值
关于这个优化的效果:
1.不需要写多项式除法和取模了!
2.第二次分治的过程中调用的(mul^T)长度短一倍
下面这份代码是优化过的版本,能快一倍左右,但关键还是代码短听说可以被卡常卡到1s跑1e6
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define pb push_back
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)
template <class T> inline void cmin(T &a,T b){ ((a>b)&&(a=b)); }
template <class T> inline void cmax(T &a,T b){ ((a<b)&&(a=b)); }
char IO;
template <class T=int> T rd(){
T s=0; int f=0;
while(!isdigit(IO=getchar())) if(IO=='-') f=1;
do s=(s<<1)+(s<<3)+(IO^'0');
while(isdigit(IO=getchar()));
return f?-s:s;
}
const int N=1<<17,P=998244353;
typedef vector <int> V;
int n,m;
ll qpow(ll x,ll k=P-2) {
ll res=1;
for(;k;k>>=1,x=x*x%P) if(k&1) res=res*x%P;
return res;
}
int w[N],Inv[N+1],rev[N];
void Init(){
w[N/2]=1;
for(int t=qpow(3,(P-1)/N),i=N/2+1;i<N;++i) w[i]=1ll*w[i-1]*t%P;
drep(i,N/2-1,1) w[i]=w[i<<1];
Inv[0]=Inv[1]=1;
rep(i,2,N) Inv[i]=1ll*(P-P/i)*Inv[P%i]%P;
}
int Init(int n) {
int R=1,cc=-1;
while(R<n) R<<=1,cc++;
rep(i,1,R-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<cc);
return R;
}
void NTT(int n,V &A,int f){
ull a[N];
if((int)A.size()<n) A.resize(n);
rep(i,0,n-1) a[i]=A[rev[i]];
for(int i=1;i<n;i<<=1) {
int *e=w+i;
for(int l=0;l<n;l+=i*2) {
for(int j=l;j<l+i;++j) {
int t=a[j+i]*e[j-l]%P;
a[j+i]=a[j]+P-t;
a[j]+=t;
}
}
}
rep(i,0,n-1) A[i]=a[i]%P;
if(f==-1) {
reverse(A.begin()+1,A.end());
rep(i,0,n-1) A[i]=1ll*A[i]*Inv[n]%P;
}
}
V operator ~ (V F) {
int n=F.size();
if(n==1) return V{(int)qpow(F[0])};
V G=F; G.resize((n+1)/2),G=~G;
int R=Init(n*2);
NTT(R,F,1),NTT(R,G,1);
rep(i,0,R-1) F[i]=(2-1ll*F[i]*G[i]%P+P)*G[i]%P;
NTT(R,F,-1),F.resize(n);
return F;
}
V operator * (V A,V B) {
int n=A.size()+B.size()-1,R=Init(n);
NTT(R,A,1),NTT(R,B,1);
rep(i,0,R-1) A[i]=1ll*A[i]*B[i]%P;
NTT(R,A,-1),A.resize(n);
return A;
}
V Evaluate(V F,V X){
static int ls[N<<1],rs[N<<1],cnt;
static V T[N<<1];
static auto TMul=[&](V F,V G){
reverse(G.begin(),G.end());
int n=F.size(),m=G.size(),R=Init(n);
NTT(R,F,1),NTT(R,G,1);
rep(i,0,R-1) F[i]=1ll*F[i]*G[i]%P;
NTT(R,F,-1); V T(n-m+1);
rep(i,0,n-m) T[i]=F[i+m-1];
return T;
};
static function <int(int,int)> Build=[&](int l,int r) {
int u=++cnt; ls[u]=rs[u]=0;
if(l==r) {
T[u]=V{1,P-X[l]};
return u;
}
int mid=(l+r)>>1;
ls[u]=Build(l,mid),rs[u]=Build(mid+1,r);
T[u]=T[ls[u]]*T[rs[u]];
return u;
};
int n=F.size(),m=X.size();
cmax(n,m),F.resize(n),X.resize(n);
cnt=0,Build(0,n-1);
F.resize(n*2+1),T[1]=TMul(F,~T[1]);
int p=0;
rep(i,1,cnt) if(ls[i]) {
swap(T[ls[i]],T[rs[i]]);
int R=Init(T[i].size()),n=T[i].size(),m1=T[ls[i]].size(),m2=T[rs[i]].size();
NTT(R,T[i],1);
reverse(T[ls[i]].begin(),T[ls[i]].end()); reverse(T[rs[i]].begin(),T[rs[i]].end());
NTT(R,T[ls[i]],1); NTT(R,T[rs[i]],1);
rep(j,0,R-1) {
T[ls[i]][j]=1ll*T[ls[i]][j]*T[i][j]%P;
T[rs[i]][j]=1ll*T[rs[i]][j]*T[i][j]%P;
}
NTT(R,T[ls[i]],-1); NTT(R,T[rs[i]],-1);
rep(j,0,n-m1) T[ls[i]][j]=T[ls[i]][j+m1-1];
T[ls[i]].resize(n-m1+1);
rep(j,0,n-m2) T[rs[i]][j]=T[rs[i]][j+m2-1];
T[rs[i]].resize(n-m2+1);
//T[ls[i]]=TMul(T[i],T[ls[i]]); T[rs[i]]=TMul(T[i],T[rs[i]]);
} else X[p++]=T[i][0];
X.resize(m);
return X;
}
int main(){
Init(),n=rd(),m=rd();
V F(n+1),X(m);
rep(i,0,n) F[i]=rd();
rep(i,0,m-1) X[i]=rd();
V Res=Evaluate(F,X);
for(int i:Res) printf("%d
",i);
}