题意:给定线段树,上面若干个节点坏了,求能表示出多少区间。
区间能被表示出当且仅当拆出来的log个节点都是好的。
解:每个区间在最浅的节点处计算答案。
对于每个节点维护从左边过来能有多少区间,从右边过来能有多少区间即可。
说起来水的一批为什么会被评成黑题啊。
1 #include <bits/stdc++.h> 2 3 typedef long long LL; 4 const int N = 20000010; 5 const LL MO = 998244353; 6 7 LL lc[N], rc[N], ans; 8 int tot, ls[N], rs[N]; 9 bool vis[N]; 10 11 void add(LL L, LL R, LL l, LL r, int &o) { 12 //printf("%lld %lld %lld %lld %d ", L, R, l, r, o); 13 if(!o) o = ++tot; 14 if(L <= l && r <= R) { 15 vis[o] = 1; 16 return; 17 } 18 LL mid = (l + r) >> 1; 19 if((r - l + 1) & 1) mid--; 20 if(L <= mid) add(L, R, l, mid, ls[o]); 21 if(mid < R) add(L, R, mid + 1, r, rs[o]); 22 } 23 24 void cal(LL l, LL r, int &o) { 25 if(!o) { 26 o = ++tot; 27 LL len = (r - l + 1) % MO; 28 ans = (ans + len * (len + 1) / 2) % MO; 29 //printf("[%lld %lld] ans += %lld = %lld ", l, r, len * (len + 1) / 2, ans); 30 lc[o] = rc[o] = (len - 1 + MO) % MO; 31 return; 32 } 33 if(l == r) { 34 ans += (!vis[o]); 35 //if(!vis[o]) printf("[%lld %lld] ans++ = %lld ", l, r, ans); 36 return; 37 } 38 LL mid = (l + r) >> 1; 39 if((r - l + 1) & 1) mid--; 40 cal(l, mid, ls[o]); 41 cal(mid + 1, r, rs[o]); 42 /// cal 43 ans = (ans + lc[rs[o]] * rc[ls[o]] % MO) % MO; 44 //printf("[%lld %lld] ans += %lld * %lld = %lld ", l, r, lc[rs[o]], rc[ls[o]], ans); 45 if(!vis[ls[o]]) { 46 ans = (ans + lc[rs[o]]) % MO; 47 //printf("[%lld %lld] ans += %lld = %lld ", l, r, lc[rs[o]], ans); 48 } 49 if(!vis[rs[o]]) { 50 ans = (ans + rc[ls[o]]) % MO; 51 //printf("[%lld %lld] ans += %lld = %lld ", l, r, rc[ls[o]], ans); 52 } 53 if(!vis[o]) ans++; 54 lc[o] = lc[ls[o]]; 55 if(!vis[ls[o]]) { 56 lc[o] = (lc[o] + lc[rs[o]] + 1) % MO; 57 } 58 rc[o] = rc[rs[o]]; 59 if(!vis[rs[o]]) { 60 rc[o] = (rc[o] + rc[ls[o]] + 1) % MO; 61 } 62 //printf("l = %lld r = %lld lc = %lld rc = %lld ", l, r, lc[o], rc[o]); 63 return; 64 } 65 66 int main() { 67 68 //printf("%d", sizeof(lc) * 3 / 1048576); 69 70 LL n; int m, root = 0; 71 scanf("%lld%d", &n, &m); 72 for(int i = 1; i <= m; i++) { 73 LL l, r; 74 scanf("%lld%lld", &l, &r); 75 add(l, r, 1ll, n, root); 76 } 77 cal(1, n, root); 78 printf("%lld ", ans); 79 return 0; 80 }