题目大意:
一棵无根树,定义度数为1点为叶子节点,求所有两个叶子之间距离的平方和,树边上有边权。
样例:
4 1 4 1 4 3 1 2 4 1 12 4 1 2 3 1 4 2 4 3 1 36 5 1 2 1 1 3 4 2 4 3 2 5 2 138 10 1 2 10 10 2 7 3 2 8 3 9 3 9 8 2 7 9 1 6 4 3 4 5 2 3 4 4 4709
(看到类似的题解才有的思路)
其实这道题的难点在于平方和,如果去掉和这道题就很简单了,换根dp就行了。
然后发现其实平方和拆开来也可以换根。
先不考虑叶子,我们直接求每个点到所有叶子的距离平方和。
假设$dis[x]$表示$x$到当前点处理到的点的距离。
那么当前答案为$sum dis[i]^2$
往一个子树跑的时候,子树内的$dis$会减去边长,子树外的$dis$会加上边长。
于是和式变为$sum_{i in tree}(dis[i] - edge[i])^2 + sum _{i not in tree}(dis[i] + edge[i]) ^ 2$
把这个式子拆开,就变成$sum (dis[i])^2 - (2 * edge[i] * sum_{i in tree}dis[i]) + (2 * edge[i] * sum _{i not in tree}dis[i]) + (edge[i] * edge[i] * totcnt)$
$sum (dis[i])^2 $是当前节点的答案,$sum_{i in tree}dis[i], sum _{i not in tree}dis[i]$都可以通过预处理预处理出来。
然后碰到当前点是叶子点直接统计答案就可以了。
#include <bits/stdc++.h> #define int long long #define Mid ((l + r) / 2) #define lson (rt << 1) #define rson (rt << 1 | 1) using namespace std; int read() { char c; int num, f = 1; while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0'; while(c = getchar(), isdigit(c)) num = num * 10 + c - '0'; return f * num; } const int N = 2e5 + 1009; int n, in[N], rt, cntl, totdis[N], revtotdis[N], cnt[N], ans; int head[N], nxt[N], ver[N], edge[N], tot = 1; void add(int x, int y, int w) { ver[++tot] = y; nxt[tot] = head[x]; head[x] = tot; edge[tot] = w; } void dfs(int x, int pre, int d) { totdis[x] = 0; cnt[x] = in[x] == 1; if(in[x] == 1) ans += d * d; for(int i = head[x]; i; i = nxt[i]) if(ver[i] != pre) { dfs(ver[i], x, d + edge[i]); totdis[x] += totdis[ver[i]] + cnt[ver[i]] * edge[i]; cnt[x] += cnt[ver[i]]; } } void dfs1(int x, int pre) { for(int i = head[x]; i; i = nxt[i]) if(ver[i] != pre) { revtotdis[ver[i]] = revtotdis[x] + (cntl - cnt[x]) * edge[i] + totdis[x] - totdis[ver[i]] - edge[i] * cnt[ver[i]] + edge[i] * (cnt[x] - cnt[ver[i]]); dfs1(ver[i], x); } } void dp(int x, int pre, int now) { if(in[x] == 1) ans += now; for(int i = head[x]; i; i = nxt[i]) if(ver[i] != pre) { dp(ver[i], x, now - 2 * edge[i] * (totdis[ver[i]] + edge[i] * cnt[ver[i]]) + 2 * edge[i] * (revtotdis[ver[i]] - edge[i] * (cntl - cnt[ver[i]])) + edge[i] * edge[i] * cntl); } } signed main() { n = read(); for(int i = 1; i < n; i++) { int x = read(), y = read(), w = read(); in[x]++; in[y]++; add(x, y, w); add(y, x, w); } for(int i = 1; i <= n; i++) if(in[i] != 1) rt = i; else cntl++; dfs(rt, rt, 0); dfs1(rt, rt); int tmp = ans; ans = 0; dp(rt, rt, tmp); printf("%lld ", ans / 2); return 0; }