题面
思路
路径问题?当然是点分治啊
最大最小值的差查询......嗯,虽然不知道为什么,但是树状数组求逆序对一定能处理!(暴论)
所以方法就确定了啦~(-_-||)
点分治,每次先dfs一遍,搞出来当前分治块里面每个点到分治中心的路径上边权最大最小值,如果最大值减去最小值大于K就丢掉
然后,我们把所有在本次分治中出现的值离散化,并且把每个点按照最大值排序
我们对于排好序的点序列中的每一个点,每次用树状数组求出它前面有多少个点的最小值在$maxn-K$到$maxn$范围内
这样可以做到不重复不遗漏(不流失不蒸发【大雾】)
然后不要忘记在进入儿子的分治块之前先去重一下
Code
比较恶心......
据说此题有LCT做法,我太蒻了不会qwq
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
inline int read(){
int re=0,flag=1;char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-') flag=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9') re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
return re*flag;
}
int n,K,first[100010],cnte,siz[100010],sum,root,son[100010];
struct edge{
int to,next,w;
}a[200010];
inline void add(int u,int v,int w){
a[++cnte]=(edge){v,first[u],w};first[u]=cnte;
a[++cnte]=(edge){u,first[v],w};first[v]=cnte;
}
int vis[100010];
void getroot(int u,int f){
int i,v;siz[u]=1;son[u]=0;
for(i=first[u];~i;i=a[i].next){
v=a[i].to;if(v==f||vis[v]) continue;
getroot(v,u);
siz[u]+=siz[v];
son[u]=max(son[u],siz[v]);
}
son[u]=max(son[u],sum-siz[u]);
if(son[u]<son[root]) root=u;
}
struct node{
int u,sub,minn,maxn;
}x[100010];int cntx;
inline bool cmp(node l,node r){
return l.maxn<r.maxn;
}
void getinfo(int u,int f,int maxn,int minn,int sub){
int i,v;
if((~maxn)&&(~minn)){
if(maxn-minn<=K){
cntx++;
x[cntx]=(node){u,sub,minn,maxn};
}
else return;
}
else maxn=-2e9,minn=2e9;
for(i=first[u];~i;i=a[i].next){
v=a[i].to;if(v==f||vis[v]) continue;
getinfo(v,u,max(maxn,a[i].w),min(minn,a[i].w),(f==0)?v:sub);
}
}
int d[200010],cntd;
void lisan(){
int i;
cntd=0;
for(i=1;i<=cntx;i++){
d[++cntd]=x[i].minn;
d[++cntd]=x[i].maxn;
}
sort(d+1,d+cntd+1);
cntd=unique(d+1,d+cntd+1)-d-1;
for(i=1;i<=cntx;i++){
x[i].maxn=lower_bound(d+1,d+cntd+1,x[i].maxn)-d;
x[i].minn=lower_bound(d+1,d+cntd+1,x[i].minn)-d;
}
sort(x+1,x+cntx+1,cmp);
}
struct BIT{
int a[200010],len;
void clear(int llen){
for(int i=1;i<=len;i++) a[i]=0;
len=llen;
}
int lowbit(int x){
return x&(-x);
}
void add(int x,int val){
for(;x<=len;x+=lowbit(x)) a[x]+=val;
}
int sum(int x){
int re=0;
for(;x>0;x-=lowbit(x)) re+=a[x];
return re;
}
}T;
int getpos(int pos){
return lower_bound(d+1,d+cntd+1,pos)-d;
}
ll calc(int u){
int i,pos;ll re=0,sum=0;
T.clear(cntd);
for(i=1;i<=cntx;i++){
if(u&&x[i].sub!=u) continue;
sum++;
pos=getpos(d[x[i].maxn]-K);
re+=T.sum(x[i].maxn)-T.sum(pos-1);
T.add(x[i].minn,1);
}
return re+(u?0:sum);
}
ll ans=0;
void dfs(int u,int sz){
int i,v;vis[u]=1;cntx=0;
getinfo(u,0,-1,-1,-1);
lisan();
ans+=calc(0);
for(i=first[u];~i;i=a[i].next){
v=a[i].to;if(vis[v]) continue;
ans-=calc(v);
}
for(i=first[u];~i;i=a[i].next){
v=a[i].to;if(vis[v]) continue;
sum=((siz[v]>siz[u])?(sz-siz[u]):siz[v]);
root=0;son[0]=sum;
getroot(v,0);
dfs(root,sum);
}
}
int main(){
memset(first,-1,sizeof(first));
n=read();K=read();int i,t1,t2,t3;
for(i=1;i<n;i++){
t1=read();t2=read();t3=read();
add(t1,t2,t3);
}
sum=n;root=0;son[0]=n;
getroot(1,0);
dfs(root,n);
printf("%lld
",ans);
}