裸的点分治,然而我因为循环赋值$s$时把$i <= k$写成$i <= n$了,WA了好长时间
#include<cstdio> #include<cstring> #include<algorithm> #define N 200100 #define inf 2147483647 #define max(a,b) (a)>(b)?(a):(b) #define min(a,b) (a)<(b)?(a):(b) #define read(x) x=getint() using namespace std; inline int getint() { int fh = 1, k = 0; char c = getchar(); for(; c < '0' || c > '9'; c = getchar()) if (c == '-') fh = -1; for(; c >= '0' && c <= '9'; c = getchar()) k = k * 10 + c - '0'; return k * fh; } struct node { int nxt, to, w; } E[N << 1]; bool vis[N]; int cnt = 0, s[1000100], rtm = inf, root, sz[N], dist[N], deep[N], n, k, ans, point[N]; inline void ins(int x, int y, int z) {++cnt; E[cnt].nxt = point[x]; E[cnt].to = y; E[cnt].w = z; point[x] = cnt;} inline void fdrt(int x, int fa, int sh) { sz[x] = 1; int ma = 0; for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) { int v = E[tmp].to; if (vis[v] || v == fa) continue; fdrt(v, x, sh); sz[x] += sz[v]; ma = max(ma, sz[v]); } ma = max(ma, sh - ma); if (ma < rtm) { rtm = ma; root = x; } } inline void work(int x, int fa) { if (dist[x] <= k) ans = min(ans, deep[x] + s[k - dist[x]]); for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) { int v = E[tmp].to; if (vis[v] || v == fa) continue; dist[v] = dist[x] + E[tmp].w; deep[v] = deep[x] + 1; work(v, x); } } inline void sfill(int x, int fa) { if (dist[x] < k) s[dist[x]] = min(s[dist[x]], deep[x]); for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) { int v = E[tmp].to; if (vis[v] || v == fa) continue; sfill(v, x); } } inline void emp(int x, int fa) { if (dist[x] < k) s[dist[x]] = n + 1; for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) { int v = E[tmp].to; if (vis[v] || v == fa) continue; emp(v, x); } } inline void dfs(int x, int sh) { vis[x] = 1; s[0] = 0; //不能落下这个点!!因为后面会更新不到,而且有可能会更改s[0]的值 for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) { int v = E[tmp].to; if (vis[v]) continue; dist[v] = E[tmp].w; deep[v] = 1; work(v, x); sfill(v, x); } for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) { int v = E[tmp].to; if (vis[v]) continue; emp(v, x); } for(int tmp = point[x]; tmp; tmp = E[tmp].nxt) { int v = E[tmp].to; if (vis[v]) continue; int ss = sz[v] < sz[x] ? sz[v]: sh - sz[x]; rtm = inf; fdrt(v, x, ss); dfs(root, ss); } } int main() { read(n); read(k); int a,b,c; for(int i = 1; i < n; ++i) { read(a); read(b); read(c); ++a; ++b; ins(a, b, c); ins(b, a, c); } ans = n; memset(vis, 0, sizeof(vis)); fdrt(1, -1, n); for(int i = 0; i <= k; ++i) s[i] = n + 1; dfs(1, n); printf("%d ", ans == n ? -1 : ans); return 0; }
然后就可以了