• 【牛客7872 J】树上启发式合并


    【牛客7872 J】树上启发式合并

    题意

    树上启发式合并,求有多少点对满足,这两个点x和y相互之间不是祖先和后代的关系
    同时满足(val[x]+val[y]=2 * val[ lca(x,y) ])

    题解

    根据两个点不能互为祖先的要求可知:

    比较可行的方式是枚举这个作为lca的结点,对于一个作为lca的结点
    什么样的结点会以它为lca呢,当然是以它的不同的儿子为根结点的子树中的结点
    因此,统计答案的方式也比较巧妙,对于一个作为lca的结点u

    • 首先遍历它的第一个儿子v1的那棵子树,用一个mp数组记录当前已经遍历过的结点中每个数出现的次数
      遍历第1个儿子那棵子树时把mp维护好。
    • 然后从第2个儿子开始,先对每一个结点v,获取到当前mp[2*val[u]-val[v]]的大小
      这表示能和结点v一起组成符合条件的点对有多少。
    • 这样查询完第2个儿子上所有节点后,再把第2个儿子子树上的所有结点的mp值维护好,依次循环这样一个过程

    由于在做这个过程的时候必须保证mp值的准确,所以每次一个lca判断完后要清空该棵子树对mp值造成的影响。
    那么考虑什么样的结点不用清空呢,那就是该结点作为父亲结点的最后一个儿子维护答案时不用清空。
    那么我们怎样能使时间复杂度尽可能降低呢?那就是把所有儿子中最重的(子树大小最大的儿子)放在最后一个访问,这样就可以节省下清空它的时间复杂度,这就是启发式合并,运用最后一个儿子不需要清空的性质来降低时间复杂度。

    Code

    /****************************
    * Author : W.A.R            *
    * Date : 2020-10-31-20:44   *
    ****************************/
    /*
    */
    #include<stdio.h>
    #include<string.h>
    #include<math.h>
    #include<algorithm>
    #include<queue>
    #include<map>
    #include<unordered_map>
    #include<stack>
    #include<string>
    #include<set>
    #define mem(a,x) memset(a,x,sizeof(a))
    using namespace std;
    typedef long long ll;
    const int maxn=1e6+10;
    const ll mod=1e9+7;
    
    namespace Fast_IO{
        const int MAXL((1 << 18) + 1);int iof, iotp;
        char ioif[MAXL], *ioiS, *ioiT, ioof[MAXL],*iooS=ioof,*iooT=ioof+MAXL-1,ioc,iost[55];
        char Getchar(){
            if (ioiS == ioiT){
                ioiS=ioif;ioiT=ioiS+fread(ioif,1,MAXL,stdin);return (ioiS == ioiT ? EOF : *ioiS++);
            }else return (*ioiS++);
        }
        void Write(){fwrite(ioof,1,iooS-ioof,stdout);iooS=ioof;}
        void Putchar(char x){*iooS++ = x;if (iooS == iooT)Write();}
        inline int read(){
            int x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
    		if(ioc==EOF)exit(0);
            for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
        }
        inline long long read_ll(){
            long long x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
    		if(ioc==EOF)exit(0);
            for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
        }
        template <class Int>void Print(Int x, char ch = ''){
            if(!x)Putchar('0');if(x<0)Putchar('-'),x=-x;while(x)iost[++iotp]=x%10+'0',x/=10;
            while(iotp)Putchar(iost[iotp--]);if (ch)Putchar(ch);
        }
        void Getstr(char *s, int &l){
            for(ioc=Getchar();ioc==' '||ioc=='
    '||ioc=='	';)ioc=Getchar();
    		if(ioc==EOF)exit(0);
            for(l=0;!(ioc==' '||ioc=='
    '||ioc=='	'||ioc==EOF);ioc=Getchar())s[l++]=ioc;s[l] = 0;
        }
        void Putstr(const char *s){for(int i=0,n=strlen(s);i<n;++i)Putchar(s[i]);}
    }
    using namespace Fast_IO;
    struct node{int to,nxt;}e[maxn];
    int son[maxn],siz[maxn],cnt[maxn],head[maxn],val[maxn],ct;
    ll ans;
    unordered_map<int,int>mp;
    void addE(int u,int v){e[++ct].to=v;e[ct].nxt=head[u];head[u]=ct;}
    void dfs(int u,int fa){
    	siz[u]=1;
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;if(v==fa)continue;
    		dfs(v,u);siz[u]+=siz[v];
    		if(siz[v]>siz[son[u]])son[u]=v;
    	}
    }
    void add(int u,int fa,int value){
    	mp[val[u]]+=value;
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;
    		if(v==fa)continue;
    		add(v,u,value);
    	}
    }
    void calc(int u,int fa,int lca){
    	ans+=mp[2*val[lca]-val[u]];
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;
    		if(v==fa)continue;
    		calc(v,u,lca);
    	}
    }
    void getAns(int u,int fa,bool heavy){
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;
    		if(v==fa||v==son[u])continue;
    		getAns(v,u,0);
    	}
    	if(son[u])getAns(son[u],u,1);
    	for(int i=head[u];i;i=e[i].nxt){
    		int v=e[i].to;
    		if(v==fa||v==son[u])continue;
    		calc(v,u,u);
    		add(v,u,1);
    	}
    	mp[val[u]]++;
    	if(!heavy)add(u,fa,-1);
    }
    int main(){
    	int n=read();
    	for(int i=1;i<=n;i++)val[i]=read();
    	for(int i=1;i<n;i++){int u=read(),v=read();addE(u,v);addE(v,u);}
    	dfs(1,0);
    	getAns(1,0,0);
    	printf("%lld
    ",ans<<1);
    	return 0;
    }
    
    
  • 相关阅读:
    C语言第0次作业
    c语言博客作业02循环结构
    C语言博客作业04数组
    存储过程,函数参数默认值的一些问题
    路线查询
    C# 猜数字
    使用 Ext.Net TreePanel,TabPanel控件 布局
    SQL SERVER 2005 动态行转列SQL
    CROSS APPLY 和OUTER APPLY 的区别
    处理表重复记录(查询和删除)
  • 原文地址:https://www.cnblogs.com/wuanran/p/13907917.html
Copyright © 2020-2023  润新知