lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义(s(i,j))为(i)到(j)的颜色数量。以及
[sum_i = sum_{j=1}^{n}s(i,j)
]
现在他想让你求出所有的(sum[i])
这题真是难,点分治神题
我们考虑一个性质,对于一个点(i),如果它的颜色在到根的路径中是第一次出现,那么对于和(i)不在一个子树的点(j),对(j)都有(i)的子树大小(size_i)的贡献
然后有了这个性质,就好做了
找完重心后预处理出来实际的(size),用(sum)来记录所有点的贡献,(s)是这个颜色的贡献
而我们不是用点去更新答案,是用颜色来更新答案,所以要枚举子树,去掉这个子树的贡献来统计答案
于是再有(X)表示除了这个子树的点数和,(co)表示这个点到根的颜色数
然后记录下这个点到根的所有颜色的(s)的和,(s)是要被减去的
那么(ans+=sum-s+co imes X),然后单独更新一下根就是(ans+=sum-s_{c_{rt}}+size_{rt})
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
const int N = 1e5;
using namespace std;
int n,c[N + 5],size[N + 5],maxp[N + 5],rt,su,vis[N + 5],cnt[N + 5];
long long sum,s[N + 5],ros,X,ans[N + 5];
vector <int> d[N + 5];
void get_rt(int u,int fa)
{
size[u] = 1;
maxp[u] = 0;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
get_rt(v,u);
size[u] += size[v];
maxp[u] = max(maxp[u],size[v]);
}
maxp[u] = max(maxp[u],su - size[u]);
if (maxp[u] < maxp[rt])
rt = u;
}
void get_size(int u,int fa)
{
size[u] = 1;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
get_size(v,u);
size[u] += size[v];
}
}
void dfs(int u,int fa,int w)
{
cnt[c[u]]++;
if (cnt[c[u]] == 1)
{
s[c[u]] += w * size[u];
sum += w * size[u];
}
if (!cnt[c[rt]])
ros += w;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
dfs(v,u,w);
}
cnt[c[u]]--;
}
void upd(int u,int fa,int co,int su)
{
cnt[c[u]]++;
if (cnt[c[u]] == 1)
{
co++;
su += s[c[u]];
}
ans[u] += sum - su + co * X;
if (!cnt[c[rt]])
ans[u] += ros;
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (v == fa || vis[v])
continue;
upd(v,u,co,su);
}
cnt[c[u]]--;
}
void calc(int u)
{
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
dfs(v,u,1);
}
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
dfs(v,u,-1);
X = size[u] - size[v];
upd(v,0,0,0);
dfs(v,u,1);
}
ans[u] += sum - s[c[u]] + size[u];
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
dfs(v,u,-1);
}
}
void solve(int u)
{
vis[u] = 1;
ros = 1;
get_size(u,0);
calc(u);
vector <int>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it);
if (vis[v])
continue;
maxp[0] = N;
su = size[v];
rt = 0;
get_rt(v,0);
solve(rt);
}
}
int main()
{
scanf("%d",&n);
for (int i = 1;i <= n;i++)
scanf("%d",&c[i]);
int u,v;
for (int i = 1;i < n;i++)
{
scanf("%d%d",&u,&v);
d[u].push_back(v);
d[v].push_back(u);
}
su = n;
maxp[0] = N;
get_rt(1,0);
get_size(rt,0);
solve(rt);
for (int i = 1;i <= n;i++)
printf("%lld
",ans[i]);
return 0;
}