题目链接
题目大意
给定一棵包含 \(n\) 个节点的树,每个节点有个权值 \(a_i\)
求 \(\sum_{u=1}^n\sum_{v=1}^n\min(a_u,a_v)dis(u,v)\)
解题思路
对于节点 \(u\)
- 记权值小于 \(a_u\) 的节点有 \(a_{x1},a_{x2},a_{x3},...,a_{xcnt1}\)
- 记权值大于等于 \(a_u\) 的节点有 \(a_{y1},a_{y2},...,a_{ycnt2}\)
那么节点 \(u\) 对答案的贡献为:
- \(a_u\times(dep_u + dep_{x1} - 2\times dep_{lca})+a_u\times(dep_u + dep_{x2} - 2\times dep_{lca})+...\)
- \(a_{y1}\times(dep_u + dep_{y1} - 2\times dep_{lca})+a_{y2}\times(dep_u + dep_{y2} - 2\times dep_{lca})+...\)
即:
- \(a_u\times cnt1\times (dep_u -2\times dep_{lca}) + a_u\times(deq_{(x1+...+xcnt1)})\)
- \(a_{(y1+...+ycnt2)}\times (dep_u - 2\times dep_{lca})+a_{(y1+...+ycnt2)}\times dep_{(y1+..+ycnt2)}\)
定义 \(rt\) 为当前子树的根,那么 \(lca = rt\)
开四棵权值树状数组,分别用来维护 \(cnt\)、\(dep\)、\(a_i\) 、\(a_i\times dep_i\)
然后跑一遍 \(dsu~on~tree\) 即可
AC_Code
#include<bits/stdc++.h>
#define int long long
using namespace std;
template<typename T>void read(T &res)
{
bool flag=false;
char ch;
while(!isdigit(ch=getchar()))(ch=='-')&&(flag=true);
for(res=ch-48; isdigit(ch=getchar()); res=(res<<1)+(res<<3)+ch - 48);
flag&&(res=-res);
}
template<typename T>void Out(T x)
{
if(x<0)putchar('-'),x=-x;
if(x>9)Out(x/10);
putchar(x%10+'0');
}
const int N = 2e5 + 10 , mod = 998244353;
int n , ans , a[N] , dep[N] , sz[N] , HH , hson[N] , M;
struct Edge{
int nex , to;
} edge[N << 1];
int head[N] , TOT;
void add_edge(int u , int v)
{
edge[++ TOT].nex = head[u];
edge[TOT].to = v;
head[u] = TOT;
}
struct TR{
int tr[N];
int lowbit(int x){
return x & (-x);
}
void add(int pos , int val)
{
while(pos <= M)
{
tr[pos] = (tr[pos] + val + mod) % mod;
pos += lowbit(pos);
}
}
int query(int pos)
{
int res = 0;
while(pos)
{
res += tr[pos];
res %= mod;
pos -= lowbit(pos);
}
return res;
}
int get_sum(int L , int R){
return (query(R) - query(L - 1) + mod) % mod;
}
} tree1 , tree2 , tr1 , tr2;
vector<int>vec;
int get_id(int x){
return lower_bound(vec.begin() , vec.end() , x) - vec.begin() + 1;
}
void dfs(int u , int far)
{
dep[u] = dep[far] + 1 , sz[u] = 1;
for(int i = head[u] ; i ; i = edge[i].nex)
{
int v = edge[i].to;
if(v == far) continue ;
dfs(v , u);
sz[u] += sz[v];
if(sz[v] > sz[hson[u]]) hson[u] = v;
}
}
void change(int u , int far , int val)
{
tree1.add(a[u] , dep[u] * val);
tree2.add(a[u] , vec[a[u] - 1] * dep[u] * val);
tr1.add(a[u] , val);
tr2.add(a[u] , val * vec[a[u] - 1]);
for(int i = head[u] ; i ; i = edge[i].nex)
{
int v = edge[i].to;
if(v == far || v == HH) continue ;
change(v , u , val);
}
}
void calc(int u , int far , int rt)
{
int cnt = tr1.get_sum(a[u] , M);
int sum = tree1.get_sum(a[u] , M);
int mi = vec[a[u] - 1];
ans += mi * dep[u] * cnt + mi * sum;
ans -= mi * cnt * 2 * dep[rt];
ans = (ans + mod) % mod;
sum = tree2.get_sum(1 , a[u] - 1);
cnt = tr2.get_sum(1 , a[u] - 1);
ans += sum + cnt * dep[u];
ans -= cnt * 2 * dep[rt];
ans = (ans + mod) % mod;
for(int i = head[u] ; i ; i = edge[i].nex)
{
int v = edge[i].to;
if(v == far || v == HH) continue ;
calc(v , u , rt);
}
}
void dsu(int u , int far , int op)
{
for(int i = head[u] ; i ; i = edge[i].nex)
{
int v = edge[i].to;
if(v == far || v == hson[u]) continue ;
dsu(v , u , 0);
}
if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u];
for(int i = head[u] ; i ; i = edge[i].nex)
{
int v = edge[i].to;
if(v == far || v == HH) continue;
calc(v , u , u) , change(v , u , 1);
}
int cnt = tr1.get_sum(a[u] , M);
int sum = tree1.get_sum(a[u] , M);
int mi = vec[a[u] - 1];
ans += mi * dep[u] * cnt + mi * sum;
ans -= mi * cnt * 2 * dep[u];
ans = (ans + mod) % mod;
sum = tree2.get_sum(1 , a[u] - 1);
cnt = tr2.get_sum(1 , a[u] - 1);
ans += sum + cnt * dep[u];
ans -= cnt * 2 * dep[u];
ans = (ans + mod) % mod;
tree1.add(a[u] , dep[u]);
tree2.add(a[u] , vec[a[u] - 1] * dep[u]);
tr1.add(a[u] , 1);
tr2.add(a[u] , vec[a[u] - 1]);
HH = 0;
if(!op) change(u , far , -1);
}
signed main()
{
read(n);
for(int i = 1 ; i <= n ; i ++) read(a[i]) , vec.push_back(a[i]);
for(int i = 1 ; i < n ; i ++)
{
int u , v;
read(u) , read(v);
add_edge(u , v) , add_edge(v , u);
}
sort(vec.begin() , vec.end());
vec.erase(unique(vec.begin() , vec.end()) , vec.end());
for(int i = 1 ; i <= n ; i ++) a[i] = get_id(a[i]);
M = vec.size();
dfs(1 , 0);
dsu(1 , 0 , 1);
Out(ans * 2 % mod) , puts("");
return 0;
}