luogu5075 [JSOI2012]分零食
题目链接:luogu
首先有一个很显然的dp,记(f_{i,j})为分给前(i)个人(j)块糖的答案,那么有(f_{i,j}=sum_{k=0}^jf_{i-1,k} imes g_{j-k}).其中(g_i)表示分给一个人(i)块糖时的快乐度(特别的,(g_0=0))
这个dp直接做很没有前途,但是发现它的转移是类似卷积的形式,于是考虑往生成函数的方向思考。
记(g_i)的生成函数为(G(x)),(f_{i,j})的生成函数为(F_i(x)),那么每次转移其实就是(F_i(x)=F_{i-1}(x)G(x)),初始条件为(F_0(x)=1), 最后的答案为(sum_{i=1}^m[x^m]F_i(x)).那么有(F_i(x)=G^i(x)).
事实上答案可以看成是等比数列求和的形式,于是可以用等比数列求和转化后使用多项式求逆+快速幂,可以做到一个(log),但是模数是一个比较麻烦的问题。
考虑一个不依赖模数的做法:记(F_n(x)=sum_{i=1}^n G^i(x)). 使用类似快速幂的方法求(F_n(x)).
- (n)为偶数时
已知(F_{frac{n}{2}}(x),G^{frac{n}{2}}(x)).
[egin{aligned}
F_n(x)=&sum_{i=1}^{frac{n}{2}}G^i(x)+sum_{i=frac{n}{2}+1}^nG^i(x)\
=&F_{frac{n}{2}}(x)+G^{frac{n}{2}}(x)sum_{i=1}^{frac{n}{2}}F^i(x)\
=&F_{frac{n}{2}}(x)+G^{frac{n}{2}}(x)F_{frac{n}{2}}(x)\
G^n(x)=&G^{frac{n}{2}}(x)G^{frac{n}{2}}(x)
end{aligned}
]
- (n)为奇数时
发现我们上面求的其实是(n-1)的情况,把(n)的贡献单独加上去即可。
[G^n(x)=G^{n-1}(x)G^1(x)\
F_n(x)=F_{n-1}(x)+G^n(x)
]
#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<bitset>
#include<math.h>
#include<stack>
#include<queue>
#include<set>
#include<map>
using namespace std;
typedef long long ll;
typedef long double db;
typedef vector<int> vi;
typedef pair<int,int> pii;
const int N=100000+100;
const db pi=acos(-1.0);
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define go(u,i) for (register int i=head[u];i;i=sq[i].nxt)
#define fir first
#define sec second
#define mp make_pair
#define pb push_back
#define eps 1e-8
int maxd;
inline int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
namespace My_Math{
#define N 100000
int fac[N+100],invfac[N+100];
int add(int x,int y) {return x+y>=maxd?x+y-maxd:x+y;}
int dec(int x,int y) {return x<y?x-y+maxd:x-y;}
int mul(int x,int y) {return 1ll*x*y%maxd;}
ll qpow(ll x,int y)
{
ll ans=1;
while (y)
{
if (y&1) ans=mul(ans,x);
x=mul(x,x);y>>=1;
}
return ans;
}
int getinv(int x) {return qpow(x,maxd-2);}
int C(int n,int m)
{
if ((n<m) || (n<0) || (m<0)) return 0;
return mul(mul(fac[n],invfac[m]),invfac[n-m]);
}
void math_init()
{
fac[0]=invfac[0]=1;
rep(i,1,N) fac[i]=mul(fac[i-1],i);
invfac[N]=getinv(fac[N]);
per(i,N-1,1) invfac[i]=mul(invfac[i+1],i+1);
}
#undef N
}
using namespace My_Math;
int n,m,O,S,U,f1[N<<2],f[N<<2],g[N<<2],h[N<<2];
namespace Polynomial{
struct complex{
double x,y;
complex(double _x=0.0,double _y=0.0) {x=_x;y=_y;}
};
complex operator +(complex a,complex b)
{
return complex(a.x+b.x,a.y+b.y);
}
complex operator -(complex a,complex b)
{
return complex(a.x-b.x,a.y-b.y);
}
complex operator *(complex a,complex b)
{
return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
complex A[N<<2],B[N<<2];
int r[N<<2];
void fft(int lim,complex *a,int typ)
{
rep(i,0,lim-1)
if (i<r[i]) swap(a[i],a[r[i]]);
for (int mid=1;mid<lim;mid<<=1)
{
int len=(mid<<1);
complex wn=complex(cos(pi/mid),sin(pi/mid)*typ);
for (int sta=0;sta<lim;sta+=len)
{
complex w=complex(1,0);
for (int j=0;j<mid;j++,w=w*wn)
{
complex x=a[sta+j],y=a[sta+j+mid]*w;
a[sta+j]=x+y;a[sta+j+mid]=x-y;
}
}
}
}
int calcr(int len)
{
int lim=1,cnt=0;
while (lim<len) {lim<<=1;cnt++;}
rep(i,0,lim-1)
r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
return lim;
//rep(i,0,lim-1) cout << r[i] << " ";cout << endl;
}
void mul(int *a,int *b,int *c,int lim)
{
//rep(i,0,lim-1) cout << a[i] << " ";cout << endl;
//rep(i,0,lim-1) cout << b[i] << " ";cout << endl;
rep(i,0,lim-1)
{
A[i]=complex(a[i],0);
B[i]=complex(b[i],0);
}
fft(lim,A,1);fft(lim,B,1);
rep(i,0,lim-1) A[i]=A[i]*B[i];
fft(lim,A,-1);
rep(i,1,m) c[i]=(int)(A[i].x/lim+0.5)%maxd;
//rep(i,1,m) cout << c[i] << " ";cout << endl << endl;
}
void solve(int k,int lim)
{
if (k==1)
{
rep(i,0,lim-1) f[i]=g[i]=f1[i];
return;
}
solve(k>>1,lim);
mul(f,g,h,lim);
rep(i,0,lim-1) g[i]=add(g[i],h[i]);
mul(f,f,f,lim);
if (k&1)
{
mul(f,f1,f,lim);
rep(i,0,lim-1)
g[i]=add(g[i],f[i]);
}
}
}
using namespace Polynomial;
int calc(int x) {return (1ll*O*sqr(x)+S*x+U)%maxd;}
int main()
{
m=read();maxd=read();n=read();O=read();S=read();U=read();
rep(i,1,m) f1[i]=calc(i);
int lim=calcr(m<<1);
solve(min(n,m),lim);
printf("%d",g[m]);
return 0;
}