一句话题意:给定一个集合$S$,求用其中的元素生成一个长度为$N$的序列,求满足$prod_{i=1}^{n}a_i=X a_iin S$的序列的方案数,其中$X$给定,且保证集合$S$中所有数都小于$m$且$m$为质数。
题解:
1.首先考虑DP
设dp[i][j]表示前i个数已经选完他们的乘积为j的方案数,转移方程很好想,显然是$dp[i][j]=sum _{l=0}^{m} sum_{r*l=j(mod m),rin S}dp[i-1][l]$
复杂度:$O(nm^2)$显然空间时间都吃不消。
2.然后考虑优化DP
首先看一眼DP转移式第一眼发现dp[i][j]的转移只与i-1有关,因此滚动,这样空间问题就会解决(时间???)
然后再看一眼DP转移方程,发现转移方程很灵性,它存在$dp[i*j (mod m)]=sum dp[i]*1 jin S$的形式,接下来是这道题最神的地方,因为m为质数,可以证明m必定存在一个原根,而质数原根的0到m-2次方在模m意义下分别对应着1到m-1这些数,也就是说我们可以把1到m-1这些数分别用原根的0到m-2次幂表示出来(有什么用呢???)
这样表示完之后,根据$g^{n}*g^{m}=g^{n+m}$的公式我们就可以把转移式中的乘法转化为加法啦。
注意这里有一个问题,就是当i+j大于m-2时怎么办,解决办法其实很简单,直接把i+j模上m-1就可以啦。证明 根据扩展欧拉定理$g^a=g^{a(mod phi(m))}(mod m) phi(m)=m-1$
这样转移方程变为$dp[i+j(mod m-1)]=sum _{G^jin S}dp[i]$,设计一个数组g,当且仅当$G^iin S$时g[i]=1。这样方程又变成了$dp[i+j (mod m-1)]=sum dp[i]*g[j]$显然是卷积的形式,可以将转移复杂度优化到$mlogm$但总时间复杂度为$nmlongm$还是过不了(难受啊)
最后我们可以发现,对于n次转移每次乘的g数组都是相同的因此我们可以直接求出g数组n次卷积后的结果再乘上dp数组,对于出g数组n次卷积的求法我们可以直接使用快速幂进行处理这样复杂度就变成了$mlogmlogn$可过
具体细节还是看代码吧(觉得这道题特别神奇,fft优化dp,处理i*j形式的卷积)(第一次没照题解打的NTT题.......)
1 #include<bits/stdc++.h> 2 #define ll long long 3 using namespace std; 4 ll mod=1004535809,n=1,lim,m,mm,x,slen,s[10000],G=3,inv;int rev[1000000],mp[1000000],j,l; 5 ll a[1000000],b[1000000],dp[1000000],g[1000000],ttt[1000000],invv[1000000]; 6 inline ll read(){ 7 ll x=0,f=1;char s=0; 8 while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();} 9 while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();} 10 return x*f; 11 } 12 inline ll ksm(ll a,ll b,ll p){a%=p;ll tmp=1;for(;b;b>>=1,a=a*a%p)if(b&1)tmp=tmp*a%p;return tmp;} 13 inline void init(){ 14 for(register int i=0;i<n;++i){ 15 int t=0;for(j=0;j<lim;++j) 16 if((i>>j)&1)t|=(1<<(lim-1-j)); 17 rev[i]=t; 18 } 19 for(l=2;l<=n;l*=2){ 20 ttt[l]=ksm(G,(mod-1)/l,mod); 21 invv[l]=ksm(ttt[l],mod-2,mod); 22 } 23 } 24 inline bool check(int x){ 25 for(register int i=2;i*i<=m-1;++i) 26 if(ksm(x,(m-1)/i,m)==1)return false; 27 return true; 28 } 29 inline ll findroot(){ 30 if(m==2)return 1; 31 for(register int i=2;i<=m;++i) 32 if(check(i))return i; 33 } 34 inline void NTT(ll *a,int flag){ 35 for(register int i=0;i<n;i++)if(rev[i]>i)swap(a[i],a[rev[i]]); 36 for(l=2;l<=n;l*=2){ 37 int m=l/2;ll tmp=ttt[l],omg; 38 if(flag==-1)tmp=invv[l]; 39 for(ll *p=a;p!=a+n;p+=l){ 40 omg=1;for(j=0;j<m;++j,omg=omg*tmp%mod){ 41 ll t=p[j+m]*omg%mod; 42 p[j+m]=p[j]-t;p[j+m]=(p[j+m]+mod)%mod; 43 p[j]=p[j]+t;p[j]%=mod; 44 } 45 } 46 } 47 } 48 inline void ksm2(ll tt){ 49 for(;tt;tt>>=1){ 50 if(tt&1){ 51 for(register int i=0;i<n;++i)b[i]=dp[i],a[i]=g[i]; 52 NTT(a,1),NTT(b,1);for(register int i=0;i<n;++i)a[i]=a[i]*b[i]%mod; 53 NTT(a,-1);for(register int i=0;i<m-1;++i)dp[i]=(a[i]*inv%mod+a[i+m-1]*inv%mod)%mod; 54 } 55 for(register int i=0;i<n;i++)b[i]=g[i],a[i]=g[i]; 56 NTT(a,1),NTT(b,1);for(register int i=0;i<n;++i)a[i]=a[i]*b[i]%mod; 57 NTT(a,-1);for(register int i=0;i<m-1;++i)g[i]=(a[i]*inv%mod+a[i+m-1]*inv%mod)%mod; 58 } 59 } 60 int main(){ll tt; 61 mm=read(),m=read(),x=read(),slen=read(); 62 ll tmp=findroot();for(register int i=0;i<m-1;++i)mp[ksm(tmp,i,m)]=i; 63 while(n<(m-1)*2)n<<=1,lim++;for(register int i=1;i<=slen;++i){ 64 tt=read();if(tt==0)continue; 65 g[mp[tt]]++; 66 }dp[0]=1; 67 init();inv=ksm(n,mod-2,mod);ksm2(mm);printf("%lld ",dp[mp[x]]); 68 }