题目
有(S)个不同的数构成的序列(每个数可以使用任意次数),求有多少个不同的长度为(n)的序列,满足它们的乘积在模(m)意义下为(x),答案对1004535809取模
思路
设(f(i,j))表示填了前(i)个数字,模(m)为(j)的方案数,推状态转移方程的时候可以发现第一维完全没有必要一步一步走,即
[f(i+j,C) = prod_{A imes B=C}f(i,A) imes f(j,B)
]
可以运用类似快速幂的方法,将第一维压成(logn),时间复杂度为(O(m^2logn))
第二维是(A+B=C)就可以卷积了,考虑用原根将乘法变成加法,即(A imes B = C,(mod) (M))变成(g^{log_gA+ log_gB}equiv g^{log_gC},(mod) (M)),进而有(log_gA+log_gBequiv log_gC,(mod) (varphi(M)))
那么对所有数取个(log_g),就可以在模(varphi(M))意义下进行加法卷积了
Code
#include<bits/stdc++.h>
#define N 400005
using namespace std;
typedef long long ll;
const ll mod = 1004535809;
int n,m,x,s,mp[8005];
int r[N],len,L;
ll a[N],b[N];
ll p[N],q[N];
ll G=3,Gi,inv,g;
template <class T> inline T Max(T a,T b) { return a > b ? a : b; }
template <class T> inline T Min(T a,T b) { return a < b ? a : b; }
template <class T> void read(T &x)
{
char c;int sign=1;
while((c=getchar())>'9'||c<'0') if(c=='-') sign=-1; x=c-48;
while((c=getchar())>='0'&&c<='9') x=(x<<1)+(x<<3)+c-48; x*=sign;
}
ll qp(ll a,ll b,ll mo) { ll ret=1; for(;b;b>>=1,a=a*a%mo) if(b&1) ret=ret*a%mo; return ret; }
int get_g(int m)
{
for(int i=2;i<=m-1;++i)
{
int x=m-1,ok=1;
for(int j=2;j*j<=x;++j)
{
if(x%j==0)
{
if(qp(i,(m-1)/j,m)==1) ok=0;
while(x%j==0) x/=j;
}
}
if(x!=1&&x!=m-1) if(qp(i,(m-1)/x,m)==1) ok=0;
if(ok) return i;
}
return -1;
}
void NTT(ll *a,int opt)
{
for(int i=0;i<len;++i) if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=1;i<len;i<<=1)
{
ll wn=qp(opt==1 ? G : Gi , (mod-1)/(i<<1),mod);
for(int j=0;j<len;j+=(i<<1))
{
ll w=1;
for(int k=0;k<i;++k,w=w*wn%mod)
{
ll x=a[j+k],y=a[j+k+i]*w%mod;
a[j+k]=(x+y)%mod;
a[j+k+i]=(x-y+mod)%mod;
}
}
}
if(opt==-1)
{
inv=qp(len,mod-2,mod);
for(int i=0;i<len;++i) a[i]=a[i]*inv%mod;
}
}
void mul(ll *a,ll *b,ll *c)
{
for(int i=0;i<m-1;++i) p[i]=a[i],q[i]=b[i];
for(int i=m-1;i<len;++i) p[i]=q[i]=0;
NTT(p,1);NTT(q,1);
for(int i=0;i<len;++i) p[i]=p[i]*q[i]%mod;
NTT(p,-1);
for(int i=0;i<m-1;++i) c[i]=(p[i] + p[i+m-1])%mod;
}
int main()
{
read(n);read(m);read(x);read(s);
g=get_g(m); Gi=qp(G,mod-2,mod);
for(int i=0;i<m-1;++i) mp[qp(g,i,m)]=i;//模意义取log
for(int i=1;i<=s;++i)
{
int x; read(x);
if(x) ++a[mp[x]];
}
b[mp[1]]=1;
len=1; L=0;
while(len < (m<<1)) len<<=1,++L;
for(int i=0;i<len;++i) r[i]=((r[i>>1]>>1)|((i&1)<<(L-1)));
while(n)
{
if(n&1) mul(a,b,b);
mul(a,a,a);
n>>=1;
}
cout<<(b[mp[x]]%mod+mod)%mod<<endl;
return 0;
}