• 牛客练习赛81D 小Q与树


    dsu on tree

    题目链接

    点我跳转

    题目大意

    给定一棵包含 \(n\) 个节点的树,每个节点有个权值 \(a_i\)
    \(\sum_{u=1}^n\sum_{v=1}^n\min(a_u,a_v)dis(u,v)\)

    解题思路

    对于节点 \(u\)

    • 记权值小于 \(a_u\) 的节点有 \(a_{x1},a_{x2},a_{x3},...,a_{xcnt1}\)
    • 记权值大于等于 \(a_u\) 的节点有 \(a_{y1},a_{y2},...,a_{ycnt2}\)

    那么节点 \(u\) 对答案的贡献为:

    1. \(a_u\times(dep_u + dep_{x1} - 2\times dep_{lca})+a_u\times(dep_u + dep_{x2} - 2\times dep_{lca})+...\)
    2. \(a_{y1}\times(dep_u + dep_{y1} - 2\times dep_{lca})+a_{y2}\times(dep_u + dep_{y2} - 2\times dep_{lca})+...\)

    即:

    1. \(a_u\times cnt1\times (dep_u -2\times dep_{lca}) + a_u\times(deq_{(x1+...+xcnt1)})\)
    2. \(a_{(y1+...+ycnt2)}\times (dep_u - 2\times dep_{lca})+a_{(y1+...+ycnt2)}\times dep_{(y1+..+ycnt2)}\)

    定义 \(rt\) 为当前子树的根,那么 \(lca = rt\)

    开四棵权值树状数组,分别用来维护 \(cnt\)\(dep\)\(a_i\)\(a_i\times dep_i\)

    然后跑一遍 \(dsu~on~tree\) 即可

    AC_Code

    #include<bits/stdc++.h>
    #define int long long
    using namespace std;
    template<typename T>void read(T &res)
    {
    	bool flag=false;
    	char ch;
    	while(!isdigit(ch=getchar()))(ch=='-')&&(flag=true);
    	for(res=ch-48; isdigit(ch=getchar()); res=(res<<1)+(res<<3)+ch - 48);
    	flag&&(res=-res);
    }
    template<typename T>void Out(T x)
    {
    	if(x<0)putchar('-'),x=-x;
    	if(x>9)Out(x/10);
    	putchar(x%10+'0');
    }
    const int N = 2e5 + 10 , mod = 998244353;
    int n , ans , a[N] , dep[N] , sz[N] , HH , hson[N] , M;
    struct Edge{
    	int nex , to;
    } edge[N << 1];
    int head[N] , TOT;
    void add_edge(int u , int v)
    {
    	edge[++ TOT].nex = head[u];
    	edge[TOT].to = v;
    	head[u] = TOT;
    }
    struct TR{
    	int tr[N];
    	int lowbit(int x){
    		return x & (-x);
    	}
    	void add(int pos , int val)
    	{
    		while(pos <= M)
    		{
    			tr[pos] = (tr[pos] + val + mod) % mod;
    			pos += lowbit(pos);
    		}
    	}
    	int query(int pos)
    	{
    		int res = 0;
    		while(pos)
    		{
    			res += tr[pos];
    			res %= mod;
    			pos -= lowbit(pos);
    		}
    		return res;
    	}
    	int get_sum(int L , int R){
    		return (query(R) - query(L - 1) + mod) % mod;
    	}
    } tree1 , tree2 , tr1 , tr2;
    vector<int>vec;
    int get_id(int x){
    	return lower_bound(vec.begin() , vec.end() , x) - vec.begin() + 1;
    }
    void dfs(int u , int far)
    {
    	dep[u] = dep[far] + 1 , sz[u] = 1;
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far) continue ;
    		dfs(v , u);
    		sz[u] += sz[v];
    		if(sz[v] > sz[hson[u]]) hson[u] = v;
    	}
    }
    void change(int u , int far , int val)
    {
    	tree1.add(a[u] , dep[u] * val);
    	tree2.add(a[u] , vec[a[u] - 1] * dep[u] * val);
    	tr1.add(a[u] , val);
    	tr2.add(a[u] , val * vec[a[u] - 1]);
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == HH) continue ;
    		change(v , u , val);
    	}
    }
    void calc(int u , int far , int rt)
    {
    	int cnt = tr1.get_sum(a[u] , M);
    	int sum = tree1.get_sum(a[u] , M);
    	int mi = vec[a[u] - 1];
    		ans += mi * dep[u] * cnt + mi * sum;
    		ans -= mi * cnt * 2 * dep[rt];
    		ans = (ans + mod) % mod;
    	sum = tree2.get_sum(1 , a[u] - 1);
    	cnt = tr2.get_sum(1 , a[u] - 1);
    		ans += sum + cnt * dep[u];
    		ans -= cnt * 2 * dep[rt];
    		ans = (ans + mod) % mod;
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == HH) continue ;
    		calc(v , u , rt);
    	}
    }
    void dsu(int u , int far , int op)
    {
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == hson[u]) continue ;
    		dsu(v , u , 0);
    	}
    	if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u];
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == HH) continue;
    		calc(v , u , u) , change(v , u , 1);
    	}
    	int cnt = tr1.get_sum(a[u] , M);
    	int sum = tree1.get_sum(a[u] , M);
    	int mi = vec[a[u] - 1];
    		ans += mi * dep[u] * cnt + mi * sum;
    		ans -= mi * cnt * 2 * dep[u];
    		ans = (ans + mod) % mod;
    	sum = tree2.get_sum(1 , a[u] - 1);
    	cnt = tr2.get_sum(1 , a[u] - 1);
    		ans += sum + cnt * dep[u];
    		ans -= cnt * 2 * dep[u];
    		ans = (ans + mod) % mod;
    	tree1.add(a[u] , dep[u]);
    	tree2.add(a[u] , vec[a[u] - 1] * dep[u]);
    	tr1.add(a[u] , 1);
    	tr2.add(a[u] , vec[a[u] - 1]);
    	HH = 0;
    	if(!op) change(u , far , -1);
    }
    signed main()
    {
    	read(n);
    	for(int i = 1 ; i <= n ; i ++) read(a[i]) , vec.push_back(a[i]);
    	for(int i = 1 ; i <  n ; i ++)
    	{
    		int u , v;
    		read(u) , read(v);
    		add_edge(u , v) , add_edge(v , u);
    	}
    	sort(vec.begin() , vec.end());
    	vec.erase(unique(vec.begin() , vec.end()) , vec.end());
    	for(int i = 1 ; i <= n ; i ++) a[i] = get_id(a[i]);
    	M = vec.size();
    	dfs(1 , 0);
    	dsu(1 , 0 , 1);
    	Out(ans * 2 % mod) , puts("");
    	return 0;
    }
    
  • 相关阅读:
    PAT乙级1014.福尔摩斯的约会 (20)(20 分)
    PAT乙级1013.数素数
    PAT乙级1012.数字分类 (20)(20 分)
    PAT乙级1011.A+B和C (15)(15 分)
    PAT乙级1025.反转链表 (25)
    PAT乙级1020.月饼(20)
    PAT乙级1015.德才论(25)
    PAT乙级1010.一元多项式求导(25)
    PAT乙级1009.说反话(20)
    PAT乙级1008.数组元素循环右移问题(20)
  • 原文地址:https://www.cnblogs.com/GsjzTle/p/14958220.html
Copyright © 2020-2023  润新知