题面
解析
刚学了点分治,练一练模版题
过程就不多说了,主要说说细节
在每次查询下一棵子树时, 传进去的整棵子树大小是上一次的$siz$, 这个数据其实是错的, 但好像并不影响时间复杂度, 这样的话找重心就必须找最大子树最小的点了,否则会错。因此需要存一个当前最大子树最小的点的最大子树的大小, 以及当前的重心, 每次找重心之前,前者要赋为$inf$
在每次统计以当前点为$lca$的链的答案时, 要先把这个重心加入数组中, 再按$dis$排序, 然后用双指针法或是其他可行的方法统计答案, 不能暴力枚举点对,不然复杂度不对
代码:
#include<cstdio> #include<iostream> #include<algorithm> using namespace std; const int maxn = 40005, inf = 0x3f3f3f3f; inline int read() { int ret, f=1; char c; while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1; ret=c-'0'; while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0'; return ret*f; } int n, k; int head[maxn], tot; struct edge{ int nxt, to, d; }e[maxn<<1]; void Addedge(int x, int y, int w) { e[++tot] = (edge){head[x], y, w}; head[x] = tot; } int cnt; int dis[maxn], siz[maxn], bel[maxn], p[maxn], ans, val, rt; bool vis[maxn]; void dfs(int x, int fa, int sum) { int ret = 0; siz[x] = 1; for(int i = head[x]; i; i = e[i].nxt) { int id = e[i].to; if(vis[id] || id == fa) continue; dfs(id, x, sum); siz[x] += siz[id]; ret = max(ret, siz[id]); } ret = max(ret, sum - siz[x]); if(ret < val) { val = ret; rt = x; } } bool cmp(int x, int y) { return dis[x] < dis[y]; } void get_dis(int x, int fa, int b) { p[++cnt] = x; bel[x] = b; for(int i = head[x]; i; i = e[i].nxt) { int id = e[i].to; if(id == fa || vis[id]) continue; dis[id] = dis[x] + e[i].d; get_dis(id, x, b); } } int num[maxn]; void calc(int x) { cnt = 0; p[++cnt] = x; dis[x] = 0; bel[x] = x; for(int i = head[x]; i; i = e[i].nxt) { int id = e[i].to; if(vis[id]) continue; dis[id] = e[i].d; get_dis(id, x, id); } sort(p + 1, p + cnt + 1, cmp); for(int i = 1; i <= cnt; ++i) ++ num[bel[p[i]]]; int l = 0, r = cnt; while(l < r - 1) { ++ l; -- num[bel[p[l]]]; while(dis[p[r]] + dis[p[l]] > k) { -- num[bel[p[r]]]; -- r; if(r <= l) break; } if(r <= l) break; ans += r - l - num[bel[p[l]]]; } if(r > l) -- num[bel[p[r]]]; } void solve(int x) { vis[x] = 1; calc(x); for(int i = head[x]; i; i = e[i].nxt) { int id = e[i].to; if(vis[id]) continue; val = inf; dfs(id, x, siz[id]); solve(rt); } } int main() { while(scanf("%d%d", &n, &k), n != 0) { for(int i = 1; i < n; ++i) { int u = read(), v = read(), w = read(); Addedge(u, v, w); Addedge(v, u, w); } val = inf; dfs(1, 0, n); solve(rt); printf("%d ", ans); ans = tot = 0; for(int i = 1; i <= n; ++i) head[i] = vis[i] = 0; } return 0; }