题目描述
给你一个数列:
[f_n=egin{cases}
a^n&1leq nleq k\
sum_{i=1}^k(a-1)f_{n-i}&n>k
end{cases}
]
记(g_i)为当(k=i)时(f_n)的值,求
[sum_{i=1}^mg_i imes {19260817}^i
]
对于(60\%)的数据:(mleq 200,nleq {10}^9)
对于另外(40\%)的数据:(mleq {10}^9,nleq 3 imes {10}^6)
题解
第一部分
直接按常系数线性递推的通用做法来做。
可以不用FFT。
时间复杂度:(O(m^3log n))或(O(m^2log mlog n))
第二部分
因为当(igeq n)时(g_i=a^n),所以我们只需要求(g_1ldots g_{n-1})
[egin{align}
f_n&=af_{n-1}-(a-1)f_{n-m-1}\
F(x)&=axF(x)-(a-1)x^{k+1}F(x)+ax-ax^{k+1}\
(1-ax+(a-1)x^{k+1})F(x)&=ax-ax^{k+1}\
F(x)&=frac{ax-ax^{k+1}}{1-ax+(a-1)x^{k+1}}\
&=(ax-ax^{k+1})sum_{i=0}^inftysum_{j=0}^iinom{i}{j}{(1-a)}^jx^{j(k+1)}a^{i-j}x^{i-j}\
end{align}
]
记
[G(x)=sum_{i=0}^inftysum_{j=0}^iinom{i}{j}{(1-a)}^jx^{j(k+1)}a^{i-j}x^{i-j}\
]
那么
[egin{align}
F(x)&=(ax-ax^{k+1})G(x)\
[x^n]F(x)&=a[x^{n-1}]G(x)-a[x^{n-k-1}]G(x)\
[x^n]G(x)&=[x^n]sum_{i=0}^inftysum_{j=0}^iinom{i}{j}{(1-a)}^ja^{i-j}x^{jk+i}\
&=sum_{j}sum_{i=n-jk}inom{i}{j}{(1-a)}^ja^{i-j}\
&=sum_{j}inom{n-jk}{j}{(1-a)}^ja^{n-j(k+1)}\
end{align}
]
观察到对于所有的(k),(j)的取值总共有(O(nlog n))种,所以可以暴力枚举(k,j)。
时间复杂度:(O(nlog n+log m))
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
s=c-'0';
while((c=getchar())>='0'&&c<='9')
s=s*10+c-'0';
return s;
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
const ll p=998244353;
const ll vv=19260817;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
int n,k,m;
ll pw[100010];
ll c[510];
ll d[510];
ll e[510];
int len;
void mul()
{
static ll f[510];
for(int i=0;i<=2*len;i++)
f[i]=0;
for(int i=0;i<len;i++)
for(int j=0;j<len;j++)
f[i+j]=(f[i+j]+d[i]*d[j])%p;
for(int i=0;i<2*len;i++)
d[i]=f[i];
}
void mod()
{
for(int i=2*len;i>=len;i--)
if(d[i])
{
ll v=d[i];
for(int j=0;j<=len;j++)
d[i-len+j]=(d[i-len+j]-v*c[j])%p;
}
}
void pow(int n)
{
if(!n)
return;
pow(n>>1);
mul();
if(n&1)
{
for(int i=2*len;i>=1;i--)
d[i]=d[i-1];
d[0]=0;
}
mod();
}
ll calc1(int x)
{
len=x;
memset(d,0,sizeof d);
d[0]=1;
c[x]=1;
for(int i=0;i<x;i++)
c[i]=-k+1;
pow(n-1);
ll ans=0;
for(int i=1;i<=x;i++)
ans=(ans+pw[i]*d[i-1])%p;
return ans;
}
void solve1()
{
ll ans=0;
pw[0]=1;
for(int i=1;i<=m;i++)
pw[i]=pw[i-1]*k%p;
for(int i=m;i>=1;i--)
ans=(ans+calc1(i))*vv%p;
ans=(ans+p)%p;
printf("%lld
",ans);
}
int fac[3000010];
int inv[3000010];
int ifac[3000010];
int s1[3000010];
int s2[3000010];
int getc(int x,int y)
{
return (ll)fac[x]*ifac[y]%p*ifac[x-y]%p;
}
int gao(int n,int m)
{
int s=0;
for(int i=0;i*(m+1)<=n;i++)
s=(s+(ll)fac[n-i*m]*s1[i]%p*s2[n-i*(m+1)])%p;
return s;
}
void solve2()
{
inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=3000000;i++)
fac[i]=(ll)fac[i-1]*i%p;
s1[0]=1;
for(int i=1;i<=3000000;i++)
s1[i]=(ll)s1[i-1]*(1-k)%p;
s2[0]=1;
for(int i=1;i<=3000000;i++)
s2[i]=(ll)s2[i-1]*k%p;
for(int i=2;i<=3000000;i++)
{
inv[i]=(ll)-p/i*inv[p%i]%p;
ifac[i]=(ll)ifac[i-1]*inv[i]%p;
s1[i]=(ll)s1[i]*ifac[i]%p;
s2[i]=(ll)s2[i]*ifac[i]%p;
}
ll ans=0;
if(n<=m)
{
for(int i=n-1;i>=1;i--)
ans=(ans+gao(n-1,i)-gao(n-i-1,i))*vv%p;
ans=ans*k%p;
ll v=fp(k,n)*(fp(vv,m+1)-fp(vv,n))%p*fp(vv-1,p-2)%p;
ans=(ans+v)%p;
ans=(ans+p)%p;
}
else
{
for(int i=m;i>=1;i--)
ans=(ans+gao(n-1,i)-gao(n-i-1,i))*vv%p;
ans=ans*k%p;
ans=(ans+p)%p;
}
printf("%lld
",ans);
}
int main()
{
open("a");
scanf("%d%d%d",&m,&k,&n);
if(m<=200)
solve1();
else
solve2();
return 0;
}