Description
小C有一个集合(S),里面的元素都是小于(M)的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为(N)的数列,
数列中的每个数都属于集合(S)。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:
给定整数(x),求所有可以生成出的,且满足数列中所有数的乘积(mod M)的值等于(x)的不同的数列的有多少个。
小C认为,两个数列({A_i})和({B_i})不同,当且仅当至少存在一个整数(i),满足(A_i eq B_i)。
另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案(mod 1004535809)的值就可以了。
Input
一行,四个整数,(N、M、x、|S|),其中(|S|)为集合(S)中元素个数。
第二行,(|S|)个整数,表示集合(S)中的所有元素。
(1 leq N leq 10^9,3 leq M leq 8000),M为质数
(0 leq x leq M-1),输入数据保证集合S中元素不重复(x in [1,m-1])
集合中的数$ in [0,m-1]$
Output
一行,一个整数,表示你求出的种类数(mod 1004535809)的值。
Solution
看到这题。。首先很容易列出一个DP转移方程
令(F_{i,j})表示选了(i)个数字,当前乘积为(j)的种类数
我们发现它非常不优美,复杂度高达$ O (n * m^2) $
我们发现这个式子可以倍增。。于是很轻松的干掉一个n,它的复杂度变成了$ O(log n * m^2) $
这貌似还是有点多。。考虑如何干掉一个 $ m $
咦。。这个模数貌似有点熟悉。。考虑NTT
不过这是乘法。。我们做不了NTT 。。。
考虑原根
设(p)为(m)的原根。。那么(p)的幂次可以表示出([1,m))的所有数字————原根定义
于是DP方程变成了这样
注意。。此时(F_{i,j})表示选到第(i)个数,大小为(p^j)次的方案数
再一变
我们发现这玩意长得像个卷积。。可以用NTT了
于是复杂度变成了 $ O(m log n log m)$
Code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int n,m,x,lens,g[2000000],a[2000010],b[2000010];
int f[2000010],ans[2000010];
int fpow(int x,int k,int Mod)
{
int ans=1;
while (k)
{
if (k&1) ans=1LL*ans*x%Mod;
x=1LL*x*x%Mod;
k>>=1;
}
return ans;
}
namespace GetRoot //求原根
{
int prime[1000000],cnt;
bool check(int x,int p)
{
for (int i=1;i<=cnt;i++)
if (fpow(x,(p-1)/(prime[i]),p)==1) return 0;
return 1;
}
int find(int p)
{
int x=p-1;
for (int i=2;i*i<=x;i++)
{
if (x%i==0)
{
prime[++cnt]=i;
while (x%i==0) x/=i;
}
}
if (x!=1) prime[++cnt]=x;
for (int i=2;;i++)
if (check(i,p)) return i;
}
}
namespace NTT
{
const int Mod=1004535809,p=3;
int n=1;
void NTT(int *a,int inv)
{
int lim=0;
while ((1<<lim)<n) lim++;
for (int i=0;i<n;i++)
{
int t=0;
for (int j=0;j<lim;j++)
if ((i>>j) & 1) t|=1<<(lim-j-1);
if (i<t) swap(a[i],a[t]);
}
for (int l=2;l<=n;l*=2)
{
int m=l/2,p0=fpow(inv?fpow(p,Mod-2,Mod):p,(Mod-1)/l,Mod);
for (int *buf=a;buf!=a+n;buf+=l)
{
int pn=1;
for (int i=0;i<m;i++)
{
int t=1LL*pn*buf[i+m]%Mod;
buf[i+m]=(buf[i]-t+Mod)%Mod;
buf[i]=(buf[i]+t)%Mod;
pn=1LL*pn*p0%Mod;
}
}
}
}
void Union(int *a,int *c,int len)
{
while (n<2*len) n<<=1;
for (int i=0;i<n;i++) b[i]=0;
for (int i=0;i<len;i++) b[i]=c[i];
NTT(a,0);NTT(b,0);
for (int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%Mod;
NTT(a,1);
int invn=fpow(n,Mod-2,Mod);
for (int i=0;i<n;i++) a[i]=1LL*a[i]*invn%Mod;
for (int i=len-1;i<n;i++) a[i%(len-1)]=(a[i%(len-1)]+a[i])%Mod,a[i]=0;
}
}
void init()
{
int t=GetRoot::find(m);
for (int i=0,k=1;i<m-1;i++,k=1LL*k*t%m) g[k]=i;
x=g[x];
for (int i=1;i<=lens;i++)
if (a[i]) f[g[a[i]]]++; //若a[i]=0.就直接舍弃。
}
void solve() //倍增优化
{
int k=n;
ans[0]=1;
while (k)
{
if (k&1) NTT::Union(ans,f,m);
NTT::Union(f,f,m);
k>>=1;
}
printf("%d
",ans[x]);
}
int main()
{
scanf("%d%d%d%d",&n,&m,&x,&lens);
for (int i=1;i<=lens;i++) scanf("%d",&a[i]);
init();
solve();
return 0;
}