并查集
合法的简单路径只要三种情况,要么全是0边,要么全是1边,或者是先0后1的边。
于是我们可以把合法路径分成两种类型,一种是只过0的边或者只过1的边,一种是先过1再过0的边。
对于第一种情况,我们可以把某个点所在的0的联通块或者1的联通块大小统计出来,合法的第一种路径为联通块大小-1。
对于第二种情况,一定存在一个点他的一个邻边是0一个邻边是1,也就是中转点。那么根据乘法原理,合法的路径数就是联通块1大小x联通块2大小-1。
再综合起来看两种情况,对于第一种情况来说,一个点如果不能做中转点,那么它肯定是没有入度或者出度的点,那么他所在的0/1联通块必定有一个只有他本身。也就是1,所以第一种情况可以看成特殊的第二种情况。
综上来说,我们用某个点的两种联通块大小相乘再减去1,所得到的答案的意义为:从某个点出发经过改点中转(0->1)的路径数与从改点出发到达某个点且路径上全为0边或者全为1边的路径数之和。
最后所有点统计答案即可。
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define full(a, b) memset(a, b, sizeof a)
using namespace std;
typedef long long ll;
inline int lowbit(int x){ return x & (-x); }
inline int read(){
int X = 0, w = 0; char ch = 0;
while(!isdigit(ch)) { w |= ch == '-'; ch = getchar(); }
while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
return w ? -X : X;
}
inline int gcd(int a, int b){ return a % b ? gcd(b, a % b) : b; }
inline int lcm(int a, int b){ return a / gcd(a, b) * b; }
template<typename T>
inline T max(T x, T y, T z){ return max(max(x, y), z); }
template<typename T>
inline T min(T x, T y, T z){ return min(min(x, y), z); }
template<typename A, typename B, typename C>
inline A fpow(A x, B p, C lyd){
A ans = 1;
for(; p; p >>= 1, x = 1LL * x * x % lyd)if(p & 1)ans = 1LL * x * ans % lyd;
return ans;
}
const int N = 200005;
int parent[2][N], size[2][N];
int find(int i, int p){
while(p != parent[i][p]) parent[i][p] = parent[i][parent[i][p]], p = parent[i][p];
return p;
}
bool isConnect(int i, int p, int q){
return find(i, p) == find(i, q);
}
void merge(int i, int p, int q){
int pRoot = find(i, p), qRoot = find(i, q);
if(pRoot == qRoot) return;
if(size[i][qRoot] < size[i][pRoot]) swap(qRoot, pRoot);
parent[i][pRoot] = qRoot;
size[i][qRoot] += size[i][pRoot];
}
int main(){
int n = read();
for(int i = 0; i <= n; i ++){
parent[0][i] = parent[1][i] = i;
size[0][i] = size[1][i] = 1;
}
for(int i = 0; i < n - 1; i ++){
int x = read(), y = read(), c = read();
if(isConnect(c, x, y)) continue;
merge(c, x, y);
}
ll ans = 0;
for(int i = 1; i <= n; i ++){
ans += size[0][find(0, i)] * 1LL *size[1][find(1, i)] - 1;
}
cout << ans << endl;
return 0;
}