首先涂区间,那么区间最多有 $2n$ 个相邻位置不同的情况,并且连续相同的颜色可以合并起来
那么这样操作完以后,区间长度最多为 $2n$
发现涂完一段区间以后其他的操作都不能出现一边在区间内而另一边在区间外的情况
又因为区间长度 $n<=1000$ ,时间 $6$ 秒,考虑一下不满的 $n^3$ 的区间 $dp$
设 $f[i][j]$ 表示把区间 $[i,j]$ 涂成最终状态的方案数
设 $mi[i][j]$ 表示区间 $[i,j]$ 内最小的颜色编号,$L[i]$ 表示颜色 $i$ 最左边的位置,$R[i]$ 表示颜色 $i$ 最右边的位置
设 $p=mi[i][j]$
那么枚举最小的颜色涂的区间为 $[l,r]$,显然 $l<=L[p],r>=R[p]$,发现 $l,r$ 把区间 $i,j$ 分成了 $4$ 个部分:$[i,l-1],[l,L[p]-1],[R[p]+1,r],[r,j]$
哦,对了,还有 $[L[p],R[p]]$ 中间的几个部分,中间这一段被颜色 $p$ 分成了很多块,每一块内部也是独立的,设中间这些块的方案数为 $sum$
有 $f[i][j]=sum_{l=i}^{L[p]}sum_{r=R[p]}^{j}f[i][l-1]f[l][L[p]-1]f[R[p]+1][r]f[r+1][j] cdot sum$
然后因为 $l,r$ 是独立的,所以分别计算即可做到 $n^3$ ,求 $sum$ 也只要预处理一下 $nxt[i]$ 表示位置 $i$ 下一个同颜色的位置即可
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; typedef long long ll; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } const int N=1007,M=2e6+7,mo=998244353; inline int fk(int x) { return x>=mo ? x-mo : x; } int n,m,a[M],b[M],tot; int L[N],R[N],mi[N][N],f[N][N]; int pre[N],nxt[N]; int main() { n=read(),m=read(); for(int i=1;i<=m;i++) a[i]=read(); for(int i=1;i<=m;i++) if(a[i]!=a[i-1]) b[++tot]=a[i]; if(tot>n*2) { printf("0 "); return 0; } m=tot; for(int i=1;i<=m;i++) a[i]=b[i]; for(int i=1;i<=m;i++) { if(pre[a[i]]) nxt[ pre[a[i]] ]=i; pre[a[i]]=i; } for(int i=1;i<=n;i++) nxt[ pre[a[i]] ]=m+1; memset(L,0x3f,sizeof(L)); memset(mi,0x3f,sizeof(mi)); for(int i=1;i<=m;i++) L[a[i]]=min(L[a[i]],i),R[a[i]]=max(R[a[i]],i); for(int i=1;i<=m;i++) for(int j=i;j<=m;j++) for(int k=i;k<=j;k++) mi[i][j]=min(mi[i][j],a[k]); for(int i=0;i<=m+1;i++) { if( i>=1&&i<=m && L[mi[i][i]]==i && R[mi[i][i]]==i ) f[i][i]=1; for(int j=i+1;j<=m+1;j++) f[j][i]=1; for(int j=0;j<=i-1;j++) f[i][j]=1; } for(int k=1;k<m;k++) for(int i=1;i+k<=m;i++) { int p=mi[i][i+k]; if(p>N) continue; if(L[p]<i||R[p]>i+k) continue; int cntl=0,cntr=0,t=1; for(int j=i;j<=L[p];j++) cntl=fk(cntl+1ll*f[i][j-1]*f[j][L[p]-1]%mo); for(int j=R[p];j<=i+k;j++) cntr=fk(cntr+1ll*f[R[p]+1][j]*f[j+1][i+k]%mo); for(int j=L[p];j<R[p];j=nxt[j]) t=1ll*t*f[j+1][nxt[j]-1]%mo; f[i][i+k]=1ll*cntl*cntr%mo*t%mo; // cout<<i<<" "<<i+k<<" "<<f[i][i+k]<<endl; } printf("%d ",f[1][m]); return 0; }