显然可以想到$n^3$的暴力,就是枚举端点,然后计算区间内不同的颜色的数量。
考虑用set维护这个区间,然后增量插入,可以把复杂度降到$n^2log{n}$,但还是太慢。
枚举右端点r,考虑在log的复杂度内计算出$f(l,r),l in [1,r]$。
我们记pos=右端点的颜色上一次出现的位置+1,如果没有出现过则记为1,这一点可以通过链表O(n)求出来。
然后我们发现只有[pos,r]这个区间内的l的f(l,r)会较f(l,r-1)加一,区间维护一下平方再区间修改区间查询,这就是线段树裸题了。
注意要离散化否则无法开桶。
#include <bits/stdc++.h> #include <iostream> #include <cstdio> #include <algorithm> #include <set> using namespace std; typedef long long ll; const ll N = 1e6 + 10; const ll mod = 1e9 + 7; struct Segment{ ll sum, sq, add; }tr[N << 2]; ll head[N], pre[N]; ll n, a[N], b[N], nn; ll ans; ll read() { ll ret = 0, f = 1; char ch = getchar(); while ('0' > ch || ch > '9'){ if (ch == '-') f = -1; ch = getchar(); } while ('0' <= ch && ch <= '9') { ret = (ret << 1) + (ret << 3) + ch - '0'; ch = getchar(); } return ret * f; } void push_up(ll p) { tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum; tr[p].sq = tr[p << 1].sq + tr[p << 1 | 1].sq; } void push_down(ll p, ll l, ll r) { ll mid = (l + r) >> 1; if (tr[p].add) { tr[p << 1].add += tr[p].add; tr[p << 1].sq += tr[p].add * tr[p].add * (mid - l + 1) + 2 * tr[p << 1].sum * tr[p].add; tr[p << 1].sum += (mid - l + 1) * tr[p].add; tr[p << 1 | 1].add += tr[p].add; tr[p << 1 | 1].sq += tr[p].add * tr[p].add * (r - mid) + 2 * tr[p << 1 | 1].sum * tr[p].add; tr[p << 1 | 1].sum += (r - mid) * tr[p].add; tr[p].add = 0; } } void change(ll p, ll l, ll r, ll x, ll y) { if (r < x || l > y) return; if (x <= l && r <= y) { tr[p].add++; tr[p].sq += r - l + 1 + 2 * tr[p].sum; tr[p].sum += r - l + 1; return; } push_down(p, l, r); ll mid = (l + r) >> 1; change(p << 1, l, mid, x, y); change(p << 1 | 1, mid + 1, r, x, y); push_up(p); } int main() { // freopen("sequence.in", "r", stdin); // freopen("sequence.out", "w", stdout); n = read(); for (int i = 1; i <= n; i++) { a[i] = read(); b[++nn] = a[i]; } sort(b + 1, b + 1 + nn); nn = unique(b + 1, b + 1 + nn) - b - 1; for (int i = 1; i <= n; i++) { a[i] = lower_bound(b + 1, b + 1 + nn, a[i]) - b; } for (int i = 1; i <= n; i++) { pre[i] = head[a[i]]; head[a[i]] = i; } for (int i = 1; i <= n; i++) { ll pos = pre[i] + 1; change(1, 1, n, pos, i); ans = (ans + tr[1].sq) % mod; } cout << ans; return 0; }