题目大意:
给定长度为 (n) 的数组 (S),你需要统计有多少个四元组 ((a,b,c,d)) 满足:(1leq a<bleq n;1leq c<dleq n;S_a<S_b;S_c>S_d),且 (a,b,c,d) 互不相等。
正文:
利用容斥的思想,把所有 (S_a<S_b,S_c>S_d) 的情况全部求出来,但是有些情况是不合法的,如 (a=c,a=d,b=c,b=d),将这些方法减去。关于如何把 (S_a<S_b,S_c>S_d) 的情况全部求出来,可以用树状数组实现,在此之前还要离散化。
代码:
struct node
{
ll val, i;
}s[N];
ll c[N], pos[N], tot;
ll ls, rs, l[N], sl[N], r[N], sr[N], ans;
void add(ll x)
{
for (; x <= n; x += x & -x) c[x]++;
}
ll ask(ll x)
{
ll ans = 0;
for (; x; x -= x & -x) ans += c[x];
return ans;
}
bool cmp (node a, node b)
{
return a.val > b.val;
}
int main()
{
scanf ("%lld", &n);
for (int i = 1; i <= n; i++)
scanf ("%lld", &s[i].val), s[i].i = i;
s[0].val = -1;
sort (s + 1, s + 1 + n, cmp);
for (int i = 1; i <= n; i++)
if(s[i].val != s[i - 1].val)
pos[s[i].i] = ++tot;
else pos[s[i].i] = tot;
for (int i = 1; i <= n; i++)
{
sl[i] = ask(pos[i] - 1);
l[i] = ask(n) - ask(pos[i]);
ls += l[i];
add(pos[i]);
}
memset(c, 0, sizeof c);
for (int i = n; i >= 1; i--)
{
sr[i] = ask(pos[i] - 1);
r[i] = ask(n) - ask(pos[i]);
rs += r[i];
add(pos[i]);
}
ans = ls * rs;
for (int i = 1; i <= n; i++)
ans -= l[i] * r[i] + sl[i] * sr[i] + l[i] * sl[i] + r[i] * sr[i];
printf("%lld", ans);
return 0;
}