HDU-4609 3-idiots FFT
题意
给定(n) 根木棍,每根木棍有一长度(a[i]) ,问任意选出三根木棍,可以组成三角形的概率
[T leq 100,3leq n leq 10^5 ,1leq a_i leq 10^5
]
分析
此题属于(FFT)入门模板题
这题刚开始看有点像是生成函数(好像就是
如果暴力枚举两根可能组成的长度,显然(n^2) 的复杂度不够
我们可以把问题抽象成两个多项式相乘,最后乘出来的指数或者说第几项就可以表示出能够组成的长度,系数就表示个数。
事实上就是在求卷积,也就是(FFT) 的工作。
注意实现上要删掉两个相同的组成,以及对剩下的除以2
特别注意(FFT) 传入的(len) 必须是(2^k)
代码
struct Complex {
double x, y;
Complex(double _x = 0.0, double _y = 0.0) {
x = _x;
y = _y;
}
Complex operator - (const Complex& b) const {
return Complex(x - b.x, y - b.y);
}
Complex operator + (const Complex& b) const {
return Complex(x + b.x, y + b.y);
}
Complex operator * (const Complex& b) const {
return Complex(x * b.x - b.y * y, x * b.y + y * b.x);
}
};
void change(Complex y[], int len) {
int k;
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) swap(y[i], y[j]);
k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(Complex y[], int len, int on) {
change(y, len);
for (int h = 2; h <= len; h <<= 1) {
Complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
for (int j = 0; j < len; j += h) {
Complex w(1, 0);
for (int k = j; k < j + h / 2; k++) {
Complex u = y[k];
Complex t = w * y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1) {
for (int i = 0; i < len; i++)
y[i].x /= len;
}
}
Complex x[maxn << 2];
ll num[maxn << 2];
ll sum[maxn << 2];
int a[maxn];
int main() {
int T = readint();
int n;
while (T--) {
n = readint();
memset(num, 0, sizeof num);
for (int i = 0; i < n; i++) a[i] = readint(), num[a[i]]++;
sort(a, a + n);
int len1 = a[n - 1] + 1;
int len = 1;
while (len < 2 * len1) len <<= 1;
for (int i = 0; i < len1; i++) x[i] = Complex(num[i], 0);
for (int i = len1; i < len; i++) x[i] = Complex(0, 0);
fft(x, len, 1);
for (int i = 0; i < len; i++) x[i] = x[i] * x[i];
fft(x, len, -1);
for (int i = 0; i < len; i++) num[i] = (ll)(x[i].x + 0.5);
len = 2 * a[n - 1];
for (int i = 0; i < n; i++) num[a[i] + a[i]]--;
for (int i = 1; i <= len; i++) num[i] /= 2;
sum[0] = 0;
for (int i = 1; i <= len; i++) sum[i] = sum[i - 1] + num[i];
ll cnt = 0;
for (int i = 0; i < n; i++) cnt += sum[a[i]];
ll tot = (ll)n * (n - 1) * (n - 2) / 6;
printf("%.7f
", 1.0 - 1.0 * cnt / tot);
}
}