Description
Given (N) integers in the range ([-50\, 000, 50\, 000]), how many ways are there to pick three integers (a_ i), (a_ j), (a_ k), such that (i), (j), (k) are pairwise distinct and (a_ i + a_ j = a_ k)? Two ways are different if their ordered triples ((i, j, k)) of indices are different.
Input
The first line of input consists of a single integer (N) ((1 leq N leq 200\, 000)). The next line consists of (N) space-separated integers (a_1, a_2, dots , a_ N).
Output
Output an integer representing the number of ways.
Sample Input 1
4
1 2 3 4
Sample Output 1
4
Sample Input 2
6
1 1 3 3 4 6
Sample Output 2
10
题意
给出n个数字,范围从([-50\, 000, 50\, 000]),问有多少组不同的((i, j, k))满足(a_ i + a_ j = a_ k)
题解
在多项式中,两项相乘,指数相加,所以我们可以用fft来解决此题.
首先,不考虑i,j,k必须不同,那么对于每个数,我们先加上50000使其变为正数,并把这一项指数对应的系数+1,这样我们对这个数组做一次卷积,遍历一遍原数组即可得到答案.
由于(i,j,k)必须不同,我们要减去自己和自己对答案产生贡献的,开一个数组(b),(b[(x+M)*2])也就是在计算答案时要被减去的自己和自己相加的.
由于可能存在0,我们还要考虑0的影响,设0的个数为cnt0个,那么0+0=0和0+ai=ai都被重复计算了,总共多计算了(2*cnt0*(cnt0-1)+2*cnt0*(n-cnt0))次,(第一个乘以2表示等号右边可以选择两个0中任意一个,第二个乘以2表示ai和0可以交换)所以要减去(2*cnt0*(n-1))
代码
#include <bits/stdc++.h>
using namespace std;
const double pi = acos(-1.0);
const int N = 1e6 + 50;
typedef long long ll;
struct cp {
double r, i;
cp(double r = 0, double i = 0): r(r), i(i) {}
cp operator + (const cp &b) {
return cp(r + b.r, i + b.i);
}
cp operator - (const cp &b) {
return cp(r - b.r, i - b.i);
}
cp operator * (const cp &b) {
return cp(r * b.r - i * b.i, r * b.i + i * b.r);
}
};
void change(cp a[], int len) {
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) swap(a[i], a[j]);
int k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(cp a[], int len, int op) {
change(a, len);
for (int h = 2; h <= len; h <<= 1) {
cp wn(cos(-op * 2 * pi / h), sin(-op * 2 * pi / h));
for (int j = 0; j < len; j += h) {
cp w(1, 0);
for (int k = j; k < j + h / 2; k++) {
cp u = a[k];
cp t = w * a[k + h / 2];
a[k] = u + t;
a[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (op == -1) {
for (int i = 0; i < len; i++) {
a[i].r /= len;
}
}
}
const int M = 50000;
ll num[N];
ll cnt[N];
cp a[N], b[N];
ll ans[N];
int main() {
int n;
scanf("%d", &n);
ll cnt0 = 0;
ll len1 = 0;
for (int i = 0; i < n; i++) {
scanf("%lld", &num[i]);
if (num[i] == 0) cnt0++;
cnt[num[i] + M]++;
len1 = max(len1, num[i] + M + 1);
}
//printf("%lld
", len1);
ll len = 1;
while (len < 2 * len1) len <<= 1;//必须补两倍
//printf("%lld
", len);
for (int i = 0; i < len1; i++) {
a[i] = cp(cnt[i], 0);
}
for (int i = len1; i < len; i++) {
a[i] = cp(0, 0);
}
fft(a, len, 1);
for (int i = 0; i < len; i++) {
a[i] = a[i] * a[i];
}
fft(a, len, -1);
for (int i = 0; i < len; i++) {
ans[i] = (ll)(a[i].r + 0.5);
}
for (int i = 0; i < n; i++) {//删去自己和自己的
ans[(num[i] + M) * 2]--;
}
ll res = 0;
for (int i = 0; i < n; i++) {
res += ans[num[i] + M * 2];
}
res -= 2 * cnt0 * (n - 1);
printf("%lld
", res);
return 0;
}