题目描述
给出一棵 $n$ 个点的树,每条边的边权为1或0。求有多少点对 $(i,j)$ ,使得:$i$ 到 $j$ 的简单路径上存在点 $k$ (异于 $i$ 和 $j$ ),使得 $i$ 到 $k$ 的简单路径上0和1数目相等,$j$ 到 $k$ 的简单路径上0和1数目也相等。
输入
第1行包含一个整数N。
接下来N-1行,每行包含三个整数a_i、b_i和t_i,表示这条路上药材的类型。
输出
输出符合采药人要求的路径数目。
样例输入
7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1
样例输出
1
题解
树的点分治
求满足条件的路径数目,可以考虑点分治,每次求过根节点的方案数,再递归处理子树。
设 $f[x][0/1]$ 表示 $1$ 的数目比 $0$ 的数目多 $x$ ,且 否/是 有满足条件的 $k$ 节点的点数。
设 $g[x][0/1]$ 表示 $0$ 的数目比 $1$ 的数目多 $x$ ,且 否/是 有满足条件的 $k$ 节点的点数。
然后相应的答案贡献就是 $f[x][0]·g[x][1]+f[x][1]·g[x][0]+f[x][1]·g[x][1]$ 。
注意特殊处理数目相等的情况,以及分治中心作为路径端点的情况。
dfs整棵子树,求出答案,减去两端在同一子树内的答案,递归子树。
时间复杂度 $O(nlog n)$
#include <cstdio> #include <cstring> #include <algorithm> #define N 100010 using namespace std; int head[N] , to[N << 1] , val[N << 1] , next[N << 1] , cnt , vis[N] , si[N] , ms[N] , sum , root , md; long long f[N][2] , g[N][2] , ans; inline void add(int x , int y , int z) { to[++cnt] = y , val[cnt] = z , next[cnt] = head[x] , head[x] = cnt; } void getroot(int x , int fa) { int i; si[x] = 1 , ms[x] = 0; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]] && to[i] != fa) getroot(to[i] , x) , si[x] += si[to[i]] , ms[x] = max(ms[x] , si[to[i]]); ms[x] = max(ms[x] , sum - si[x]); if(ms[x] < ms[root]) root = x; } void calc(int x , int fa , int now , int cnt) { int i; if(now == 0) { if(cnt >= 2) ans ++ ; cnt ++ ; } for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]] && to[i] != fa) calc(to[i] , x , now + 2 * val[i] - 1 , cnt); } void dfs(int x , int fa , int now , int l , int r) { int i; if(now >= l && now <= r) { if(now >= 0) f[now][1] ++ ; else g[-now][1] ++ ; } else { if(now >= 0) f[now][0] ++ ; else g[-now][0] ++ ; } l = min(l , now) , r = max(r , now) , md = max(md , max(-l , r)); for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]] && to[i] != fa) dfs(to[i] , x , now + val[i] * 2 - 1 , l , r); } void solve(int x) { int i , j; vis[x] = 1 , md = 0 , calc(x , 0 , 0 , 0); dfs(x , 0 , 0 , 1 , -1) , ans += f[0][1] * (f[0][1] - 1) / 2 , f[0][0] = f[0][1] = 0; for(i = 1 ; i <= md ; i ++ ) ans += f[i][0] * g[i][1] + f[i][1] * g[i][0] + f[i][1] * g[i][1] , f[i][0] = f[i][1] = g[i][0] = g[i][1] = 0; for(i = head[x] ; i ; i = next[i]) { if(!vis[to[i]]) { md = 0 , dfs(to[i] , 0 , 2 * val[i] - 1 , 0 , 0) , ans -= f[0][1] * (f[0][1] - 1) / 2 , f[0][0] = f[0][1] = 0; for(j = 0 ; j <= md ; j ++ ) ans -= f[j][0] * g[j][1] + f[j][1] * g[j][0] + f[j][1] * g[j][1] , f[j][0] = f[j][1] = g[j][0] = g[j][1] = 0; sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , solve(root); } } } int main() { int n , i , x , y , z; scanf("%d" , &n); for(i = 1 ; i < n ; i ++ ) scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z); sum = n , root = 0 , ms[0] = n , getroot(1 , 0) , solve(root); printf("%lld " , ans); return 0; }