以下只讨论这种一般的"我卷我自己"形式:
如果式子是(f_n=sumlimits_{i=1}^{n-1}f_ig_{n-i})的话,就是普通的分治FFT了。
普通分治FFT做法是:solve(l,r)时,先行递归计算左半区间solve(l,mid),得到(f_{l,,,mid})的准确值之后,卷上(g_{1,,,r-l}),这样把[l,mid]的权值贡献给右半区间[mid+1,r],然后递归计算右半区间solve(mid+1,r)。
我们把这个贡献方式记为:[l,mid]*[1,r-l]贡献给右区间。
在这里,(g_{n-i})都是预先给出的。但是如果是我卷我自己形式的话,(f_{n-i})也是要求在计算过程中求出的。
那么就有可能会碰到一个问题,当递归完左半区间打算计算其对右半区间的贡献时,我们可能并没有完全得到(f_{1,,,r-l})的准确值。
这里就需要根据(l)的取值分成两种情况讨论:
-
(l eq 1)
此时的(l)一定是形如([(mid=(1+r)/2)+1,r])的区间,即([1,r])的右半区间,或者从这类区间递归而来。递归过程中要么(l)增大,要么(r)减小,所以一定会有(lge r-l)成立。
这意味着solve(l,mid)之后(f_{l,,,mid})和(f_{1,,,r-l})的准确值都是已经求好了的。
那么这种情况下贡献直接照普通分治FFT的方式计算即可。
这里还有一个性质是[l,mid]和[1,r-l]无重叠部分,端点忽略不计(???)
-
(l=1)
这时(1ge r-1)是不一定成立的。
把(f_{1,,,r-1})拆成两部分看:
-
(f_{1,,,mid})部分。这一部分是在solve(1,mid)的时候已经算出来了的,所以同样可以直接按照普通分治FFT做法,(f_{1,mid})和(f_{1,mid})卷一下贡献给右半区间。
-
(f_{mid+1,,,r-1})部分(这里就是与普通FFT不同之处)。因为刚递归计算完[1,mid],所以这一部分当然是还没计算出准确值的,值都是0。此时[1,mid]*[mid+1,r-1]这部分就无法贡献给右半区间。
但是注意到[mid+1,r-1]区间,其左端点是不等于1的。那么在计算完[mid+1,r-1]的准确值时,我们肯定已经知道[1,mid]的准确值,所以可以在那时再把贡献算上。
即本来是要在solve(1,mid)时[1,mid]*[mid+1,r-1]贡献给右半区间,现在变成solve(mid+1,r-1)时[1,mid]*[mid+1,r-1]贡献给右区间。这里[1,mid]*[mid+1,r-1]恰好就是solve(mid+1,r-1)时本来要计算的(case1中)[mid+1,r-1]*[1,mid] swap一下的样子。
-
总结一下,具体流程就是:
首先递归solve(l,mid)计算完[l,mid]的准确值之后,
若(l=1),则只做[1,mid]*[1,mid]的贡献。
若(l eq 1),先做[l,mid]*[1,r-l]的贡献,然后再做[1,r-l]*[l,mid]的贡献以补偿在(l=1)时的未计算量。
最后递归solve(mid+1,r)。
一道例题:
LOJ2554 「CTSC2018」青蕈领主
Code
#include<bits/stdc++.h>
using namespace std;
#define REP(i,a,b) for(int i=(a),_ed=(b);i<=_ed;++i)
#define DREP(i,a,b) for(int i=(a),_ed=(b);i>=_ed;--i)
#define mp(x,y) make_pair((x),(y))
#define sz(x) (int)(x).size()
#define pb push_back
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
inline int read(){
register int x=0,f=1;register char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch)){x=x*10+(ch^'0');ch=getchar();}
return f?x:-x;
}
const int P=998244353;
int n,L[50005],f[140005];
inline int power(int b,int n){int ans=1;for(;n;n>>=1,b=1ll*b*b%P)if(n&1)ans=1ll*ans*b%P;return ans;}
inline void inc(int& x,int y){x=x+y<P?x+y:x+y-P;}
inline void dec(int& x,int y){x=x-y>=0?x-y:x-y+P;}
int trans[140005],w[140005];
void NTT(int* a,int n){
static ull f[140005];
REP(i,0,n-1)f[i]=a[i];
REP(i,0,n-1)if(i<trans[i])swap(f[i],f[trans[i]]);
for(int len=2,d=1;len<=n;d=len,len<<=1)
for(int p=0;p<n;p+=len)
for(int i=p;i<p+d;++i){
int t=f[i+d]*w[d+i-p]%P;
f[i+d]=f[i]+P-t,f[i]+=t;
}
REP(i,0,n-1)a[i]=f[i]%P;
}
void times(int* f,int* a,int m1,int m2,int lim){
static int g[140005];
int n=1;for(;n<m1+m2-1;n<<=1);
REP(i,0,n-1)trans[i]=(trans[i>>1]>>1)|(i&1?(n>>1):0);
for(int len=2,d=1;len<=n;d=len,len<<=1){
int e=power(3,(P-1)/len);
REP(i,w[d]=1,d-1)w[d+i]=1ll*w[d+i-1]*e%P;
}
REP(i,m1,n-1)f[i]=0;REP(i,0,m2-1)g[i]=a[i];REP(i,m2,n-1)g[i]=0;
NTT(f,n),NTT(g,n);
REP(i,0,n-1)f[i]=1ll*f[i]*g[i]%P;
NTT(f,n);int inv=power(n,P-2);
reverse(f+1,f+n);
REP(i,0,lim-1)f[i]=1ll*f[i]*inv%P;
REP(i,lim,n-1)f[i]=0;
}
void solve(int l,int r){
static int A[140005],B[140005];
if(l==r)return inc(f[l],1ll*(l-1)*f[l-1]%P);
int mid=(l+r)>>1;
solve(l,mid);
if(l==1){
A[0]=A[1]=B[0]=B[1]=0;
REP(i,2,mid)A[i]=1ll*f[i]*(i-1)%P,B[i]=f[i];
times(A,B,mid+1,mid+1,r+1);
REP(i,mid+1,r)inc(f[i],A[i]);
}
else{
REP(i,l,mid)A[i-l]=1ll*(i-1)*f[i]%P;
B[0]=B[1]=0;
REP(i,2,r-l)B[i]=f[i];
times(A,B,mid-l+1,r-l+1,r-l+1);
REP(i,mid+1,r)inc(f[i],A[i-l]);
A[0]=A[1]=0;
REP(i,2,r-l)A[i]=1ll*(i-1)*f[i]%P;
REP(i,l,mid)B[i-l]=f[i];
times(A,B,r-l+1,mid-l+1,r-l+1);
REP(i,mid+1,r)inc(f[i],A[i-l]);
}
solve(mid+1,r);
}
int cal(int l,int r){
if(L[r]!=r-l+1)return 0;
int p=r-1,res=1,num=1;
while(p>=l)res=1ll*res*cal(p-L[p]+1,p)%P,p-=L[p],++num;
res=1ll*res*f[num-1]%P;
if(p<l-1)res=0;
return res;
}
int main(){
// freopen("in.in","r",stdin);
int T=read();n=read();
f[0]=1,f[1]=2;
solve(1,n-1);
REP(t,1,T){
REP(i,1,n)L[i]=read();
printf("%d
",cal(1,n));
}
return 0;
}