Loj #2554. 「CTSC2018」青蕈领主
题目描述
“也许,我的生命也已经如同风中残烛了吧。”小绿如是说。
小绿同学因为微积分这门课,对“连续”这一概念产生了浓厚的兴趣。小绿打算把连续的概念放到由整数构成的序列上,他定义一个长度为 (m) 的整数序列是连续的,当且仅当这个序列中的最大值与最小值的差,不超过(m-1)。例如 ({1,3,2}) 是连续的,而 ({1,3}) 不是连续的。
某天,小绿的顶头上司板老大,给了小绿 (T) 个长度为 (n) 的排列。小绿拿到之后十分欢喜,他求出了每个排列的每个区间是否是他所定义的“连续”的。然而,小绿觉得被别的“连续”区间包含住的“连续”区间不够优秀,于是对于每个排列的所有右端点相同的“连续”区间,他只记录下了长度最长的那个“连续”区间的长度。也就是说,对于板老大给他的每一个排列,他都只记录下了在这个排列中,对于每一个 (1 le i le n),右端点为 (i) 的最长“连续”区间的长度 (L_i)。显然这个长度最少为 (1),因为所有长度为 (1) 的整数序列都是连续的。
做完这一切后,小绿爬上绿色床,美美地做了一个绿色的梦。
可是第二天醒来之后,小绿惊讶的发现板老大给他的所有排列都不见了,只剩下他记录下来的 (T) 组信息。小绿知道自己在劫难逃,但是作为一个好奇的青年,他还是想知道:对于每一组信息,有多少个和信息符合的长度为 (n) 的排列。
由于小绿已经放弃治疗了,你只需要告诉他每一个答案对 (998244353) 取模的结果。
我们并不保证一定存在至少一个符合信息的排列,因为小绿也是人,他也有可能犯错。
输入格式
输入的第一行包含两个整数 (T,n),分别表示板老大给小绿的排列个数、以及每个排列的长度。
接下来 (T) 行,每行描述一组信息,包含 (n) 个正整数,第 (i) 组信息的从左往右第 (j) 个整数 (L_{i,j}) 表示第 (i) 个排列中右端点为第 (j) 个数的最长“连续”区间的长度。
对于每一行,如果行内包含多个数,则用单个空格将它们隔开。
输出格式
对于每组信息,输出一行一个整数表示可能的排列个数对 (998244353) 取模的结果。由于是计算机帮你算,所以我们不给你犯错的机会。
数据范围与提示
对于所有测试数据,(1 le T le 100),(1 le N le 50000), (1 le L_{i,j} le j)。
首先我们得到很多段区间,这些区间要么相离,要么包含,并且一定有一个([1,n])。就用这两个条件来判断无解。然后我们可以将这些区间建成一个树。
考虑区间([l,r]),他有(k)个儿子,于是我们要将一段长为(r-l+1)的连续区间分配个这(k)个儿子。首先这(k)个儿子每一个都是连续的一段,并且相邻的儿子不能组成连续的一段。但是考虑放在(r)位置上的那个点,他与所有儿子共同组成了连续的一段。如果说我们把每个儿子看做一个点而,再把(r)也看做一个点,那么合法条件就是:一个(k+1)的排列,不能存在不包含最后一个位置的长度(>1)的连续区间。
设该答案为(f_k),那么:
特别地,(f_0=1,f_1=2)。
我们设一个合法数列为(A),再令(b_{a_i}=i)。很容易发现(A)与(B)是唯一映射的。(A)的合法条件在(B)数组中等价为不能存在不包含最大那个元素的长度(geq 2)的连续区间。
考虑从大到小插入每一个数。已经插入了([2,n+1]),现在要插入(1)。如果原来的序列已经合法,那么(1)只要不与(2)相邻,这个数列依旧合法。这样就还有(n-1)个位置可以插入,所以有((n-1)f_{n-1})。如果原来的数列不合法,那么我们要插入(1)破坏那个长度(>1)的连续区间。很显然,不相交的连续区间最多有一个,不然插入一个(1)解决不了问题。于是我们枚举最大的那个非法区间的长度,设其为(j),则(2leq jleq n-2)(首先至少有(2)两个元素,并且不能让最大的那个元素单独存在,否则就不是最长的了)。假设非法区间的元素为([xldots x+j-1]),那么(x)有([2,n-j])这(n-j-1)种方案,所以乘上系数(n-1-j)。考虑将(1)插入其中,等价于将(j+1)插入([1,j])中,所以合法方案数为(f_j)。考虑这段连续区间(插入(1)之前)必须是极长,所以如果将这段连续区间视为一个点,那么就有还有(n-j)个元素,合法的排列方案数是(f_{n-j})。这部分贡献为
发现这个方程可以用分治(FFT)优化,不过这个写法有点巧妙。
考虑分治(FFT)的原理是递归区间([l,r])的时候将区间分为了两半,考虑计算左半边对右半边的贡献。于是我们发现,计算([l,mid])对([mid+1,r])的贡献时,我们要用到(f_{2ldots r-l}),但是可能(mid<r-l)。我们考虑(i,j(i<j)),显然只会在递归到某一个区间([l,r])的时候才会计算(i)对(j)的贡献。假设这个区间是([l,mid,r]),如果(lleq r-lleq mid),那么用(f_{lldots mid})自己卷自己就行了。如果(r-l<l),那么我们用(f_{lldots mid})卷(f_{2ldots min{r-l,l-1}})就好了。因为:
所以对于计算了(r-l<l)的部分后乘上(i-2)就可以一并计算(r-l>mid)的部分了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 50005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
const ll mod=998244353;
ll ksm(ll t,ll x) {
ll ans=1;
for(;x;x>>=1,t=t*t%mod)
if(x&1) ans=ans*t%mod;
return ans;
}
int n;
int p[N];
void NTT(ll *a,int d,int flag) {
static int rev[N<<2];
static ll G=3;
int n=1<<d;
for(int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int s=1;s<=d;s++) {
int len=1<<s,mid=len>>1;
ll w=flag==1?ksm(G,(mod-1)/len):ksm(G,mod-1-(mod-1)/len);
for(int i=0;i<n;i+=len) {
ll t=1;
for(int j=0;j<mid;j++,t=t*w%mod) {
ll u=a[i+j],v=a[i+j+mid]*t%mod;
a[i+j]=(u+v)%mod;
a[i+j+mid]=(u-v+mod)%mod;
}
}
}
if(flag==-1) {
ll inv=ksm(n,mod-2);
for(int i=0;i<n;i++) a[i]=a[i]*inv%mod;
}
}
ll f[N];
ll A[N<<2],B[N<<2];
void solve(int l,int r) {
if(l==r) {
(f[l]+=f[l-1]*(l-1))%=mod;
return ;
}
int mid=l+r>>1;
int d=ceil(log2(r-l+1));
solve(l,mid);
for(int i=0;i<1<<d;i++) A[i]=B[i]=0;
for(int i=l;i<=mid;i++) {
A[i-l]=(i-1)*f[i]%mod;
B[i-l]=f[i];
}
NTT(A,d,1),NTT(B,d,1);
for(int i=0;i<1<<d;i++) A[i]=A[i]*B[i]%mod;
NTT(A,d,-1);
for(int i=0;i<1<<d;i++)
if(mid<i+2*l&&i+2*l<=r) (f[i+2*l]+=A[i])%=mod;
int len=min(r-l,l-1);
d=ceil(log2(len+mid-l+1));
for(int i=0;i<1<<d;i++) A[i]=B[i]=0;
for(int i=l;i<=mid;i++) A[i-l]=f[i];
for(int i=2;i<=len;i++) B[i-2]=f[i];
NTT(A,d,1),NTT(B,d,1);
for(int i=0;i<1<<d;i++) A[i]=A[i]*B[i]%mod;
NTT(A,d,-1);
for(int i=0;i<1<<d;i++)
if(mid<i+2+l&&i+2+l<=r) (f[i+2+l]+=(i+l)*A[i])%=mod;
solve(mid+1,r);
}
int st[N],top;
int sn[N];
int main() {
int T=Get();
n=Get();
f[0]=1,f[1]=2;
if(n-1>2) solve(2,n-1);
while(T--) {
for(int i=1;i<=n;i++) p[i]=i-Get()+1;
for(int i=1;i<=n;i++) sn[i]=0;
if(p[n]!=1) {
cout<<0<<"
";
continue ;
}
int flag=0;
st[top=1]=n;
for(int i=n-1;i>=1;i--) {
while(p[st[top]]>i) top--;
sn[st[top]]++;
if(p[st[top]]>p[i]) {
flag=1;
break;
}
st[++top]=i;
}
if(flag) {
cout<<0<<"
";
} else {
ll ans=1;
for(int i=1;i<=n;i++) ans=ans*f[sn[i]]%mod;
cout<<ans<<"
";
}
}
return 0;
}