大致题意: 给你一棵树,每条边为黑色或红色, 求有多少个三元组((x,y,z)),使得路径((x,y),(x,z),(y,z))上都存在至少一条红色边。
容斥
我们可以借助容斥思想,用总方案数减去不合法方案数,就可以得到合法方案数。
一个不合法方案,就要使得路径((x,y),(x,z),(y,z))中,至少存在一条路径是全黑的。
如果我们删去树上的红色边,只留下黑色的边。则可以发现,一个不合法方案,满足至少存在两个点在同一个连通块内。
计算答案
考虑用并查集,统计每一个连通块的点数(s_i)。
然后,我们枚举连通块(i)使有至少两个点存在于这个连通块中,则可以分两类讨论:
- 第三个点在这个连通块中,方案数为(C_{s_i}^2cdot(n-s_i))。
- 第三个点不在这个连通块中,方案数为(C_{s_i}^3)。
这样就能计算出不合法方案数了。
最后用(C_n^3)减去不合法方案数即为答案。
代码
#include<bits/stdc++.h>
#define Tp template<typename Ty>
#define Ts template<typename Ty,typename... Ar>
#define Reg register
#define RI Reg int
#define Con const
#define CI Con int&
#define I inline
#define W while
#define N 50000
#define X 1000000007
using namespace std;
int n,s[N+5],fa[N+5];
I int getfa(CI x) {return fa[x]^x?fa[x]=getfa(fa[x]):x;}
int main()
{
RI i,x,y,t=0;char op;for(scanf("%d",&n),i=1;i<=n;++i) fa[i]=i;//初始化并查集
for(i=1;i^n;++i) scanf("%d%d",&x,&y),cin>>op,op=='b'&&(x=getfa(x))^(y=getfa(y))&&(fa[x]=y);//只留黑色边
for(i=1;i<=n;++i) ++s[getfa(i)];for(i=1;i<=n;++i) i==getfa(i)&&//统计连通块点数,枚举连通块算答案
(t=((1LL*s[i]*(s[i]-1)>>1)%X*(n-s[i])+1LL*s[i]*(s[i]-1)%X*(s[i]-2)%X*(X+1)/6+t)%X);//分两类情况统计不合法方案数
return printf("%d",(1LL*n*(n-1)%X*(n-2)%X*(X+1)/6-t+X)%X),0;//容斥求出合法方案数
}