我们找到每个要被删的数字左边和右边第一个比它小的没被删的数字的位置。然后从小到大枚举要被删的数, 求答案。
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ull unsigned long long using namespace std; const int N = 1e6 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; int n, k, top, p[N], b[N]; int L[N], R[N]; bool ban[N]; PII stk[N]; LL ans; vector<int> vc; struct Bit { int a[N]; inline void modify(int x, int v) { for(int i = x; i < N; i += i & -i) a[i] += v; } inline int sum(int x) { int ans = 0; for(int i = x; i; i -= i & -i) ans += a[i]; return ans; } inline int query(int L, int R) { if(L > R) return 0; return sum(R) - sum(L - 1); } } bit; bool cmp(const int& a, const int& b) { return p[a] < p[b]; } int main() { scanf("%d%d", &n, &k); for(int i = 1; i <= n; i++) bit.modify(i, 1); for(int i = 1; i <= n; i++) scanf("%d", &p[i]); for(int i = 1; i <= k; i++) scanf("%d", &b[i]), ban[b[i]] = true; for(int i = 1; i <= n; i++) { if(ban[p[i]]) { while(top && stk[top].fi > p[i]) top--; stk[++top] = mk(p[i], i); } else { int pos = lower_bound(stk + 1, stk + top + 1, mk(p[i], 0)) - stk - 1; L[i] = pos ? stk[pos].se : 0; } } top = 0; for(int i = n; i >= 1; i--) { if(ban[p[i]]) { while(top && stk[top].fi > p[i]) top--; stk[++top] = mk(p[i], i); } else { int pos = lower_bound(stk + 1, stk + top + 1, mk(p[i], 0)) - stk - 1; R[i] = pos ? stk[pos].se : n + 1; } } for(int i = 1; i <= n; i++) if(!ban[p[i]]) vc.push_back(i); sort(vc.begin(), vc.end(), cmp); for(auto& x : vc) { ans += bit.query(L[x] + 1, R[x] - 1); bit.modify(x, -1); } printf("%lld ", ans); return 0; } /* */