题目
题目链接:https://ac.nowcoder.com/acm/contest/11174/E
给出一棵 (n) 个点的树,点有点权,定义一条路径的权值为路径上最大点权乘最小点权。求所有路径的权值之和 (mod {998244353})。
(nleq 10^5;) 点权 (<998244353)。
思路
建出点分树,每次考虑经过点分树上一个节点的路径的权值和。
我们先 dfs 一遍求出当前点分中心到点分树内所有节点路径上的最大权值和最小权值,然后将所有二元组按照最小权值从小到大排序。
依次枚举所有二元组。假设枚举到第 (i) 个,考虑计算它和它前面的所有二元组合并形成的路径的贡献。显然合并后的路径最小权值就是前面二元组的最小权值,而最大权值可能是前面的或者当前的。
对于前面的一个二元组 (j)(即 (j<i)):
- 如果 (j) 的最大权值小于等于 (i) 的最大权值,那么贡献即为 (minv_j imes maxv_i)。维护一个树状数组,下标为 (k) 表示最大权值为 (k) 的路径的最小权值之和即可。
- 如果 (j) 的最大权值大于 (i) 的最大权值,那么贡献为 (minv_j imes maxv_j)。再维护一个树状数组,下标为 (k) 表示最大权值为 (k) 的路径的 (minv imes maxv) 之和即可。
注意我们会把同一个子树内的两条路径合并,这样的贡献需要删掉。
时间复杂度 (O(nlog^2 n))。
代码
#include <bits/stdc++.h>
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
const int N=100010,MOD=998244353;
int n,rt,tot,ans,a[N],d[N],maxp[N],siz[N],head[N];
bool vis[N];
vector<pair<int,int> > b,c;
struct edge
{
int next,to;
}e[N*2];
struct BIT
{
int c[N];
void add(int x,int v)
{
for (int i=x;i<=n;i+=i&-i)
c[i]=(c[i]+v)%MOD;
}
int query(int x)
{
int res=0;
for (int i=x;i;i-=i&-i)
res=(res+c[i])%MOD;
return res;
}
}bit1,bit2;
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void findrt(int x,int fa,int sum)
{
siz[x]=1; maxp[x]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa && !vis[v])
{
findrt(v,x,sum); siz[x]+=siz[v];
maxp[x]=max(maxp[x],siz[v]);
}
}
maxp[x]=max(maxp[x],sum-siz[x]);
if (!rt || maxp[x]<maxp[rt]) rt=x;
}
void dfs(int x,int fa,int minv,int maxv)
{
siz[x]=1;
minv=min(minv,a[x]); maxv=max(maxv,a[x]);
b.push_back(mp(minv,maxv));
c.push_back(mp(minv,maxv));
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa && !vis[v])
{
dfs(v,x,minv,maxv);
siz[x]+=siz[v];
}
}
}
void work(vector<pair<int,int> > b,ll f)
{
sort(b.begin(),b.end());
for (int i=0;i<b.size();i++)
{
int minv=b[i].fi,maxv=b[i].se;
ans=(ans+f*bit1.query(maxv)*d[maxv])%MOD;
ans=(ans+f*(bit2.query(n)-bit2.query(maxv)))%MOD;
bit1.add(maxv,d[minv]);
bit2.add(maxv,1LL*d[maxv]*d[minv]%MOD);
}
for (int i=0;i<b.size();i++)
{
int minv=b[i].fi,maxv=b[i].se;
bit1.add(maxv,-d[minv]);
bit2.add(maxv,-1LL*d[maxv]*d[minv]%MOD);
}
}
void calc(int x,int sum)
{
b.push_back(mp(a[x],a[x]));
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v])
{
dfs(v,x,a[x],a[x]);
work(c,-1); c.clear();
}
}
work(b,1); b.clear();
}
void solve(int x,int sum)
{
calc(x,sum); vis[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v])
{
rt=0; findrt(v,0,siz[v]);
solve(rt,siz[v]);
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
d[i]=a[i]; ans=(ans+1LL*a[i]*a[i])%MOD;
}
sort(d+1,d+1+n);
tot=unique(d+1,d+1+n)-d-1;
for (int i=1;i<=n;i++)
a[i]=lower_bound(d+1,d+1+tot,a[i])-d;
tot=0;
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
rt=0; findrt(1,0,n);
solve(rt,n);
cout<<ans;
return 0;
}