[BZOJ4765]普通计算姬
试题描述
"奋战三星期,造台计算机"。小G响应号召,花了三小时造了台普通计算姬。普通计算姬比普通计算机要厉害一些。普通计算机能计算数列区间和,而普通计算姬能计算树中子树和。更具体地,小G的计算姬可以解决这么个问题:给定一棵n个节点的带权树,节点编号为1到n,以root为根,设sum[p]表示以点p为根的这棵子树中所有节点的权值和。计算姬支持下列两种操作:
1 给定两个整数u,v,修改点u的权值为v。
2 给定两个整数l,r,计算sum[l]+sum[l+1]+....+sum[r-1]+sum[r]
尽管计算姬可以很快完成这个问题,可是小G并不知道它的答案是否正确,你能帮助他吗?
输入
第一行两个整数n,m,表示树的节点数与操作次数。
接下来一行n个整数,第i个整数di表示点i的初始权值。
接下来n行每行两个整数ai,bi,表示一条树上的边,若ai=0则说明bi是根。
接下来m行每行三个整数,第一个整数op表示操作类型。
若op=1则接下来两个整数u,v表示将点u的权值修改为v。
若op=2则接下来两个整数l,r表示询问。
N<=10^5,M<=10^5
0<=Di,V<2^31,1<=L<=R<=N,1<=U<=N
输出
对每个操作类型2输出一行一个整数表示答案。
输入示例
6 4 0 0 3 4 0 1 0 1 1 2 2 3 2 4 3 5 5 6 2 1 2 1 1 1 2 3 6 2 3 5
输出示例
16 10 9
数据规模及约定
见“输入”
题解
做法一:
我们先把 sum 数组求出来,然后节点 u 权值的修改对应着 u 到根节点这一条链的修改,于是可以树链剖分套数据结构完成这个操作;对于询问,它问的却是连续的一段编号;所以可以看成一个二维平面,每个节点 u 的坐标是 (dfs[u], u)(dfs[u] 表示节点 u 的树链剖分序),权值就是 sum[u];那么对于修改操作,就是 x 轴上连续的 log(n) 段,y 坐标没有限制;对于询问操作就是 y 轴上连续一段, x 坐标没限制;所以这个东西就可以用 kd 树维护了。
#include <iostream> #include <cstdio> #include <cstring> #include <cstdlib> #include <cctype> #include <algorithm> using namespace std; const int BufferSize = 1 << 16; char buffer[BufferSize], *Head, *Tail; inline char Getchar() { if(Head == Tail) { int l = fread(buffer, 1, BufferSize, stdin); Tail = (Head = buffer) + l; } return *Head++; } int read() { int x = 0, f = 1; char c = Getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); } return x * f; } #define maxn 100010 #define maxm 200010 #define UL unsigned long long #define LL long long int n, m, head[maxn], nxt[maxm], to[maxm], val[maxn]; void AddEdge(int a, int b) { to[++m] = b; nxt[m] = head[a]; head[a] = m; swap(a, b); to[++m] = b; nxt[m] = head[a]; head[a] = m; return ; } int rt, fa[maxn], siz[maxn], son[maxn], top[maxn], pos[maxn], ToT; UL sum[maxn]; void build(int u) { siz[u] = 1; sum[u] = val[u]; for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa[u]) { fa[to[e]] = u; build(to[e]); siz[u] += siz[to[e]]; sum[u] += sum[to[e]]; if(!son[u] || siz[son[u]] < siz[to[e]]) son[u] = to[e]; } return ; } void gett(int u, int tp) { top[u] = tp; pos[u] = ++ToT; if(son[u]) gett(son[u], tp); for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa[u] && to[e] != son[u]) gett(to[e], to[e]); return ; } int Rt, ch[maxn][2]; bool Cur; struct Node { int x[2], mx[2], mn[2], siz; UL sum; LL val, add; Node() {} Node(int x, int y, LL val): val(val), add(0) { this->x[0] = x; this->x[1] = y; } bool operator < (const Node& t) const { return x[Cur] != t.x[Cur] ? x[Cur] < t.x[Cur] : x[Cur^1] < t.x[Cur^1]; } } ns[maxn]; void maintain(int o) { for(int j = 0; j < 2; j++) ns[o].mx[j] = ns[o].mn[j] = ns[o].x[j]; ns[o].sum = ns[o].val; ns[o].siz = 1; for(int i = 0; i < 2; i++) if(ch[o][i]) { for(int j = 0; j < 2; j++) ns[o].mx[j] = max(ns[o].mx[j], ns[ch[o][i]].mx[j]), ns[o].mn[j] = min(ns[o].mn[j], ns[ch[o][i]].mn[j]); ns[o].sum += ns[ch[o][i]].sum; ns[o].siz += ns[ch[o][i]].siz; } ns[o].val += ns[o].add; ns[o].sum += ns[o].add * ns[o].siz; return ; } void build(int& o, int l, int r, bool cur) { if(l > r) return ; int mid = l + r >> 1; o = mid; Cur = cur; nth_element(ns + l, ns + mid, ns + r + 1); build(ch[o][0], l, mid - 1, cur ^ 1); build(ch[o][1], mid + 1, r, cur ^ 1); return maintain(o); } void pushdown(int o) { if(!ns[o].add) return ; for(int i = 0; i < 2; i++) if(ch[o][i]) { ns[ch[o][i]].add += ns[o].add; ns[ch[o][i]].val += ns[o].add; ns[ch[o][i]].sum += ns[o].add * ns[ch[o][i]].siz; } ns[o].add = 0; return ; } void upd(int o, int l, int r, int add) { pushdown(o); if(l <= ns[o].mn[0] && ns[o].mx[0] <= r) { ns[o].add += add; ns[o].val += add; ns[o].sum += (LL)add * ns[o].siz; return ; } if(l <= ns[o].x[0] && ns[o].x[0] <= r) ns[o].val += add; for(int i = 0; i < 2; i++) if(ch[o][i] && l <= ns[ch[o][i]].mx[0] && ns[ch[o][i]].mn[0] <= r) upd(ch[o][i], l, r, add); return maintain(o); } UL que(int o, int l, int r) { pushdown(o); if(l <= ns[o].mn[1] && ns[o].mx[1] <= r) return ns[o].sum; UL ans = (l <= ns[o].x[1] && ns[o].x[1] <= r) ? ns[o].val : 0; for(int i = 0; i < 2; i++) if(ch[o][i] && l <= ns[ch[o][i]].mx[1] && ns[ch[o][i]].mn[1] <= r) ans += que(ch[o][i], l, r); return ans; } void update(int u, int add) { while(u) upd(Rt, pos[top[u]], pos[u], add), u = fa[top[u]]; return ; } #define maxol 2100000 char Output[maxol]; int num[21], cnt, cntol; int main() { // freopen("common10.in", "r", stdin); // freopen("data.out", "w", stdout); n = read(); int q = read(); for(int i = 1; i <= n; i++) val[i] = read(); for(int i = 1; i <= n; i++) { int a = read(), b = read(); if(!a) rt = b; else AddEdge(a, b); } build(rt); gett(rt, rt); for(int i = 1; i <= n; i++) ns[i] = Node(pos[i], i, sum[i]); build(Rt, 1, n, 0); while(q--) { int tp = read(); if(tp == 1) { int u = read(), v = read(); update(u, v - val[u]); val[u] = v; } else { int l = read(), r = read(); UL tmp = que(Rt, l, r); cnt = 0; while(tmp) num[cnt++] = tmp % 10, tmp /= 10; for(int i = cnt - 1; i >= 0; i--) Output[cntol++] = num[i] + '0'; Output[cntol++] = ' '; } } Output[--cntol] = ' '; puts(Output); return 0; }
然而加了读入输出优化还是 T 飞。。。。。
解法二:
分块套分块。
我们先搞一个 dfs 序列,序列上存 val[i](即节点 i 的权值)。然后我们对这个序列分块,并维护两个信息:dfS[i] 表示位置 i 所在块的前缀和,dfSb[i] 表示前 i 个块的总和(即块的前缀和)。这样我们就可以 O(1) 询问区间和,O(sqrt(n)) 点修改了(想一想,为什么)。
然后我们再对正常编号的序列分块,并维护两个信息:tot[i][j] 表示第 i 块中所有 sum 使得对应 dfs 序列上位置 j 被计算了几次,Sb[i] 表示第 i 块中 sum 的总和。那么借助 tot 我们可以 O(n · sqrt(n)) 预处理 Sb,还可以 O(sqrt(n)) 支持点修改(想一想,为什么)。查询 [l, r] 时,对于被整个覆盖的块 i 直接累加 Sb[i] 就好了,对于没有被整个覆盖的块我们暴力找到这些点对应的 dfs 序上的区间,然后 O(1) 询问区间和,累加,就好了。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> #include <cmath> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 100010 #define maxm 200010 #define maxb 320 #define UL unsigned long long #define LL long long int n, m, head[maxn], nxt[maxm], to[maxm]; void AddEdge(int a, int b) { to[++m] = b; nxt[m] = head[a]; head[a] = m; swap(a, b); to[++m] = b; nxt[m] = head[a]; head[a] = m; return ; } int rt, clo, dl[maxn], dr[maxn], id[maxn], val[maxn]; void build(int u, int fa) { dl[u] = ++clo; id[clo] = u; for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa) build(to[e], u); dr[u] = clo; return ; } UL dfS[maxn], dfSb[maxb], Sb[maxb]; int tot[maxb][maxn], bl[maxn], st[maxn], en[maxn]; UL que(int l, int r) { return dfS[r] - (l > st[bl[l]] ? dfS[l-1] : 0) + dfSb[bl[r]-1] - dfSb[bl[l]-1]; } int main() { n = read(); int q = read(); for(int i = 1; i <= n; i++) val[i] = read(); for(int i = 1; i <= n; i++) { int a = read(), b = read(); if(!a) rt = b; else AddEdge(a, b); } build(rt, 0); int m = (int)sqrt(n); for(int i = 1; i <= n; i++) { bl[i] = (i - 1) / m + 1; if(!st[bl[i]]) st[bl[i]] = i; en[bl[i]] = i; dfS[i] = (bl[i-1] == bl[i] ? dfS[i-1] : 0) + (UL)val[id[i]]; if(bl[i] != bl[i-1]) dfSb[bl[i]] = dfSb[bl[i-1]]; dfSb[bl[i]] += val[id[i]]; } // for(int i = 1; i <= bl[n]; i++) printf("[%d, %d] ", st[i], en[i]); for(int i = 1; i <= bl[n]; i++) { for(int j = st[i]; j <= en[i]; j++) tot[i][dl[j]]++, tot[i][dr[j]+1]--; for(int j = 1; j <= n; j++) tot[i][j] += tot[i][j-1], Sb[i] += (UL)tot[i][j] * val[id[j]]; } while(q--) { int tp = read(); if(tp == 1) { int u = read(), v = read(), dv = v - val[u]; val[u] = v; for(int i = dl[u]; i <= en[bl[dl[u]]]; i++) dfS[i] += dv; for(int i = bl[dl[u]]; i <= bl[n]; i++) dfSb[i] += dv; for(int i = 1; i <= bl[n]; i++) Sb[i] += (UL)tot[i][dl[u]] * dv; } else { int l = read(), r = read(); UL ans = 0; if(bl[l] == bl[r]) for(int i = l; i <= r; i++) ans += que(dl[i], dr[i]); else { for(int i = bl[l] + 1; i < bl[r]; i++) ans += Sb[i]; for(int i = l; i <= en[bl[l]]; i++) ans += que(dl[i], dr[i]); for(int i = st[bl[r]]; i <= r; i++) ans += que(dl[i], dr[i]); } printf("%llu ", ans); } } return 0; }