问题描述
求一个满足 $K$ 阶齐次线性递推数列 $a_i$ 的第 $n$ 项,即:$a_n = sum_{i=1}^k f_i imes a_{n-i}$.
分析
首先写成矩阵快速幂
$$left( egin{bmatrix} f_1 &f_2 &f_3 &f_4 & cdots &f_{k-2} &f_{k-1} \ 1 &0 &0 &0 & cdots &0 &0 \ 0 &1 &0 &0 & cdots &0 &0\ cdots & cdots& cdots & cdots & cdots & cdots & cdots\ 0 &0 &0 &0 & cdots &1 &0 end{bmatrix} ight) ^n imes egin{bmatrix} a_{k-1} \ a_{k-2} \ cdots \ a_{1} \ a_{0}end{bmatrix} =egin{bmatrix} a_{n+k-1} \ a_{n+k-2} \ cdots \ a_{n+1} \ a_{n}end{bmatrix}$$
所以我们只需要算出 $M^N imes A$,然后取最后一个数即可。
使用矩阵快速幂,复杂度 $O(k^3 log_2n)$.
Carlay-Hamilton定理
设有 $k$ 个特征值的矩阵 $A$的特征多项式为 $f(lambda ) =prod_{i=1}^k(lambda_i - x)$,则有 $f(A) = 0$,$0$ 为零矩阵。
用这个定理来优化递推
由前面的矩阵快速幂,我们只要求出 $M^n$就可以了。
我们考虑 $M$ 的特征多项式 $f(x)$,这是一个 $k$ 次多项式。我们对 $M^n$ 做带余除法 $M^n = f(M) imes g(M) + R(M)$。
由于 $f(M) = 0$,所以 $M^n equiv R(M) (mod f(M))$,$R(M)$ 是一个次数不超过 $k-1$ 的多项式。
也就是说,我们只要求出 $M^n \% f(M)$就可以了
但是要怎么求呢?我们考虑快速幂的过程(就是倍增)
假设我们现在已知 $g(M)=M^{2^i} \% f(M)$,现在要求 $h(M)= M^{2^{i+1}} \% f(M)$。
一个直接的想法是令 $H(M)=g(M) imes g(M)$。但是这样做 $H(x)$ 的次数是 $2k-2$次的。
那么我们考虑原本的递推关系,$a_n=sumlimits_{i=1}^{k}a_{n-i}*f_i$
不难得到 $M^n=sumlimits_{i=1} ^{k} M^{n-i} imes f_{i}$
所以我们可以用这个式子将多余的系数都向前压一位。
这样我们可以得到一个 $O(k^2 log_2 n)$ 的做法。
那么有没有优化的余地呢?我们从倍增的过程入手,可以发现 $H(M) = g(M) imes g(M)$ 的过程可以用FFT/NTT加速至 $O(k log_2k)$。
现在只要解决压系数就可以了,把 $H(M)$ 模 $f(M)$ 即可。
我们的推导一直用到这个特征多项式 $f(x)$,如何求得呢?
根据定义, $f(x) = det(xI - M)$,得到
$$f(x) = |x I - M| = egin{bmatrix} x- a_1 & -a_2 & -a_3 & cdots & -a_{k - 2} & -a_{k - 1} & -a_k \ -1 & x & 0 & cdots & 0 & 0 & 0 \ 0 & -1 & x & cdots & 0 & 0 & 0 \ 0 & 0 & -1 & cdots & 0 & 0 & 0 \ vdots & vdots & vdots & ddots & vdots & vdots & vdots \ 0 & 0 & 0 & cdots & -1 & x & 0 \ 0 & 0& 0 & cdots & 0 & -1 & xend{bmatrix}$$
对第一行进行展开,得到
$$f(x) = (x - a_1)M_{11} + (-a_2)M_{12} + cdots + (-a_k)M_{1n} = x ^ k - a_1 x ^ {k - 1} - a_2x ^ {k - 2} - cdots - a_k$$
代码1:
$O(k log_2k log_2n)$的做法
思路其实就是去做一个类似快速幂的操作,然后把乘法改成多项式下的,取模也改成多项式下的
// luogu-judger-enable-o2 #include<cstdio> #include<algorithm> using namespace std; typedef long long ll; const ll mod=998244353; const int N=65536+10; int n;int k;int rv[20][N];ll rt[20][20];int Len;ll tr1[N];ll tr2[N];long long st[N];long long xs[N]; ll sg[N];ll a[N];ll res[N];ll irg[N];ll q[N];ll rf[N];int DL=-1;ll ans=0;ll ret[N]; inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;} inline void ntt(ll* a,int o,int len,int d)//ntt { for(int i=0;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]); for(int k=1,j=1;k<len;k<<=1,j++) for(int s=0;s<len;s+=(k<<1)) for(int i=s,w=1;i<s+k;i++,w=w*rt[o][j]%mod) {ll a0=a[i];ll a1=a[i+k]*w%mod;a[i]=(a0+a1)%mod,a[i+k]=(a0+mod-a1)%mod;} if(o==1){ll inv=po(len,mod-2);for(int i=0;i<len;i++)(a[i]*=inv)%=mod;} } inline void poly_inv(ll* a,ll* b,int len)//求逆 { b[0]=po(a[0],mod-2); for(int k=1,j=0;k<=len;k<<=1,j++) { for(int i=0;i<k;i++)tr1[i]=a[i];for(int i=0;i<k;i++)tr2[i]=b[i]; ntt(tr1,0,k<<1,j);ntt(tr2,0,k<<1,j); for(int i=0;i<(k<<1);i++)b[i]=tr2[i]*(2+mod-tr1[i]*tr2[i]%mod)%mod; ntt(b,1,k<<1,j);for(int i=k;i<(k<<1);i++)b[i]=0; } } inline void poly_mod(ll* a)//取模 { int mi=(k<<1);while(a[--mi]==0);if(mi<k)return; for(int i=0;i<(Len<<1);i++)rf[i]=0;for(int i=0;i<=mi;i++)rf[i]=a[i]; reverse(rf,rf+mi+1);for(int i=mi-k+1;i<=mi;i++)rf[i]=0;ntt(rf,0,Len<<1,DL+1); for(int i=0;i<(Len<<1);i++)q[i]=(rf[i]*irg[i])%mod;ntt(q,1,(Len<<1),DL+1); for(int i=mi-k+1;i<=(Len<<1);i++)q[i]=0;reverse(q,q+mi-k+1);ntt(q,0,(Len<<1),DL+1); for(int i=0;i<(Len<<1);i++)(q[i]*=sg[i])%=mod;ntt(q,1,(Len<<1),DL+1); for(int i=0;i<k;i++)(a[i]+=mod-q[i])%=mod;for(int i=k;i<=mi;i++)a[i]=0; } int main() { for(int i=0;i<=15;i++) for(int j=0;j<(1<<(i+1));j++)rv[i][j]=(rv[i][j>>1]>>1)|((j&1)<<i); for(int t=2,j=1;j<=18;t<<=1,j++)rt[0][j]=po(3,(mod-1)/t); for(int t=2,j=1;j<=18;t<<=1,j++)rt[1][j]=po(332748118,(mod-1)/t); scanf("%d%d",&n,&k); for(Len=1;Len<=k;Len<<=1,DL++); //预处理 for(int i=1;i<=k;i++){scanf("%lld",&xs[i]);xs[i]=xs[i]<0?xs[i]+mod:xs[i];} for(int i=0;i<k;i++){scanf("%lld",&st[i]);st[i]=st[i]<0?st[i]+mod:st[i];} for(int i=1;i<=k;i++)sg[k-i]=mod-xs[i];sg[k]=1;for(int i=0;i<=k;i++)ret[i]=sg[i]; for(int i=0;i<=k;i++)rf[i]=sg[i];reverse(rf,rf+k+1);poly_inv(rf,irg,Len); for(int i=0;i<=k;i++)rf[i]=0;ntt(sg,0,Len<<1,DL+1);ntt(irg,0,Len<<1,DL+1);a[1]=1;res[0]=1; while(n)//快速幂 { if(n&1) { ntt(res,0,Len<<1,DL+1);ntt(a,0,Len<<1,DL+1); for(int i=0;i<(Len<<1);i++)(res[i]*=a[i])%=mod; ntt(res,1,Len<<1,DL+1);ntt(a,1,Len<<1,DL+1);poly_mod(res); }ntt(a,0,Len<<1,DL+1);for(int i=0;i<(Len<<1);i++)(a[i]*=a[i])%=mod; ntt(a,1,Len<<1,DL+1);poly_mod(a);n>>=1; } for(int i=0;i<k;i++)(ans+=res[i]*st[i])%=mod; printf("%lld",ans); return 0; }
代码2:
$O(k^2 log_2n)$的做法
#include<cstdio> #include<cstring> #include<cstdlib> #include<cctype> #include<cmath> #include<iostream> #include<algorithm> #include<vector> #include<set> #include<map> #include<queue> #include<stack> #include<cassert> typedef long long ll; typedef unsigned long long ull; using namespace std; const int P=1000000007; const int MAXN=4010; //2*k+10 int n,k,ans; int f[MAXN],h[MAXN]; struct Matrix{ //其实是多项式 int a[MAXN]; Matrix (){memset(a,0,sizeof a);} int& operator [] (const int &i) {return a[i];} int operator [] (const int &i) const {return a[i];} inline Matrix operator * (const Matrix &rhs) const { Matrix ret; for(int i=0;i<k;i++) for(int j=0;j<k;j++) (ret[i+j]+=1ll*a[i]*rhs[j]%P)%=P; for(int i=2*k-2;i>=k;ret[i--]=0) for(int j=1;j<=k;j++) //这里就是多项式取模优化的地方 (ret[i-j]+=1ll*ret[i]*f[j]%P)%=P; //可以认为是暴力向前压系数 return ret; } }res; Matrix ksm(Matrix a,int b) { Matrix ret; ret[0]=1; for(;b;a=a*a,b>>=1) if(b&1) ret=ret*a; return ret; } int main() { scanf("%d%d",&n,&k); for(int i=1;i<=k;i++) scanf("%d",&f[i]),f[i]=f[i]>0?f[i]:f[i]+P; for(int i=0;i<k;i++) scanf("%d",&h[i]),h[i]=h[i]>0?h[i]:h[i]+P; if(n<k) printf("%d ",h[n]); res[1]=1;ans=0; res=ksm(res,n); for(int i=0;i<k;i++) ans=(ans+1ll*res[i]*h[i]%P)%P; printf("%d ",ans); }
参考链接:
1. https://www.luogu.org/problemnew/solution/P4723
2. https://www.luogu.org/blog/Zhang-RQ/chang-ji-shuo-ji-ci-xian-xing-di-tui-chu-tan