洛谷P3321「序列统计」
题目描述
小 (C) 有一个集合 (S),里面的元素都是小于 (m) 的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为 (n) 的数列,数列中的每个数都属于集合 (S) 。
小 (C) 用这个生成器生成了许多这样的数列。但是小 (C) 有一个问题需要你的帮助:给定整数 (x),求所有可以生成出的,且满足数列中所有数的乘积 (mod m) 的值等于 (x) 的不同的数列的有多少个。
小C认为,两个数列 (A) 和 (B) 不同,当且仅当 (exists i ; ext{s.t.} A_i eq B_i)。另外,小 (C) 认为这个问题的答案可能很大,因此他只需要你帮助他求出答案对 (1004535809) 取模的值就可以了。
输入格式
一行,四个整数,(n,m,x,left | S ight |),其中 (left | S ight |) 为集合 (S) 中元素个数。
第二行,(left | S ight |) 个整数,表示集合 (S) 中的所有元素。
输出格式
一行一个整数表示答案。
输入输出样例
输入
4 3 1 2
1 2
输出
8
说明/提示
【样例说明】
可以生成的满足要求的不同的数列有
((1,1,1,1),;(1,1,2,2),;(1,2,1,2),;(1,2,2,1),;(2,1,1,2),;(2,1,2,1),;(2,2,1,1),;(2,2,2,2))。
【数据规模和约定】
对于 (10\%) 的数据,(1leq nleq 1000);
对于 (30\%) 的数据,(3leq mleq 100);
对于 (60\%) 的数据,(3leq mleq 800);
对于 (100\%) 的数据,(1leq nleq 10^9,3leq mleq 8000,1leq x <m)。
(m) 为质数,输入数据保证集合 (S) 中元素不重复。
题解
(m) 的值域很小,考虑用Triple一题的思路,将集合中的数放到多项式上,问题就可以转化为一个式子:
这已经很像我们多项式乘法的式子了,但是唯一的不同是,这里的 (i) 和 (j) 是相乘的
考虑如何将乘法转化成加法?
可以利用高中数学里的对数
对数有个很好的性质:
不妨将 (i) 和 (j) 都用一个数来取对数,得到 (log^i) 和 (log^j)
在询问 (x) 地方的值的时候,相当于是询问 (log^x) 地方的值
现在问题转化为
问题又来了,要选取哪个底数来对所有的值取对数呢?
利用原根
如果连原根都不知道是什么的小朋友,可以去百度百科初步了解一下,不会原根,你怎么学的 (NTT)?
我们想把所有值取 (log) 要保证什么?
比如说我们取的底数是 (g)
我们需要 (1sim m-1) 的 (log_g^imod log_g^m) 互不相同
即 (1sim m-1) 的 (g^imod m) 互不相同
这不就是原根的第二个性质嘛
(1sim m-1) 的 (g^i) 正好一一对应了 (1sim m - 1) 的所有值
我们就可以把原题中集合 (S) 的每个值去取对数了
然后我们就可以得到一个 (1sim m-1) 的多项式,由于没有常数项难以处理,直接变成 (0sim m-2) 的多项式来处理
对于取模,就很好处理了
两个长度为 (m-2) 的多项式相乘,对于得到的多项式的 (m-1) 次项以后,同时也对答案造成了贡献,次数模上模数加上贡献即可
我们要选 (n) 个数,而且没有像Triple一样“不能选重复的限制”,不用什么乱七八糟的容斥,所以直接 (f(x)^k) 即可
我们可以不用像多项式快速幂那么麻烦的快速幂,而且 (a_0) 不一定为 (1),也用不了
可以像普通实数快速幂一样,每次只留前 (m-2) 位,只不过效率是 (nlog^{2n})
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int maxn = 2e5 + 50, INF = 0x3f3f3f3f, mod = 1004535809, inv3 = 334845270;
inline int read () {
register int x = 0, w = 1;
register char ch = getchar ();
for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
return x * w;
}
inline void write (register int x) {
if (x / 10) write (x / 10);
putchar (x % 10 + '0');
}
int n, m, g, X, s, len = 1, bit;
bool vis[maxn];
int res[maxn], tmp[maxn], ans[maxn];
int f[maxn], id[maxn], rev[maxn];
inline int gqpow (register int a, register int b, register int ans = 1) {
for (; b; b >>= 1, a = 1ll * a * a % m)
if (b & 1) ans = 1ll * ans * a % m;
return ans;
}
inline int Get_g (register int m) {
for (register int i = 0; i < m; i ++) {
memset (vis, 0, sizeof 4 * m);
for (register int k = 1, tmp; k <= m - 1; k ++) {
tmp = gqpow (i, k);
if (vis[tmp]) goto end;
else vis[tmp] = 1;
}
return i;
end:;
}
return -1;
}
inline int qpow (register int a, register int b, register int ans = 1) {
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline void NTT (register int len, register int * a, register int opt) {
for (register int i = 1; i < len; i ++) if (i < rev[i]) swap (a[i], a[rev[i]]);
for (register int d = 1; d < len; d <<= 1) {
register int w1 = qpow (opt, (mod - 1) / (d << 1));
for (register int i = 0; i < len; i += d << 1) {
register int w = 1;
for (register int j = 0; j < d; j ++, w = 1ll * w * w1 % mod) {
register int x = a[i + j], y = 1ll * w * a[i + j + d] % mod;
a[i + j] = (x + y) % mod, a[i + j + d] = (x - y + mod) % mod;
}
}
}
}
inline void Calc (register int * a, register int * b) {
memset (res, 0, 4 * len), memset (tmp, 0, 4 * len);
for (register int i = 0; i < m; i ++) res[i] = a[i], tmp[i] = b[i], a[i] = 0;
NTT (len, res, 3), NTT (len, tmp, 3);
for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * tmp[i] % mod;
NTT (len, res, inv3);
register int inv = qpow (len, mod - 2);
for (register int i = 0; i < len; i ++) res[i] = 1ll * res[i] * inv % mod, a[i % (m - 1)] = (a[i % (m - 1)] + res[i]) % mod;
}
int main () {
n = read(), m = read(), X = read(), s = read(), g = Get_g (m);
for (register int i = 0; i <= m - 2; i ++) id[gqpow (g, i)] = i;
while (len < m << 1) len <<= 1, bit ++;
for (register int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
for (register int i = 1; i <= s; i ++) {
register int x = read();
if (x) f[id[x]] ++;
}
ans[0] = 1;
while (n) {
if (n & 1) Calc (ans, f);
n >>= 1, Calc (f, f);
}
printf ("%d
", ans[id[X]]);
return 0;
}