题意:一棵 (n) 个节点树,有三种对 ((u,v)) 路径上的点的权的操作和一个询问。
操作1:将 ((u,v)) 路径上的点的权置为 (w)。
操作2:将 ((u,v)) 路径上的点的权加上 (w)。
操作3:将 ((u,v)) 路径上的点的权乘上 (w)。
询问:求 ((u,v)) 路径上的点的权的立方和。
题解:裸树剖+线段树。
维护立方和:设区间和、平方和、立方和分别为 (a,a^2,a^3)。当加上一个数 (w),和为 (a + len_{区间长} imes w),平方和为 ((a+w)^2),立方和为 ((a+w)^3)。
将平方和展开即 (a^2+2aw+w^2),若区间和、平方和、立方和分别用 (sum_1[rt],sum_2[rt],sum_3[rt]) (这些值暂时还没有加上(w))维护,
那么平方和 ((a+w)^2=sum_2[rt]+2 imes sum_1[rt] imes w+w imes w imes len_{区间长})。
同理,立方和即 (a^3+3(a^2w+aw^2)+w^3),那么立方和 ((a+w)^3=sum_3[rt]+3 imes(sum_2[rt] imes w+sum_1[rt] imes w^2)+w^3 imes len_{区间长})。
所以这样就可以在线段树pushdown的时候维护立方和了,注意要先更新立方和再更新平方和再更新区间和。以及注意标记之间的先后顺序即可。
#include <bits/stdc++.h>
using namespace std;
#define debug(x) cerr << #x << " is " << x << '
';
typedef long long LL;
#define int long long
const int N = 1e5 + 5;
const int P = 1e9 + 7;
int n, q, opt, x, y, z, cas;
LL sum1[N << 2], sum2[N << 2], sum3[N << 2], tag[N << 2], mul[N << 2], add[N << 2];
int cnt, tot, w[N], nw[N], id[N], top[N], dep[N], pre[N], son[N], siz[N], head[N];
struct Graph {
int v, next;
} edge[N << 1];
void addedge(int u, int v) {
edge[++cnt].v = v;
edge[cnt].next = head[u];
head[u] = cnt;
}
void init() {
tot = cnt = 0;
memset(head, 0, sizeof head);
memset(siz, 0, sizeof siz);
memset(nw, 0, sizeof nw);
memset(top, 0, sizeof top);
memset(pre, 0, sizeof pre);
memset(son, 0, sizeof son);
memset(dep, 0, sizeof dep);
memset(id, 0, sizeof id);
}
#define ls rt << 1
#define rs rt << 1 | 1
void pushup(int rt) {
sum1[rt] = (sum1[ls] + sum1[rs]) % P;
sum2[rt] = (sum2[ls] + sum2[rs]) % P;
sum3[rt] = (sum3[ls] + sum3[rs]) % P;
}
void downMul(int rt, int z) {
sum3[rt] = sum3[rt] * z % P * z % P * z % P;
sum2[rt] = sum2[rt] * z % P * z % P;
sum1[rt] = sum1[rt] * z % P;
mul[rt] = mul[rt] * z % P;
add[rt] = add[rt] * z % P;
}
void downAdd(int rt, int len, int z) {
sum3[rt] = (1LL * sum3[rt] % P + 3LL * (sum2[rt] * z % P + 1LL * sum1[rt] * z % P * z % P) % P + 1LL * len * z % P * z % P * z % P) % P;
sum2[rt] = (1LL * sum2[rt] % P + 2LL * sum1[rt] % P * z % P + 1LL * len * z % P * z % P) % P;
sum1[rt] = (sum1[rt] + 1LL * len * z % P) % P;
add[rt] = (add[rt] + z) % P;
}
void pushdown(int rt, int l, int r, int mid) {
if (tag[rt] == 1) {
sum1[ls] = sum2[ls] = sum3[ls] = 0;
sum1[rs] = sum2[rs] = sum3[rs] = 0;
mul[ls] = mul[rs] = 1;
add[ls] = add[rs] = 0;
tag[ls] = tag[rs] = 1;
tag[rt] = 0;
}
if (mul[rt] != 1) {
downMul(ls, mul[rt]);
downMul(rs, mul[rt]);
mul[rt] = 1;
}
if (add[rt] != 0) {
downAdd(ls, mid - l + 1, add[rt]);
downAdd(rs, r - mid, add[rt]);
add[rt] = 0;
}
}
void build(int rt, int l, int r) {
add[rt] = tag[rt] = 0;
mul[rt] = 1;
if (l == r) {
sum1[rt] = nw[l] % P;
sum2[rt] = nw[l] * nw[l] % P;
sum3[rt] = sum2[rt] * nw[l] % P;
return ;
}
int mid = l + r >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
pushup(rt);
}
void updateAdd(int rt, int l, int r, int x, int y, int z) {
if (x <= l && r <= y) {
sum3[rt] = (1LL * sum3[rt] % P + 3LL * (sum2[rt] * z % P + 1LL * sum1[rt] * z % P * z % P) % P + 1LL * (r - l + 1) * z % P * z % P * z % P) % P;
sum2[rt] = (1LL * sum2[rt] % P + 2LL * sum1[rt] % P * z % P + 1LL * (r - l + 1) * z % P * z % P) % P;
sum1[rt] = (sum1[rt] + 1LL * (r - l + 1) * z % P) % P;
add[rt] = (add[rt] + z) % P;
return ;
}
int mid = l + r >> 1;
pushdown(rt, l, r, mid);
if (x <= mid) updateAdd(ls, l, mid, x, y, z);
if (y > mid) updateAdd(rs, mid + 1, r, x, y, z);
pushup(rt);
}
void updateMul(int rt, int l, int r, int x, int y, int z) {
if (x <= l && r <= y) {
sum3[rt] = sum3[rt] * z % P * z % P * z % P;
sum2[rt] = sum2[rt] * z % P * z % P;
sum1[rt] = sum1[rt] * z % P;
mul[rt] = mul[rt] * z % P;
add[rt] = add[rt] * z % P;
return ;
}
int mid = l + r >> 1;
pushdown(rt, l, r, mid);
if (x <= mid) updateMul(ls, l, mid, x, y, z);
if (y > mid) updateMul(rs, mid + 1, r, x, y, z);
pushup(rt);
}
void updateTag(int rt, int l, int r, int x, int y, int z) {
if (x <= l && r <= y) {
sum1[rt] = 1LL * (r - l + 1) * z % P;
sum2[rt] = 1LL * (r - l + 1) * z % P * z % P;
sum3[rt] = 1LL * (r - l + 1) * z % P * z % P * z % P;
tag[rt] = 1, mul[rt] = 1, add[rt] = z;
return ;
}
int mid = l + r >> 1;
pushdown(rt, l, r, mid);
if (x <= mid) updateTag(ls, l, mid, x, y, z);
if (y > mid) updateTag(rs, mid + 1, r, x, y, z);
pushup(rt);
}
LL query(int rt, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return sum3[rt];
}
int mid = l + r >> 1;
LL ans = 0;
pushdown(rt, l, r, mid);
if (x <= mid) ans = (ans + query(ls, l, mid, x, y)) % P;
if (y > mid) ans = (ans + query(rs, mid + 1, r, x, y)) % P;
return ans;
}
void dfs1(int u, int f, int d) {
pre[u] = f, dep[u] = d, siz[u] = 1;
int mx = -1;
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].v;
if (v == f) continue;
dfs1(v, u, d + 1);
siz[u] += siz[v];
if (siz[v] > mx) mx = siz[v], son[u] = v;
}
}
void dfs2(int u, int topf) {
id[u] = ++tot, nw[tot] = w[u], top[u] = topf;
if (!son[u]) return ;
dfs2(son[u], topf);
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].v;
if (v == pre[u] || v == son[u]) continue;
dfs2(v, v);
}
}
void treeAdd(int u, int v, int z) {
while (top[u] ^ top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
updateAdd(1, 1, n, id[top[u]], id[u], z);
u = pre[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
updateAdd(1, 1, n, id[u], id[v], z);
}
void treeMul(int u, int v, int z) {
while (top[u] ^ top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
updateMul(1, 1, n, id[top[u]], id[u], z);
u = pre[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
updateMul(1, 1, n, id[u], id[v], z);
}
void treeTag(int u, int v, int z) {
while (top[u] ^ top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
updateTag(1, 1, n, id[top[u]], id[u], z);
u = pre[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
updateTag(1, 1, n, id[u], id[v], z);
}
LL TreeQuery(int u, int v) {
LL ans = 0;
while (top[u] ^ top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
ans = (ans + query(1, 1, n, id[top[u]], id[u])) % P;
u = pre[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
ans = (ans + query(1, 1, n, id[u], id[v])) % P;
return ans;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int T;
cin >> T;
while (T--) {
cin >> n;
init();
for (int i = 1, u, v; i <= n - 1; i++) {
cin >> u >> v;
addedge(u, v), addedge(v, u);
}
for (int i = 1; i <= n; i++) cin >> w[i];
dfs1(1, -1, 1);
dfs2(1, 1);
build(1, 1, n);
cin >> q;
cout << "Case #" << ++cas << ":" << '
';
while (q--) {
cin >> opt;
if (opt == 1) {
cin >> x >> y >> z;
treeTag(x, y, z);
} else if (opt == 2) {
cin >> x >> y >> z;
treeAdd(x, y, z);
} else if (opt == 3) {
cin >> x >> y >> z;
treeMul(x, y, z);
} else if (opt == 4) {
cin >> x >> y;
cout << TreeQuery(x, y) << '
';
}
}
}
return 0;
}