考虑一下把(0)看成(-1),那么就是找到一条边权和为(0)的路径,且这条路径可以被分成两段,边权和都是(0)
没有第二个限制就是点分裸题了
其实有了第二个限制还是点分裸题
考虑那个断点肯定会存在于当前分治重心的某一边,或者直接在分治重心上,我们在求每个点到分治重心的距离的时候判断一下上面能否有一个点成为断点就好了
具体做法就是开个桶,判断每个点的祖先是否有一个和它的分治重心的距离相等
之后就把点分成了两类,一类是到分治重心的路径上存在断点,一类是不存在的,显然不存在的只能和存在的拼成路径,同时还要满足距离和加起来为(0)就好了
于是开两个桶分别维护一下就好了
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define LL long long
#define re register
#define maxn 100005
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt,w;}e[maxn<<1];
int n,m,num,now,rt,S;LL ans;
int head[maxn],sum[maxn],mx[maxn],vis[maxn],pre[maxn];
int f[maxn<<1],tax[maxn<<1][2],top[2],st[maxn<<1][2];
inline void add(int x,int y,int z) {e[++num].v=y;e[num].nxt=head[x];head[x]=num;e[num].w=z;}
void getroot(int x,int fa) {
sum[x]=1,mx[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
getroot(e[i].v,x);sum[x]+=sum[e[i].v];
mx[x]=max(mx[x],sum[e[i].v]);
}
mx[x]=max(mx[x],S-sum[x]);
if(mx[x]<now) rt=x,now=mx[x];
}
void getdis(int x,int fa,int o) {
pre[x]=pre[fa]+o;
if(f[pre[x]+n]) {
st[++top[1]][1]=x;
if(f[pre[x]+n]>1&&pre[x]==0) ans++;
}
else st[++top[0]][0]=x;
f[pre[x]+n]++;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
getdis(e[i].v,x,e[i].w);
}
f[pre[x]+n]--;
}
void clear(int x,int fa) {
tax[n+pre[x]][0]=tax[n+pre[x]][1]=0;pre[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]||e[i].v==fa) continue;
clear(e[i].v,x);
}
}
void dfs(int x) {
vis[x]=1;f[0+n]=1;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
top[0]=top[1]=0;getdis(e[i].v,x,e[i].w);
for(re int j=1;j<=top[0];j++)
ans+=tax[n-pre[st[j][0]]][1];
for(re int j=1;j<=top[1];j++)
ans+=tax[n-pre[st[j][1]]][1]+tax[n-pre[st[j][1]]][0];
for(re int j=1;j<=top[0];j++) tax[n+pre[st[j][0]]][0]++;
for(re int j=1;j<=top[1];j++) tax[n+pre[st[j][1]]][1]++;
}f[0+n]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
clear(e[i].v,x);
}
for(re int i=head[x];i;i=e[i].nxt) {
if(vis[e[i].v]) continue;
S=sum[e[i].v],now=n,getroot(e[i].v,0),dfs(rt);
}
}
int main() {
n=read();
for(re int x,y,z,i=1;i<n;i++) {
x=read(),y=read(),z=read();
if(!z) z=-1;add(x,y,z),add(y,x,z);
}
now=n,S=n,getroot(1,0);dfs(rt);
printf("%lld
",ans);
return 0;
}