(color{#0066ff}{题目描述})
给一棵树,每条边有权。求一条简单路径,权值和等于 K,且边的数量最小。
(color{#0066ff}{输入格式})
第一行:两个整数 n,k。
第二至 n 行:每行三个整数,表示一条无向边的两端和权值 (注意点的编号从 0 开始)
(color{#0066ff}{输出格式})
一个整数,表示最小边数量。
如果不存在这样的路径,输出 -1。
(color{#0066ff}{输入样例})
4 3
0 1 1
1 2 2
1 3 4
(color{#0066ff}{输出样例})
2
(color{#0066ff}{数据范围与提示})
n≤200000,K≤1000000。
(color{#0066ff}{题解})
点分治
每次选出重心,对子树进行操作
搜每棵子树,把所有的dis和dep放进一个数组里(pair存)
开一个数组t,t[i]代表,到当前重心的距离为i的最少边数
每次搜完一棵子树,就用t和当前的结果更新ans
然后把当前子树的信息加入t
每次work的时候要把t清空(赋成极大值)
#include<bits/stdc++.h>
using namespace std;
#define LL long long
LL in() {
char ch; int x = 0, f = 1;
while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
return x * f;
}
struct node {
int to, dis;
node *nxt;
node(int to = 0, int dis = 0, node *nxt = NULL): to(to), dis(dis), nxt(nxt) {}
void *operator new (size_t) {
static node *S = NULL, *T = NULL;
return (S == T) && (T = (S = new node[1024]) + 1024), S++;
}
};
const int N = 2e5 + 10;
const int M = 1e6 + 10;
const int inf = 0x7f7f7f7f;
int siz[N], t[M], f[N], tmp[N];
using std::pair;
using std::make_pair;
pair<int, int> ls[N];
node *head[N];
bool vis[N];
int n, k, num, sum, root, cnt;
int ans = inf;
void add(int from, int to, int dis) {
head[from] = new node(to, dis, head[from]);
}
void getdis(int x, int fa, int dis, int dep) {
ls[++cnt] = make_pair(dis, dep);
tmp[++num] = dis;
for(node *i = head[x]; i; i = i->nxt)
if(i->to != fa && !vis[i->to])
getdis(i->to, x, dis + i->dis, dep + 1);
}
void getroot(int x, int fa) {
f[x] = 0;
siz[x] = 1;
for(node *i = head[x]; i; i = i->nxt) {
if(i->to == fa || vis[i->to]) continue;
getroot(i->to, x);
siz[x] += siz[i->to];
f[x] = std::max(f[x], siz[i->to]);
}
f[x] = std::max(f[x], sum - siz[x]);
if(f[x] < f[root]) root = x;
}
void calc(int x) {
num = 0;
for(node *i = head[x]; i; i = i->nxt) {
if(vis[i->to]) continue;
cnt = 0;
getdis(i->to, 0, i->dis, 1);
for(int j = 1; j <= cnt; j++)
if(ls[j].first <= k)
ans = std::min(ans, ls[j].second + t[k - ls[j].first]);
for(int j = 1; j <= cnt; j++)
if(ls[j].first <= k)
t[ls[j].first] = std::min(t[ls[j].first], ls[j].second);
}
for(int i = 1; i <= num; i++) if(tmp[i] <= k) t[tmp[i]] = inf;
}
void work(int x) {
vis[x] = true;
calc(x);
for(node *i = head[x]; i; i = i->nxt) {
if(vis[i->to]) continue;
root = 0, sum = siz[i->to];
getroot(i->to, 0);
work(root);
}
}
int main() {
n = in(), k = in();
int x, y, z;
for(int i = 1; i < n; i++) {
x = in(), y = in(), z = in();
add(x + 1, y + 1, z), add(y + 1, x + 1, z);
}
for(int i = 1; i <= k; i++) t[i] = inf;
f[0] = sum = n;
getroot(1, 0);
work(root);
printf("%d", ans == inf? -1 : ans);
return 0;
}