• 【洛谷P2664】 树上游戏 点分治


    code:

    #include <bits/stdc++.h>   
     
    #define N 200009   
     
    #define ll long long 
     
    #define setIO(s) freopen(s".in","r",stdin)   
     
    using namespace std; 
     
    ll Sum[N];
     
    int n,edges,root,sn;     
     
    int val[N],hd[N],to[N<<1],nex[N<<1],size[N],mx[N],vis[N],A[N];   
     
    inline void add(int u,int v) 
    {
        nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;  
    }     
     
    void getroot(int u,int ff) 
    {
        size[u]=1,mx[u]=0; 
     
        for(int i=hd[u];i;i=nex[i]) 
        {
            int v=to[i];   
     
            if(v==ff||vis[v])    continue;
     
            getroot(v,u);  
     
            size[u]+=size[v];   
     
            mx[u]=max(mx[u],size[v]);  
        }
     
        mx[u]=max(mx[u],sn-size[u]); 
     
        if(mx[u]<mx[root])   root=u;  
     
    }        
      
    int ou;     
     
    ll tmp,tot,bu[N];           
     
    map<int,ll>cn[N];     
     
    map<int,ll>::iterator it;   
     
    int dep[N],cnt[N],siz[N];   
     
    void getnode(int top,int u,int ff,int cur) 
    {               
    
        if(!cnt[val[u]])   
     
            ++cur;      
     
        ++cnt[val[u]];  
     
        Sum[top]+=(ll)cur;
     
        siz[u]=1;      
     
        for(int i=hd[u];i;i=nex[i])  
        {
            int v=to[i]; 
     
            if(v==ff||vis[v])     continue;    
     
            getnode(top,v,u,cur);  
     
            siz[u]+=siz[v];  
        }      
     
        --cnt[val[u]];                            
    }            
    void get_col(int top,int u,int ff) 
    {                 
        if(!cnt[val[u]])     
     
            cn[top][val[u]]+=(ll)siz[u];             
     
        ++cnt[val[u]];     
     
        for(int i=hd[u];i;i=nex[i]) 
        {
     
            int v=to[i]; 
     
            if(v==ff||vis[v])   continue;      
     
            get_col(top,v,u); 
     
        } 
        --cnt[val[u]];        
    }          
    void calc_v(int u,int ff) 
    {          
        ll tt=bu[val[u]];           
     
        tmp=tmp-bu[val[u]]+ou;            
     
        bu[val[u]]=ou;                    
    
        Sum[u]+=tmp;       
    
        for(int i=hd[u];i;i=nex[i]) 
        {
            int v=to[i]; 
     
            if(vis[v]||v==ff)    continue;          
     
            calc_v(v,u);                       
        }  
    
        tmp=tmp-bu[val[u]]+tt;    
    
        bu[val[u]]=tt;    
    }
    void clr(int u,int ff)
    {
        cn[u].clear();    
        bu[val[u]]=0;         
        for(int i=hd[u];i;i=nex[i])
        {
            int v=to[i]; 
            if(v==ff||vis[v])   continue;   
            clr(v,u);     
        }
    }
    void calc(int u) 
    {         
        tot=0;       
     
        getnode(u,u,0,0);       
    
        for(int i=hd[u];i;i=nex[i])           
        {
            int v=to[i];   
     
            if(vis[v])    continue;     
     
            // memset(cnt,0,sizeof(cnt));  
     
            get_col(v,v,u);    
     
            for(it=cn[v].begin();it!=cn[v].end();it++)    
            {
                tot+=it->second; 
     
                bu[it->first]+=it->second;                
            }
        }               
    
        for(int i=hd[u];i;i=nex[i])    
        {   
            int v=to[i];   
     
            if(vis[v])    continue;   
     
            tmp=tot;       
     
            ou=siz[u]-siz[v]; 
     
            for(it=cn[v].begin();it!=cn[v].end();it++)    
            {
                bu[it->first]-=it->second;   
     
                tmp-=it->second;            
            }
    
            ll tt=bu[val[u]];             
     
            tmp=tmp-bu[val[u]]+ou; 
     
            bu[val[u]]=ou;  
     
            calc_v(v,u);     
      
            bu[val[u]]=tt;   
     
            for(it=cn[v].begin();it!=cn[v].end();it++) bu[it->first]+=it->second;      
      
        }          
     
        clr(u,0);     
    }
    void dfs(int u) 
    {
        calc(u);      
     
        vis[u]=1;     
    
        for(int i=hd[u];i;i=nex[i]) 
        {
            int v=to[i];   
     
            if(vis[v])    continue;     
     
            root=0,sn=size[v],getroot(v,u),dfs(root);   
        }
    }
    int main() 
    { 
        // setIO("input");   
     
        int i,j;        
     
        scanf("%d",&n);     
     
        for(i=1;i<=n;++i)     scanf("%d",&val[i]), A[i]=val[i]; 
     
        sort(A+1,A+1+n); 
     
        for(i=1;i<=n;++i)     val[i]=lower_bound(A+1,A+1+n,val[i])-A;     
     
        for(i=1;i<n;++i) 
        {
            int u,v; 
     
            scanf("%d%d",&u,&v),add(u,v),add(v,u);  
        }          
     
        sn=mx[0]=n,root=0,getroot(1,0),dfs(root);   
     
        for(i=1;i<=n;++i)     printf("%lld
    ",Sum[i]);        
     
        return 0; 
    }
    

      

  • 相关阅读:
    JDK中Unsafe类详解
    JAVA并发理论与实践
    关于FastJSON
    指数退避算法
    MySQL多表关联查询效率高点还是多次单表查询效率高,为什么?
    App开放接口api安全性—Token签名sign的设计与实现
    使用Jmeter进行http接口性能测试
    短信验证登录实现流程
    使用 Postman 取得 Token 打另一隻 API
    SpringMVC拦截器HandlerInterceptor使用
  • 原文地址:https://www.cnblogs.com/guangheli/p/11997385.html
Copyright © 2020-2023  润新知