题目描述:这里
发现还是点对之间的问题,于是还是上点分
只不过是怎么做的问题
首先对每条边边权给成1和-1(即把原来边权为0的边边权改为-1),那么合法的路径总权值一定为0!
还是将路径分为经过当前根节点和不经过当前根节点的,对不经过当前点的递归处理
那么我们讨论经过当前根节点的路径算法即可
可以发现,如果一条路径经过当前根节点,那么当前根节点可以将这条路径分成两部分,且这两部分权值互为相反数!
(这是很显然的,如果不互为相反数的话那么加起来肯定不是0啊)
接下来我们分析中转站的问题:
其实中转站的含义就是在路径上找到一个点,使得这个点左右两部分权值和均为0
那么这里需要dp处理
我们用两个dp数组处理,分别为$f[i][0/1],g[i][0/1]$,其中g是f的前缀和,f[i][0]表示长度为i的路径,0/1用来记录路径上是否存在距离为i的节点
之所以要记录这一点,是因为起点和终点不能作为休息站,所以当某个权值第一次出现时我们要记录在f[i][0]中,在另半条路径中找到g[-i][1]的点数才能保证休息站的存在
那么答案就由几部分组成:
第一:之前的几棵子树中某一个点到根的路径权值为0,当前子树中某一个点到根的路径权值也为0,那么方案数就是$f[0][0]*g[0][0]$
第二:之前的几棵子树中某一个点到根的路径权值为i,当前子树中某一个点权值为-i,那么答案即为$f[i][1]*g[-i][0]+f[i][0]*g[-i][1]+f[i][1]*g[i][-1]$(注意这里的i不一定是正值!!!)
还是比较好理解,因为不管这个权值之前是否出现过,只需互为相反数然后累计即可,因为左右一定能找到一个合法的位置
但是注意一点:
我们认为根节点的g[0][0]初值为1,这样在累计以根节点为端点的情况的时候是有效的,但对于根节点恰好是端点而且根节点被迫成为休息站的情况是有问题的,这种情况是不合法的,所以事实上累计两边权值都是0的时候应该用的是$f[0][0]*(g[0][0]-1)$!!
然后算完把f累计到g里做个前缀和就可以了
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> #define ll long long using namespace std; const int inf=0x3f3f3f3f; struct Edge { int next; int to; int val; }edge[200005]; int head[100005]; int siz[100005]; int maxp[100005]; bool vis[100005]; int dis[100005]; int dep[100005]; int has[200005]; ll g[200005][2]; ll f[200005][2]; ll ans=0; int maxdep=0; int cnt=1; int s,rt; int n; void init() { memset(head,-1,sizeof(head)); cnt=1; } void add(int l,int r,int w) { edge[cnt].next=head[l]; edge[cnt].to=r; edge[cnt].val=w; head[l]=cnt++; } void get_rt(int x,int fa) { siz[x]=1,maxp[x]=0; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fa||vis[to])continue; get_rt(to,x); siz[x]+=siz[to]; maxp[x]=max(maxp[x],siz[to]); } maxp[x]=max(maxp[x],s-siz[x]); if(maxp[x]<maxp[rt])rt=x; } void get_dis(int x,int fa) { maxdep=max(maxdep,dep[x]); if(has[dis[x]])f[dis[x]][1]++; else f[dis[x]][0]++; has[dis[x]]++; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to]||to==fa)continue; dep[to]=dep[x]+1; dis[to]=dis[x]+edge[i].val; get_dis(to,x); } has[dis[x]]--; } void solve(int x) { vis[x]=1; g[n][0]=1; int maxx=0; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to])continue; dis[to]=n+edge[i].val,maxdep=1,dep[to]=1; get_dis(to,0); maxx=max(maxx,maxdep); ans+=(g[n][0]-1)*f[n][0]; for(int j=-maxdep;j<=maxdep;j++)ans+=g[n-j][1]*f[n+j][1]+g[n-j][0]*f[n+j][1]+g[n-j][1]*f[n+j][0]; for(int j=n-maxdep;j<=n+maxdep;j++) { g[j][0]+=f[j][0]; g[j][1]+=f[j][1]; f[j][0]=f[j][1]=0; } } for(int j=n-maxx;j<=n+maxx;j++)g[j][0]=g[j][1]=0; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to])continue; rt=0,maxp[rt]=inf,s=siz[to]; get_rt(to,0); solve(rt); } } int main() { scanf("%d",&n); init(); for(int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); if(!z)z--; add(x,y,z),add(y,x,z); } maxp[rt]=s=n; get_rt(1,0); solve(rt); printf("%lld ",ans); return 0; }