传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=3992
【题解】
很容易得到一个dp但是复杂度不对
我们想到用原根把乘法改成加法。
然后a1a2...an=g^(b1+b2+...+bn)
我们找到g^k=x,那么就有b1+b2+...+bn=x(mod (m-1))(m-1就是phi(m))
考虑生成函数,那么即为生成函数的n次方 mod x^(m-1)中,k次项的系数。
注意这里的mod是要把后面半部分移到前面的。
这样就可以FFT了,复杂度O(mlognlogm)
# include <stdio.h> # include <string.h> # include <algorithm> // # include <bits/stdc++.h> using namespace std; typedef long long ll; typedef long double ld; typedef unsigned long long ull; const int M = 8000 + 10, N = 2e5 + 10; const int mod = 1004535809; const int G = 3; # define RG register # define ST static int n, m, X, S; int fp[M]; struct pa { int a[N]; } A, ans; inline int pwr(int a, int b, int P) { int ret = 1; while(b) { if(b&1) ret = 1ll * ret * a % P; a = 1ll * a * a % P; b >>= 1; } return ret; } int t[2333], tn=0; inline int getprt(int m) { tn = 0; for (int i=2; i<m-1; ++i) if((m-1)%i==0) t[++tn] = i; for (int i=2; ; ++i) { bool ok = 1; for (int j=1; j<=tn; ++j) if(pwr(i, t[j], m) == 1) { ok = 0; break; } if(ok) return i; } return -1; } namespace NTT { const int M = 2e5 + 10; int n, w[2][M], lst[M], invn; inline void init(int _n) { n = 1; while(n < _n) n <<= 1; w[0][0] = 1, w[1][0] = 1; int g = pwr(G, (mod-1)/n, mod), invg = pwr(g, mod-2, mod); for (int i=1; i<n; ++i) w[0][i] = 1ll * w[0][i-1] * g % mod, w[1][i] = 1ll * w[1][i-1] * invg % mod; int len = 0; while((1<<len) < n) ++len; for (int i=0; i<n; ++i) { int t = 0; for (int j=0; j<len; ++j) if(i&(1<<j)) t |= (1<<(len-j-1)); lst[i] = t; } invn = pwr(n, mod-2, mod); } inline void DFT(int *a, int op) { int *o = w[op]; for (int i=0; i<n; ++i) if(i < lst[i]) swap(a[i], a[lst[i]]); for (int len=2; len<=n; len<<=1) { int m = (len>>1); for (int *p = a; p != a+n; p += len) { for (int k=0; k<m; ++k) { int t = 1ll * o[n/len*k] * p[k+m] % mod; p[k+m] = p[k] - t; if(p[k+m] < 0) p[k+m] += mod; p[k] = p[k] + t; if(p[k] >= mod) p[k] -= mod; } } } if(op) { for (int i=0; i<n; ++i) a[i] = 1ll * a[i] * invn % mod; } } inline void mul(int *x, pa A, pa B, int Mod) { DFT(A.a, 0); DFT(B.a, 0); for (int i=0; i<n; ++i) A.a[i] = 1ll * A.a[i] * B.a[i] % mod; DFT(A.a, 1); for (int i=0; i<n; ++i) x[i] = 0; for (int i=0; i<n; ++i) { int np = i % Mod; x[np] = x[np] + A.a[i]; if(x[np] >= mod) x[np] -= mod; } } } int main() { scanf("%d%d%d%d", &n, &m, &X, &S); int gg = getprt(m), sum = 1; for (int i=0; i<m-1; ++i) { fp[sum] = i; sum = 1ll * sum * gg % m; } for (int i=1, pt; i<=S; ++i) { scanf("%d", &pt); if(pt) A.a[fp[pt]] = ans.a[fp[pt]] = 1; } NTT::init(m+m); --n; while(n) { if(n&1) NTT::mul(ans.a, ans, A, m-1); NTT::mul(A.a, A, A, m-1); n >>= 1; } // printf("%d ", fp[X]); printf("%d ", ans.a[fp[X]]); return 0; }