爬爬爬
这其实算是ABC的F们中比较难的一个
洛谷题目页面传送门 & AtCoder题目页面传送门
有一棵树(T=(V,E),|V|=n,|E|=n-1),节点(i)的颜色为(a_iin[1,n])。求(forall iin[1,n]),经过至少(1)个颜色为(i)的节点的路径个数。
(ninleft[1,2 imes10^5 ight])。
先考虑对于某个(i),该如何求答案。注意到,若正面求答案,那么需要容斥一大堆,或者点分治(我还不会),比较麻烦。考虑反面求答案,求不经过任何颜色为(i)的节点的路径个数,最后用总路径数(dfrac{n(n+1)}2)减去它即可。考虑“不经过任何颜色为(i)的节点的路径个数”怎么算,显然,若将所有颜色为(i)的节点删除,会分离成若干CC,设它们的大小分别为(s_1,s_2,cdots,s_m),那么答案就是(sumlimits_{j=1}^mdfrac{s_j(s_j+1)}2),因为一条路径不经过任何颜色为(i)的节点当且仅当它的两端点属于同一个CC。
如果只对于一个(i)求答案,那么显然可以直接删除所有该删的点,然后随便DFS一下就(mathrm O(n))出答案了。但现在需要在较短时间内求出对于所有(i)的答案,就有那么一丝丝难度了。不难发现,(forall i),任意一个CC的深度最小的点要么没有父亲(等于根,令(1)为根),要么父亲颜色为(i)(反证法,如果不成立,那么显然CC可以再往上扩展)。于是我们可以用这个唯一的(唯一性可以反证,如果不唯一,也就是(2)个同辈人没有长辈连接,一定不连通)深度最小的点作为整个CC的代表。
又不难发现,任意一个非(1)节点,都只能作为一个(i)剖出来的CC的代表,因为它们都有一个固定的父亲,也就有一个固定的父亲的颜色,而节点(1)却可以作为任意一个(i)剖出来的CC的代表,需要特殊开数组处理。所以总CC数为(mathrm O(n))。考虑一遍DFS求出所有CC的大小。考虑如何算一个CC的大小。先假设整个以(x)为根的子树全属于此CC,然后看有多少是被删了的或属于其他CC的,减去即可。显然,在子树内的所有CC代表要么一个为祖先一个为晚辈,此时以它们为根的子树一个包含另一个,要么无关,此时以它们为根的子树不相交,CC代表们的父亲(颜色为(i))亦然。又显然,所有子树内的以颜色为(i)的节点为根的子树内的节点都不属于此CC,于是我们只需要将此CC的大小减去所有不为其他颜色为(i)的节点的晚辈的颜色为(i)的节点的子树大小即可。考虑从下往上贡献,即DFS过程中时刻维护递归栈,每到一个节点就将最深的一个颜色相同的节点(这个可以时刻维护每种颜色的递归栈(mathrm O(1))做到)在栈中的下一个节点的CC大小减去它的子树大小(正确性显然)。对于根代表多CC的情况,特判一下即可。
时间复杂度(mathrm O(n))。
下面是AC代码:
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define ppb pop_back
typedef long long ll;
const int N=200000;
int n;//节点个数
int a[N+1];//颜色
vector<int> nei[N+1];//邻接表
int sz[N+1]/*子树大小*/,cc_sz[N+1]/*CC大小(非根)*/,cc_sz1[N+1]/*CC大小(根,下标为颜色)*/;
int fa[N+1];//父亲
vector<int> stk/*总递归栈*/,stk_c[N+1]/*颜色们的递归栈*/;
void dfs(int x=1){//求CC们的大小
// cout<<x<<"
";
sz[x]=1;
stk_c[a[x]].pb(stk.size());
stk.pb(x);//压栈
for(int i=0;i<nei[x].size();i++){
int y=nei[x][i];
if(y==fa[x])continue;
fa[y]=x;
dfs(y);
sz[x]+=sz[y];
}
stk.ppb();
stk_c[a[x]].ppb();//准备回溯,弹出
if(stk_c[a[x]].size()>1)cc_sz[stk[stk_c[a[x]].back()+1]]-=sz[x];//贡献(非根)
else/*特判*/ cc_sz1[a[x]]-=sz[x];//贡献(根)
}
ll ans[N+1];
int main(){
cin>>n;
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
nei[x].pb(y);nei[y].pb(x);
}
stk.pb(0);
for(int i=1;i<=n;i++)stk_c[i].pb(0);
dfs();
for(int i=1;i<=n;i++)cc_sz[i]+=sz[i],cc_sz1[i]+=n;
// for(int i=1;i<=n;i++)cout<<cc_sz[i]<<" ";puts("");
// for(int i=1;i<=n;i++)cout<<cc_sz1[i]<<" ";puts("");
for(int i=1;i<=n;i++)ans[i]=1ll*n*(n+1)/2;
for(int i=2;i<=n;i++)ans[a[fa[i]]]-=1ll*cc_sz[i]*(cc_sz[i]+1)/2;
for(int i=1;i<=n;i++)ans[i]-=1ll*cc_sz1[i]*(cc_sz1[i]+1)/2;//减去反面答案
for(int i=1;i<=n;i++)cout<<ans[i]<<"
";
return 0;
}