来自FallDream的博客,未经允许,请勿转载,谢谢。
小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。1<=N<=10^9,3<=M<=8000,M为质数,1<=x<=M-1 输入数据保证集合S中元素不重复
学了FFT/NTT之后,没做过什么题,模板快忘了,本来是想写一题练练手,结果发现找到了这道神题。不会做,看题解,题解说 嗯,原根的各种性质,不会,补一波。嗯,生成函数,不会,看了好久懂了一点,还好这道题的生成函数很简单可以理解,然后抄了抄题解..
可以把乘积转换成对数相加,因为m是质数,所以可以找到m的原根p,用$p^{0}-p^{m-2}$表示所有数,题目转换为满足(对数相加取膜m-1等于(x的对数))的数量。
然后写出S的生成函数g(n)=a0+a1x+a2x^2+...+anx^n 其中ai是1当且仅当p^i可以选(在集合中)。因为没有限制条件,所以直接求生成函数的n次方,输出x的对数那一项系数就好了
注意题目中对数相加有一个取膜,所以每次做完要把>=m的项整到前面去。
复杂度mlognlogm
ps:1004535809的原根是3,找p的原根只要对p-1每个质因数p1-pk,判断x^((p-1)/pi)%p是否等于1就行了
#include<iostream> #include<cstdio> #include<cmath> #include<algorithm> #define mod 1004535809 #define MN 17000 using namespace std; inline int read() { int x = 0 , f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar();} while(ch >= '0' && ch <= '9'){x = x * 10 + ch - '0';ch = getchar();} return x * f; } int g,k,m,n,X,pos[MN+5],prime[MN],num=0,N,inv,w[2][MN+5],f[MN+5],t[MN+5],c[MN+5]; int pow(int x,int k,int P) { int sum=1; for(int i=x;k;k>>=1,i=1LL*i*i%P) if(k&1) sum=1LL*sum*i%P; return sum; } bool check(int x) { for(int i=1;i<=num;i++) if(pow(x,(m-1)/prime[i],m)==1)return 0; return 1; } int getrt(int x) { if(x==2)return 1;--x; for(int i=2;x>1;i++) if(x%i==0) { prime[++num]=i; while(x%i==0)x/=i; } for(int i=2;;++i) if(check(i))return i; } void NTT(int *x,int b) { for(int i=0,j=0;i<N;i++) { if(i>j)swap(x[i],x[j]); for(int l=N>>1;(j^=l)<l;l>>=1); } for(int i=2;i<=N;i<<=1)for(int j=0;j<N;j+=i)for(int k=0;k<i>>1;k++) { int t=1LL*x[j+k+(i>>1)]*w[b][1LL*N/i*k]%mod; x[j+k+(i>>1)]=(1LL*x[j+k]-t+mod)%mod; x[j+k]=(1LL*x[j+k]+t)%mod; } if(b)for(int i=0;i<N;i++)x[i]=1LL*x[i]*inv%mod; } void mul(int*a,int*b) { for(int i=0;i<N;i++)c[i]=b[i]; NTT(a,0);NTT(c,0); for(int i=0;i<N;i++)a[i]=1LL*a[i]*c[i]%mod; NTT(a,1); for(int i=N-1;i>=m-1;i--) a[i-m+1]=(a[i-m+1]+a[i])%mod,a[i]=0; } int main() { k=read();m=read();X=read();n=read(); g=getrt(m); for(int i=1,j=g;i<m-1;i++,j=1LL*g*j%m)pos[j]=i; for(N=1;N<m;N<<=1);N<<=1;inv=pow(N,mod-2,mod); g=pow(3,(mod-1)/N,mod); for(int i=0,j=1;i<=N;i++,j=1LL*j*g%mod) w[0][i]=w[1][N-i]=j; for(int i=1;i<=n;i++)t[pos[read()]]=1;t[pos[0]]=0; for(f[0]=1;k;k>>=1,mul(t,t)) if(k&1)mul(f,t); printf("%d ",f[pos[X]]); return 0; }