Bandit Blues
题目描述
解法
前置知识:长度为 \(n\) 的序列,前缀最大值个数为 \(m\) 的方案数是 \({n\brack m}\),因为它完全等价于把 \(n\) 个数划分成 \(m\) 个圆排列,每个圆排列以其最大值为代表,按最大值从小到大的顺序在原序列上排列。
回到本题,我们可以枚举 \(n\) 的位置,这样就能把限制拆解为:前 \(i-1\) 个位置前缀最大值是 \(a-1\),后 \(n-i-1\) 个位置后缀最大值是 \(b-1\)(翻转一下就变成了前缀最大值),那么答案是:
考虑用组合意义化简这个式子,发现这等价于:在长度为 \(n-1\) 的序列中划分出 \(a+b-2\) 个圆排列,再从 \(a+b-2\) 个圆排列中选取 \(a-1\) 个放在左边,\(b-1\) 个放在右边,那么答案是:
现在的问题变成了求单个第一类斯特林数,回忆其递推式:\({n\brack m}={n-1\brack m-1}+(n-1){n-1\brack m}\),只需要考虑转移路径,就发现 \({n\brack m}\) 可以用这样的生成函数表示出来:
直接分治 \(\tt NTT\) 即可,时间复杂度 \(O(n\log^2 n)\)
彩蛋
在写完这道题之后,我自信地猜测这道题的评分是 \(2900\),结果打开一看真是 \(2900\),感觉 \(\tt CF\) 题做多了都可以自己评难度了~
#include <cstdio>
#include <iostream>
using namespace std;
const int M = 400005;
#define int long long
const int MOD = 998244353;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,a,b,fac[M],inv[M],s[M],rev[M];
void init()
{
fac[0]=inv[0]=inv[1]=1;
for(int i=2;i<=n;i++) inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
for(int i=1;i<=n;i++) inv[i]=inv[i-1]*inv[i]%MOD;
for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%MOD;
}
int C(int n,int m)
{
if(n<m || m<0) return 0;
return fac[n]*inv[m]%MOD*inv[n-m]%MOD;
}
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
void NTT(int *a,int len,int op)
{
for(int i=0;i<len;i++)
{
rev[i]=(rev[i>>1]>>1)|((len/2)*(i&1));
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int s=2;s<=len;s<<=1)
{
int w=(op==1)?qkpow(3,(MOD-1)/s)
:qkpow(3,MOD-1-(MOD-1)/s),t=s/2;
for(int i=0;i<len;i+=s)
for(int j=0,x=1;j<t;j++,x=x*w%MOD)
{
int fe=a[i+j],fo=a[i+j+t];
a[i+j]=(fe+x*fo)%MOD;
a[i+j+t]=(fe-x*fo%MOD+MOD)%MOD;
}
}
if(op==1) return ;
int inv=qkpow(len,MOD-2);
for(int i=0;i<len;i++) a[i]=a[i]*inv%MOD;
}
void solve(int *a,int l,int r)
{
if(l==r) {a[0]=l;a[1]=1;return ;}
int mid=(l+r)>>1,len=1,zz=r-l+1;
int A[zz<<2]={},B[zz<<2]={};
solve(A,l,mid);solve(B,mid+1,r);
while(len<=zz) len<<=1;
NTT(A,len,1);NTT(B,len,1);
for(int i=0;i<len;i++) A[i]=A[i]*B[i]%MOD;
NTT(A,len,-1);
for(int i=0;i<=zz;i++) a[i]=A[i];
}
signed main()
{
n=read();a=read();b=read();init();
if(!a || !b || a+b>n+1) {puts("0");return 0;}
if(n==1) {puts("1");return 0;}
solve(s,0,n-2);
printf("%lld\n",s[a+b-2]*C(a+b-2,a-1)%MOD);
}
Perpetual Subtraction
题目描述
解法
幸好神 \(\tt OUYE\) 给我讲了一个多小时,要不然我根本不可能理解这东西。我尽量写一篇通俗易懂的题解,来帮助像我一样在线代方面没有任何基础的人吧。
\(\tt Warning\):请确保你先了解特征值、特征向量、基向量等基础概念再来阅读本文。
首先重点介绍一下对角化,对角化的核心是基变换。基变换的定义是:把 \(n\times n\) 矩阵 \(A\) 看成 \(n\) 个向量,变换这些向量所对应的基向量,得到新的矩阵 \(A'\),这个过程称为基变换。
原来的基是 \((1,0...0),(0,1...0),(0,0...1)\),它们对应单位矩阵 \(I\);现在我们把基变成 \(E^{(1)},E^{(2)}...E^{(n)}\),对应某个矩阵 \(E\);那么原来的向量 \(v\),在改变了基向量之后,就变成了 \(E^{-1}v\)
设矩阵 \(A\) 在改变基向量之后会变成 \(A'\),有这样一个等式:\(Av=EA'v'\),\(A'v'\) 的意思就是变换之后再做乘法,右乘上 \(E\) 之后就可以得到原来的结果 \(Av\),根据这个等式进一步推导:
研究清楚基变换以后,对角化就是选取矩阵 \(A\) 的 \(n\) 个特征向量组成 \(E\),那么 \(AE\) 的结果势必是对于 \(E\) 每一行的拉伸(即把 \(E\) 的每一行分别乘上一个数),所以 \(E^{-1}AE\) 的结果是对于 \(I\) 每一行的拉伸,这说明 \(A'\) 是一个对角矩阵。
总结一句:对角化的原理就是,通过选取特征向量为基向量,把 \(A\) 基变换成一个对角矩阵 \(A'\)。
本题就是对角化的经典应用,首先写出用来矩阵加速的矩阵 \(A\):
设输入的列向量是 \(p\),我们要求 \(A^mp\),核心就是求解 \(A^m\),首先改一个形式:
其中 \(E^{-1}AE\) 就是 \(A'\),对角线上的值是 \(A\) 的特征值,因为 \(det(A-\lambda I)=0\),那么 \(A'\) 就是 \(A\) 对角线上的值:
因为 \((A-\lambda I)\vec x=0\),可以手算几个特征向量 \(\vec x\) 找规律,那么可以得到 \(E\):
可以通过矩阵求逆得到 \(E^{-1}\)(如果去解方程,本质是二项式反演):
形式化地:\(E(i,j)=(-1)^{i+j}{i\choose j}\),\(E^{-1}(i,j)={i\choose j}\)
最后我们需要计算 \(E(E^{-1}AE)^mE^{-1}p\),从右往左一个个乘上去即可,把式子写出来发现是差卷积的形式,所以一发 \(\tt NTT\) 解决问题,时间复杂度 \(O(n\log n)\)
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int M = 400005;
#define int long long
const int MOD = 998244353;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,fac[M],inv[M],f[M],g[M],rev[M];
void init()
{
fac[0]=inv[0]=inv[1]=1;
for(int i=2;i<=n;i++) inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
for(int i=1;i<=n;i++) inv[i]=inv[i-1]*inv[i]%MOD;
for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%MOD;
}
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
void NTT(int *a,int len,int op)
{
for(int i=0;i<len;i++)
{
rev[i]=(rev[i>>1]>>1)|((len/2)*(i&1));
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int s=2;s<=len;s<<=1)
{
int w=(op==1)?qkpow(3,(MOD-1)/s)
:qkpow(3,MOD-1-(MOD-1)/s),t=s/2;
for(int i=0;i<len;i+=s)
for(int j=0,x=1;j<t;j++,x=x*w%MOD)
{
int fe=a[i+j],fo=a[i+j+t];
a[i+j]=(fe+x*fo)%MOD;
a[i+j+t]=(fe-x*fo%MOD+MOD)%MOD;
}
}
if(op==1) return ;
int inv=qkpow(len,MOD-2);
for(int i=0;i<len;i++) a[i]=a[i]*inv%MOD;
}
signed main()
{
n=read()+1;m=read();init();
//
for(int i=0;i<n;i++) f[i]=read()*fac[i]%MOD;
for(int i=0;i<n;i++) g[i]=inv[i];
reverse(f,f+n);
int len=1;while(len<=2*n) len<<=1;
NTT(f,len,1);NTT(g,len,1);
for(int i=0;i<len;i++) f[i]=f[i]*g[i]%MOD;
NTT(f,len,-1);
for(int i=n;i<len;i++) f[i]=g[i]=0;
reverse(f,f+n);
//
for(int i=0;i<n;i++)
{
f[i]=f[i]*qkpow(qkpow(i+1,MOD-2),m)%MOD;
g[i]=(i&1)?MOD-inv[i]:inv[i];
}
reverse(f,f+n);
NTT(f,len,1);NTT(g,len,1);
for(int i=0;i<len;i++) f[i]=f[i]*g[i]%MOD;
NTT(f,len,-1);
for(int i=n;i<len;i++) f[i]=g[i]=0;
reverse(f,f+n);
//
for(int i=0;i<n;i++)
printf("%lld ",f[i]*inv[i]%MOD);
}
stairs
题目描述
解法
本来看不懂这题题解直接就扔了,结果 \(\tt imzzy\) 硬是给我讲懂了,张教主 \(\tt yyds\)!
首先简化问题,我们考虑数组 \(a\) 的一个极长相同的连续段 \([l,r]\),需要满足 \((r-l+1)\bmod a_l=0\) 才有解。并且通过这个可以得到一个长为 \(m\) 的数组 \(b\),其中 \(b_i\) 表示第 \(i\) 段连续段的长度,发现计数只需要知道:
- 每一个连续段是从大到小,还是从小到大。特别地,如果 \(b_i=1\) 那么只有一种方案。
- 连续段之间的相对关系,可以以连续段中最小的数为代表,用一个 \(1,2...m\) 的排列来描述这个关系。
计数有一个很烦的限制:相邻两个连续段不能拼接成一个大的连续段。可以容斥,设最后的连续段数量是 \(i\),那么容斥系数是 \((-1)^{m-i}\),设 \(f_i\) 表示最后剩下 \(i\) 个连续段的方案数,答案可以写成:
\(f_i\) 的计算是平凡的,因为根据我们的构造方案,两个 \(b\) 数组中的相邻连续段其实可以随意拼接。计算可以考虑分治 \(\tt NTT\),为了解决长度为 \(1\) 的连续段只有 \(1\) 种方案的特殊情况,我们记 \(dp_{a,b,x}\) 表示最左边的大连续段长度 =1/>1
、最右边的大连续段长度 =1/>1
、大连续段的总数是 \(x\),的方案数是多少。
暴力合并左右部分,内层是一个卷积的形式,代码实现没有任何细节,时间复杂度 \(O(n\log ^2n)\)
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
const int M = 400005;
#define int long long
const int MOD = 998244353;
const int inv2 = (MOD+1)/2;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,a[M],b[M],rev[M],A[M],B[M],fac[M];
struct node
{
vector<int> v[2][2];
void resize(int x)
{
for(int i=0;i<2;i++) for(int j=0;j<2;j++)
v[i][j].resize(x+1);
}
vector<int>* operator [] (int x) {return v[x];}
};
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
void NTT(int *a,int len,int op)
{
for(int i=0;i<len;i++)
{
rev[i]=(rev[i>>1]>>1)|((len/2)*(i&1));
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int s=2;s<=len;s<<=1)
{
int w=(op==1)?qkpow(3,(MOD-1)/s)
:qkpow(3,MOD-1-(MOD-1)/s),t=s/2;
for(int i=0;i<len;i+=s)
for(int j=0,x=1;j<t;j++,x=x*w%MOD)
{
int fe=a[i+j],fo=a[i+j+t];
a[i+j]=(fe+x*fo)%MOD;
a[i+j+t]=(fe-x*fo%MOD+MOD)%MOD;
}
}
if(op==1) return ;
int inv=qkpow(len,MOD-2);
for(int i=0;i<len;i++) a[i]=a[i]*inv%MOD;
}
void add(int &x,int y) {x=(x+y)%MOD;}
node solve(int l,int r)
{
node s;s.resize(r-l+1);
if(r-l<=2)
{
if(r-l==1)
{
s[b[l]][b[r]][2]=(b[l]+1)*(b[r]+1);
s[1][1][1]=2;
}
if(r-l==2)
{
s[b[l]][b[r]][3]=(b[l]+1)*(b[l+1]+1)*(b[r]+1);
s[b[l]][1][2]=2*(b[l]+1);s[1][b[r]][2]+=2*(b[r]+1);
s[1][1][1]=2;
}
return s;
}
int mid=(l+r)>>1,len=1;
node X=solve(l,mid),Y=solve(mid+1,r);
while(len<=r-l+2) len<<=1;
for(int i=0;i<2;i++) for(int j=0;j<2;j++)
for(int x=0;x<2;x++) for(int y=0;y<2;y++)
{
vector<int> t;t.resize(r-l+2);
for(int p=0;p<len;p++) A[p]=B[p]=0;
for(int p=0;p<=mid-l+1;p++) A[p]=X[i][x][p];
for(int p=0;p<=r-mid;p++) B[p]=Y[y][j][p];
NTT(A,len,1);NTT(B,len,1);
for(int p=0;p<len;p++) A[p]=A[p]*B[p]%MOD;
NTT(A,len,-1);
for(int p=0;p<=r-l+1;p++) t[p]=A[p];
//
for(int p=0;p<=r-l+1;p++)
add(s[i][j][p],t[p]);
for(int p=0;p<=r-l;p++)
{
if(x+y==2) add(s[i][j][p],t[p+1]*inv2);
if(x+y==1) add(s[i][j][p],t[p+1]);
if(x+y==0) add(s[i][j][p],t[p+1]*2);
}
}
return s;
}
signed main()
{
n=read();fac[0]=1;
for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%MOD;
for(int i=1;i<=n;i++) a[i]=read();
for(int l=1,r=1;l<=n;l=r)
{
for(r=l;a[l]==a[r];r++);
if((r-l)%a[l]) {puts("0");return 0;}
for(int j=1;j<=(r-l)/a[l];j++)
b[++m]=(a[l]>1);
}
if(m==1) {puts("1");return 0;}
node r=solve(1,m);int ans=0;
for(int i=1;i<=m;i++)
{
int t=0;
for(int a=0;a<2;a++) for(int b=0;b<2;b++)
add(t,r[a][b][i]);
if((m-i)&1) add(ans,MOD-t*fac[i]%MOD);
else add(ans,t*fac[i]);
}
printf("%lld\n",ans);
}