Problem
Description
给出 (n) 个物品,第 (i) 个物品体积为 (a_i) 。
对于每个体积 (V) ,求选出 (3) 个物品,体积之和为 (V) 的方案总数。
选择顺序不同算同一种方案。
Range
(n) 保证不会读入到 (TLE) , (a_ile 4 imes 10^4) 。
Algorithm
多项式,生成函数。
Mentality
设生成函数 (A(x)) 为只选择一个物品的生成函数。其中 ([x^m]A(x)) 的系数代表了体积 (m) 有多少种选法。
同理设 (B(x)) 为选择两个相同物品的生成函数,设 (C(x)) 为选择三个相同物品的生成函数。
则对于最后的答案而言:
若选择的 (3) 个物品互不相同,则方案数为:
[frac{A^3(x)-3B(x)A(x)+2C(x)}{6}
]
因为根据容斥,(A^3(x)) 等于所有选择三个物品的方案数,(B(x)A(x)) 则是所有形如 ((a, a, b)) 的方案数,由于这种方案在 (A^3(x)) 会出现三次,所以要乘 (3) ,然后对于所有 ((a,a,a)) ,也即生成函数 (C(x)) 在 (B(x)A(x)) 中出现了 (3) 次,但实际上在 (A^3(x)) 只会被计算一次,所以还要加回 (2) 个来。
若选择 (2) 个物品,那么方案为:
[frac{A^2(x)-B(x)}{2}
]
这个很好理解。
选择一个物品的方案自然就是 (A(x)) 了。
(FFT) 即可。
Code
#include <cmath>
#include <complex>
#include <cstdio>
#include <iostream>
using namespace std;
#define LL long long
#define cp complex<double>
#define inline __inline__ __attribute__((always_inline))
inline LL read() {
LL x = 0, w = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-') w = -1;
ch = getchar();
}
while (isdigit(ch)) {
x = (x << 3) + (x << 1) + ch - '0';
ch = getchar();
}
return x * w;
}
const int Max_n = 4e5 + 5, Ml = 1.2e5;
const double pi = acos(-1);
cp ans[Max_n], A[Max_n], B[Max_n], C[Max_n];
namespace Input {
void main() {
int n = read();
for (int i = 1, x; i <= n; i++)
x = read(), A[x] += 1, B[x * 2] += 1, C[x * 3] += 1;
}
} // namespace Input
namespace Solve {
int bit, len, rev[Max_n];
void init() {
int bit = log2(Ml + 1) + 1;
len = 1 << bit;
for (int i = 0; i < len; i++)
rev[i] = rev[i >> 1] >> 1 | ((i & 1) << (bit - 1));
}
void dft(cp *f, int t) {
for (int i = 0; i < len; i++)
if (i < rev[i]) swap(f[i], f[rev[i]]);
for (int l = 1; l < len; l <<= 1) {
cp Wn(cos(t * pi / (double)l), sin(t * pi / (double)l));
for (int i = 0; i < len; i += (l << 1)) {
cp Wnk(1, 0);
for (int k = i; k < i + l; k++, Wnk *= Wn) {
cp x = f[k], y = f[k + l] * Wnk;
f[k] = x + y, f[k + l] = x - y;
}
}
}
}
void main() {
init();
dft(A, 1), dft(B, 1), dft(C, 1);
for (int i = 0; i < len; i++) {
ans[i] = (A[i] * A[i] * A[i] - A[i] * B[i] * 3.0 + 2.0 * C[i]) / 6.0;
ans[i] += (A[i] * A[i] - B[i]) / 2.0 + A[i];
}
dft(ans, -1);
for (int i = 0; i <= Ml; i++) ans[i] /= (double)len;
for (int i = 0; i <= Ml; i++) {
LL Ans = (LL)(ans[i].real() + 0.5);
if (Ans) printf("%d %lld
", i, Ans);
}
}
} // namespace Solve
int main() {
Input::main();
Solve::main();
}