好久不做点分治的题了,正好在联赛之前抓紧复习一下.
先把边权为 $0$ 的置为 $-1$.
定义几个状态:
$f[dis][0/1],g[dis][0/1]$
其中 $f$ 代表在当前遍历的子树内的答案.
其中 $f[dis][0]$ 表示到根节点距离为 $dis$,没有遇到平衡点的个数,$f[dis][1]$ 表示遇到平衡点的个数.
然后就把 $f$ 和 $g$ 用乘法原理乘一下就可以了.
注意要及时清空.
#include <cstdio> #include <vector> #include <algorithm> #define N 200004 #define ll long long #define bu(i) bu[i+n] #define f(i,j) f[i+n][j] #define g(i,j) g[i+n][j] #define setIO(s) freopen(s".in","r",stdin) using namespace std; ll ans=0; int n,edges,sz,maxdep,root; int hd[N],to[N],nex[N],val[N],mx[N],size[N],vis[N],f[N][2],g[N][2],bu[N<<1],tmp[N]; void add(int u,int v,int c) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c; } void getroot(int u,int ff) { size[u]=1,mx[u]=0; for(int i=hd[u];i;i=nex[i]) if(to[i]!=ff&&!vis[to[i]]) getroot(to[i],u),size[u]+=size[to[i]],mx[u]=max(mx[u],size[to[i]]); mx[u]=max(mx[u],sz-size[u]); if(mx[u]<mx[root]) root=u; } void dfs(int x,int ff,int dep) { if(bu(dep)) ++f(dep,1); else ++f(dep,0); ++bu(dep); maxdep=max(maxdep,dep<0?-dep:dep); for(int i=hd[x];i;i=nex[i]) if(!vis[to[i]]&&to[i]!=ff) dfs(to[i],x,dep+val[i]); --bu(dep); } void calc(int x) { int i,j,maxx=0; for(i=hd[x];i;i=nex[i]) { if(vis[to[i]]) continue; maxdep=0; dfs(to[i],x,val[i]); maxx=max(maxx,maxdep); ans+=f(0,1); for(j=-maxdep;j<=maxdep;++j) { if(j==0) ans+=(ll)g(j,0)*f(j,0); ans+=(ll)g(j,0)*f(-j,1); ans+=(ll)g(j,1)*f(-j,0); ans+=(ll)g(j,1)*f(-j,1); } for(j=-maxdep;j<=maxdep;++j) g(j,0)+=f(j,0),g(j,1)+=f(j,1); for(j=-maxdep;j<=maxdep;++j) f(j,0)=f(j,1)=0; } for(j=-maxx;j<=maxx;++j) g(j,0)=g(j,1)=0; } void solve(int x) { vis[x]=1; calc(x); for(int i=hd[x];i;i=nex[i]) if(!vis[to[i]]) root=0,sz=size[to[i]],getroot(to[i],x),solve(root); } int main() { int i,j; // setIO("input"); scanf("%d",&n); for(i=1;i<n;++i) { int a,b,t; scanf("%d%d%d",&a,&b,&t); if(t==0) t=-1; add(a,b,t),add(b,a,t); } mx[0]=n,sz=n,getroot(1,0),solve(root),printf("%lld ",ans); return 0; }