题面
已知$f_i=(sum_{j=1}ka_j{v_j}i )mod 1004535809$
给定$v_1,v_2,ldots,v_k,f_1,f_2,ldots f_k$
求$f_n$
思路
我们考虑构造一个递推式,使得:
$f_n=sum_{i=1}^k c_i f_{n-i}$
我们把这个$f_n$挪到右边来,令$c_0=1$,得到:
$sum_{i=0}^k c_i f_{n-i} =0$
即:
$sum_{i=0}^k c_i sum_{j=1}^k a_j v_j^{n-i}=0$
这个式子的一个充分条件(可行条件)
$forall j in [1,k] sum_{i=0}^k c_i a_j v_j^{n-i}=0$
把$a_j$挪到前面去,除掉一部分$v_j$的幂,得到这个式子:
$forall j in [1,k] sum_{i=0}^k c_i v_j^{k-i}=0$
令$F(x)=sum c_{k-i} x^i$,那么我们发现${v}$数组是$F(x)$的所有0点
又因为$c_0=-1$,所以$F(x)=-prod_{i=1}^k (x-v_i)$
分治FFT求出$F(x)$,然后用$O((n-k)k)$递推(不会TLE)得到$f_n$即可
Code
代码里有一个技巧
因为一段区间得到的n+1个系数的多项式的最高次项一定是1,所以我们可以不保存他
这样分治FFT用长度为n的数组就能保存了
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MOD 1004535809
#define ll long long
using namespace std;
inline int read(){
int re=0,flag=1;char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') flag=-1;
ch=getchar();
}
while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
return re*flag;
}
ll qpow(ll a,ll b){
ll re=1;
while(b){
if(b&1) re=re*a%MOD;
a=a*a%MOD;b>>=1;
}
return re;
}
ll add(ll a,ll b){
a+=b;
return ((a>=MOD)?a-MOD:a);
}
ll dec(ll a,ll b){
a-=b;
return ((a<0)?a+MOD:a);
}
ll g=3,ginv;
namespace NTT{
int lim,cnt,r[400010];
ll A[400010],B[400010];
void ntt(ll *a,ll type){
int i,j,k,mid;ll x,y,w,wn,inv;
for(i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(mid=1;mid<lim;mid<<=1){
wn=qpow(((~type)?g:ginv),(MOD-1)/(mid<<1));
for(j=0;j<lim;j+=(mid<<1)){
w=1;
for(k=0;k<mid;k++,w=w*wn%MOD){
x=a[j+k];y=a[j+k+mid]*w%MOD;
a[j+k]=add(x,y);
a[j+k+mid]=dec(x,y);
}
}
}
if(~type) return;
inv=qpow(lim,MOD-2);
for(i=0;i<lim;i++) a[i]=a[i]*inv%MOD;
}
void init(int n){
int i;
lim=1;cnt=0;
while(lim<=n) lim<<=1,cnt++;
for(i=0;i<lim;i++) r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1))),A[i]=B[i]=0;
}
}
void mul(){
using namespace NTT;
ntt(A,1);ntt(B,1);int i;
for(i=0;i<lim;i++) A[i]=A[i]*B[i]%MOD;
ntt(A,-1);
}
ll c[100010];//黑科技数组
int n,k;ll v[100010],f[100010];
void solve(int l,int r){
if(l==r){
c[l]=MOD-v[l];
return;
}
int mid=(l+r)>>1,i;
solve(l,mid);solve(mid+1,r);
using namespace NTT;
init(r-l+1);
for(i=0;i<=mid-l;i++) A[i]=c[i+l];
for(i=0;i<r-mid;i++) B[i]=c[i+mid+1];
A[mid-l+1]=B[r-mid]=1;//把没记录的1加上
mul();
for(i=0;i<=r-l;i++) c[l+i]=A[i];//这里不保存1
}
int main(){
n=read();k=read();int i,j;
g=3;ginv=qpow(3,MOD-2);
for(i=1;i<=k;i++) v[i]=read();
for(i=1;i<=k;i++) f[i]=read();
solve(1,k);
for(i=0;i<k;i++) c[i]=c[i+1];
c[k]=1;
for(i=0;i<=k;i++) if(c[i]) c[i]=MOD-c[i];
for(i=0;i<=k/2;i++) swap(c[i],c[k-i]);
for(i=k+1;i<=n;i++){
ll w=0;
for(j=1;j<=k;j++) w+=c[j]*f[i-j]%MOD;
f[i]=w%MOD;
}
printf("%lld
",f[n]);
}