题目:传送门
题意
给出一个长度为 n 的排列 ai
规定一个区间 [l,r] 是 fair 的,当且仅当区间最小值等于 l,最大值等于 r;
求 fair 区间的个数。
1 <= n <= 1e6
思路
对于每个 i,若 ai <= i,则表示 i 这个点可以作为某些 fair 区间的右端点,那么就找 l[i],表示 i 的左边第一个满足 a[idx] > i 的 idx + 1;
对于每个 i,若 ai >= i,则表示 i 这个点可以作为某些 fair 区间的左端点,那么就找 r[i],表示 i 的右边第一个满足 a[idx] < i 的 idx - 1;
那么对于 i 这个点来说,它能作为 fair 区间的左端点的合法区间就是 [i, r[i]]; 能作为 fair 区间的右端点的合法区间就是 [l[i], i];
那我们可以先把能作为 fair 区间的左端点的合法区间存起来;
假设 i 这个点能作为 fair 区间的左端点,且合法区间为 [l, r],那我们可以存 { l, i, 1 } 和 { r, i, -1 } 这两个三元组,按照第一元素从小到大排序
然后我们可以枚举 i:1~n,计算 i 这个点作为 fair 区间的右端点时,满足条件的左端点的个数,累加起来就是答案了。
#include <bits/stdc++.h> #define LL long long #define rep(i, j, k) for(int i = j; i <= k; i++) #define dep(i, j, k) for(int i = k; i >= j; i--) #define pb push_back #define make make_pair #define lb(x) (x & (-x)) using namespace std; const int N = 1e6 + 5; int n, a[N], pos[N], l[N], r[N], t[N], cnt; struct note { int val, id, x; }tmp[N * 2]; bool cmp(note a, note b) { return a.val < b.val; } void add(int pos, int x) { for(int i = pos; i <= n; i += lb(i)) t[i] += x; } int sum(int pos) { int res = 0; for(int i = pos; i ; i -= lb(i)) { res += t[i]; } return res; } void solve() { scanf("%d", &n); rep(i, 1, n) scanf("%d", &a[i]), pos[a[i]] = i; /// l[i] 左边第一个比 i 大的 a[idx] 的上一位,即 l[i] = idx + 1 rep(i, 1, n) { l[i] = i; if(a[i] <= i) while(l[i] > 1 && a[l[i] - 1] <= i) l[i] = l[l[i] - 1]; /// 因为要以 i 为右端点 } /// r[i] 右边第一个比 i 小的 a[idx] 的后一位,即 r[i] = idx - 1; dep(i, 1, n) { r[i] = i; if(a[i] >= i) while(r[i] < n && a[r[i] + 1] >= i) r[i] = r[r[i] + 1]; } rep(i, 1, n) { /// 把能作为 fair 区间左端点的合法区间存起来 if(pos[i] >= i && l[i] <= pos[i] && pos[i] <= r[i]) { tmp[++cnt] = {pos[i], i, 1}; tmp[++cnt] = {r[i] + 1, i, -1}; } } sort(tmp + 1, tmp + 1 + cnt, cmp); /// 按第一元素从小到大排序; int now = 1; LL ans = 0LL; rep(i, 1, n) { while(now <= cnt && tmp[now].val <= i) { /// tmp[now].val > i,则就不能作为 以i这个点为右端点的fair区间的左端点了 add(tmp[now].id, tmp[now].x); now++; } if(l[i] <= pos[i] && pos[i] <= i) { /// 能作为 fair 区间的右端点; /// 那就算满足的左端点,若左端点 > pos[i] 那 pos[i] 没被选上,当然不行,故就算区间 [l[i], pos[i]] 的左端点的个数即为满足条件的左端点的个数 ans += sum(pos[i]) - sum(l[i] - 1); } } printf("%lld ", ans); } int main() { // int _; scanf("%d", &_); // while(_--) solve(); solve(); return 0; }