• 点分治_学习笔记+题目清单


    1.模板题 洛谷P3806

    注意对limit加限制<=1e7,不然会RE

    #include<bits/stdc++.h>
    #define ll long long
    #define rep(i,a,n) for(int i=a;i<=n;i++)
    #define per(i,n,a) for(int i=n;i>=a;i--)
    #define endl '
    '
    #define mem(a,b) memset(a,b,sizeof(a))
    #define IO ios::sync_with_stdio(false);cin.tie(0);
    using namespace std;
    const int INF=0x3f3f3f3f;
    const ll inf=0x3f3f3f3f3f3f3f3f;
    const int mod=1e9+7;
    const int maxn=1e5+5;
    const int maxk=1e7+5;
    const int limit=1e7;
    int tot,head[maxn];
    struct E{
        int to,next,w;
    }edge[maxn<<1];
    void add(int u,int v,int w){
        edge[tot].to=v;
        edge[tot].w=w;
        edge[tot].next=head[u];
        head[u]=tot++;
    }
    int n,m,rt,sum,cnt,q[maxn];
    int tmp[maxn],siz[maxn],dis[maxn],maxp[maxn];
    bool judge[maxk],ans[maxn],vis[maxn];
    void getrt(int u,int f){
        siz[u]=1,maxp[u]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            getrt(v,u);
            siz[u]+=siz[v];
            if(siz[v]>maxp[u]) maxp[u]=siz[v];
        }
        maxp[u]=max(maxp[u],sum-siz[u]);
        if(maxp[u]<maxp[rt]) rt=u;
    }
    void getdis(int u,int f){
        if(dis[u]<=limit) tmp[cnt++]=dis[u];
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            dis[v]=dis[u]+edge[i].w;
            getdis(v,u);
        }
    }
    void solve(int u){
        queue<int> que;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt=0;
            dis[v]=edge[i].w;
            getdis(v,u);
            for(int j=0;j<cnt;j++)
                for(int k=0;k<m;k++)
                    if(q[k]>=tmp[j])
                        ans[k]|=judge[q[k]-tmp[j]];
            for(int j=0;j<cnt;j++){
                que.push(tmp[j]);
                judge[tmp[j]]=true;
            }
        }
        while(!que.empty()){
            judge[que.front()]=false;
            que.pop();
        }
    }
    void divide(int u){
        vis[u]=judge[0]=true;
        solve(u);
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            maxp[rt=0]=sum=siz[v];
            getrt(v,0);
            getrt(rt,0);
            divide(rt);
        }
    }
    int main(){
        scanf("%d%d",&n,&m);mem(head,-1);
        for(int i=1;i<n;i++){
            int u,v,w;scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);add(v,u,w);
        }
        for(int i=0;i<m;i++) scanf("%d",&q[i]);
        maxp[0]=sum=n;
        getrt(1,0);
        getrt(rt,0);
        divide(rt);
        for(int i=0;i<m;i++){
            if(ans[i]) puts("AYE");
            else puts("NAY");
        }
    }
    View Code

    2.P4178 Tree

    遵循点分治思想,在solve函数中用双指针维护和<=k,另外在求距离的getdis函数中,如果当前距离>k了可以直接return;因为距离是累加的,大于了就无意义.暴力会TLE

    #include<bits/stdc++.h>
    #define ll long long
    #define rep(i,a,n) for(int i=a;i<=n;i++)
    #define per(i,n,a) for(int i=n;i>=a;i--)
    #define endl '
    '
    #define mem(a,b) memset(a,b,sizeof(a))
    #define IO ios::sync_with_stdio(false);cin.tie(0);
    using namespace std;
    const int INF=0x3f3f3f3f;
    const ll inf=0x3f3f3f3f3f3f3f3f;
    const int mod=1e9+7;
    const int maxn=4e4+5;
    int tot,head[maxn];
    struct E{
        int to,next,w;
    }edge[maxn<<1];
    void add(int u,int v,int w){
        edge[tot].to=v;
        edge[tot].w=w;
        edge[tot].next=head[u];
        head[u]=tot++;
    }
    int n,rt,ans=0,cnt,sum,lim,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
    bool vis[maxn];
    void getrt(int u,int f){
        siz[u]=1,maxp[u]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            getrt(v,u);
            siz[u]+=siz[v];
            if(siz[v]>maxp[u]) maxp[u]=siz[v];
        }
        maxp[u]=max(maxp[u],sum-siz[u]);
        if(maxp[u]<maxp[rt]) rt=u;
    }
    void getdis(int u,int f){
        if(dis[u]>lim) return ;
        tmp[cnt++]=dis[u];
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            dis[v]=dis[u]+edge[i].w;
            getdis(v,u);
        }
    }
    void solve(int u){
        vector<int> vec;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt=0;
            dis[v]=edge[i].w;
            getdis(v,u);
            sort(tmp,tmp+cnt);
            sort(vec.begin(),vec.end());
            int r=vec.size()-1;
            for(int j=0;j<cnt&&r>=0;j++){
                while(tmp[j]+vec[r]>lim){
                    --r;
                    if(r<=-1) break;
                }
                if(r>=0) ans+=(r+1);
                else break;
            }
            for(int j=0;j<cnt;j++){
                if(tmp[j]<=lim) ++ans;
                vec.push_back(tmp[j]);
            }            
        }
        vec.clear();
    }
    void divide(int u){
        vis[u]=1;
        solve(u);
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            maxp[rt=0]=sum=siz[v];
            getrt(v,0);
            getrt(rt,0);
            divide(rt);
        }
    }
    int main(){
        scanf("%d",&n);mem(head,-1);
        rep(i,1,n-1){
            int u,v,w;scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);add(v,u,w);
        }
        scanf("%d",&lim);
        maxp[0]=sum=n;
        getrt(1,0);
        getrt(rt,0);
        divide(rt);
        printf("%d
    ",ans);
    }    
    View Code

    3.P2634 [国家集训队]聪聪可可

    统计一共有几条边,并对边分析,看%3是否==0

    #include<bits/stdc++.h>
    #define ll long long
    #define rep(i,a,n) for(int i=a;i<=n;i++)
    #define per(i,n,a) for(int i=n;i>=a;i--)
    #define endl '
    '
    #define mem(a,b) memset(a,b,sizeof(a))
    #define IO ios::sync_with_stdio(false);cin.tie(0);
    using namespace std;
    const int INF=0x3f3f3f3f;
    const ll inf=0x3f3f3f3f3f3f3f3f;
    const int mod=1e9+7;
    const int maxn=1e5+5;
    int tot,head[maxn];
    struct E{
        int to,next,w;
    }edge[maxn<<1];
    void add(int u,int v,int w){
        edge[tot].to=v;
        edge[tot].w=w;
        edge[tot].next=head[u];
        head[u]=tot++;
    }
    int n,rt,cnt,sum,cnt1=0,cnt2=0,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
    bool vis[maxn];
    void getrt(int u,int f){
        siz[u]=1,maxp[u]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            getrt(v,u);
            siz[u]+=siz[v];
            if(siz[v]>maxp[u]) maxp[u]=siz[v];
        }
        maxp[u]=max(maxp[u],sum-siz[u]);
        if(maxp[u]<maxp[rt]) rt=u;
    }
    void getdis(int u,int f){
        tmp[cnt++]=dis[u];
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            dis[v]=dis[u]+edge[i].w;
            getdis(v,u);
        }
    }
    void solve(int u){
        vector<int> vec;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt=0;
            dis[v]=edge[i].w;
            getdis(v,u);
            for(int j=0;j<cnt;j++){
                for(auto it:vec){
                    ++cnt2;
                    if((it+tmp[j])%3==0) ++cnt1;
                }
            }
            for(int j=0;j<cnt;j++){
                ++cnt2;
                if(tmp[j]%3==0) ++cnt1;
                vec.push_back(tmp[j]);
            }            
        }
        vec.clear();
    }
    void divide(int u){
        vis[u]=1;
        solve(u);
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            maxp[rt=0]=sum=siz[v];
            getrt(v,0);
            getrt(rt,0);
            divide(rt);
        }
    }
    int main(){
        scanf("%d",&n);mem(head,-1);
        rep(i,1,n-1){
            int u,v,w;scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);add(v,u,w);
        }
        maxp[0]=sum=n;
        getrt(1,0);
        getrt(rt,0);
        divide(rt);
        cnt1*=2;cnt2*=2;
        cnt1+=n;cnt2+=n;
        int c=__gcd(cnt1,cnt2);
        cnt1/=c;cnt2/=c;
        printf("%d/%d",cnt1,cnt2);
    }    
    View Code

     4.Codeforces.161D Distance in Tree

    思想和P3806很贴近,我一开始用了双指针,但是双指针会处理过剩,所以会wa(也有可能我写的菜吧).所以我用了judge来存每个边权出现的次数,lim-tmp[j],表示在u结点的别的子树上是否存有当前值的边权,若有,ans+其数量即可。对于每个v所在自己的子树也要处理.

    #include<bits/stdc++.h>
    #define ll long long
    #define int long long
    #define rep(i,a,n) for(int i=a;i<=n;i++)
    #define per(i,n,a) for(int i=n;i>=a;i--)
    #define endl '
    '
    #define mem(a,b) memset(a,b,sizeof(a))
    #define IO ios::sync_with_stdio(false);cin.tie(0);
    using namespace std;
    const int INF=0x3f3f3f3f;
    const ll inf=0x3f3f3f3f3f3f3f3f;
    const int mod=1e9+7;
    const int maxn=5e4+5;
    int tot,head[maxn];
    struct E{
        int to,next,w;
    }edge[maxn<<1];
    void add(int u,int v,int w){
        edge[tot].to=v;
        edge[tot].w=w;
        edge[tot].next=head[u];
        head[u]=tot++;
    }
    int n,m,rt,ans=0,cnt,sum,lim,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
    bool vis[maxn];
    void getrt(int u,int f){
        siz[u]=1,maxp[u]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            getrt(v,u);
            siz[u]+=siz[v];
            if(siz[v]>maxp[u]) maxp[u]=siz[v];
        }
        maxp[u]=max(maxp[u],sum-siz[u]);
        if(maxp[u]<maxp[rt]) rt=u;
    }
    void getdis(int u,int f){
        if(dis[u]>lim) return ;
        tmp[cnt++]=dis[u];
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            dis[v]=dis[u]+edge[i].w;
            getdis(v,u);
        }
    }
    void solve(int u){
        int judge[510]={0};
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt=0;
            dis[v]=edge[i].w;
            getdis(v,u);
            for(int j=0;j<cnt;j++){
                ans+=judge[lim-tmp[j]];
            }
            for(int j=0;j<cnt;j++){
                if(tmp[j]==lim) ++ans;
                judge[tmp[j]]++;
            }            
        }
    }
    void divide(int u){
        vis[u]=1;
        solve(u);
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            maxp[rt=0]=sum=siz[v];
            getrt(v,0);
            getrt(rt,0);
            divide(rt);
        }
    }
    signed main(){
        scanf("%lld%lld",&n,&lim);mem(head,-1);
        rep(i,1,n-1){
            int u,v,w;scanf("%lld%lld",&u,&v);
            add(u,v,1);add(v,u,1);
        }
        maxp[0]=sum=n;
        getrt(1,0);
        getrt(rt,0);
        divide(rt);
        printf("%lld
    ",ans);
    }    
    View Code

     5.P4149 [IOI2011]Race

    我觉得这个题还是模板题的延伸,我用tmp2数组记录当前边权值和所运用的边的数量,judge[tmp[j]]表示当前tmp[j]边权下所运用的最少边数量,记得judge一开始初始化为INF,同样如果暴力会TLE,所以我们需要用队列queue来维护已经运用过的judge情况,最后更新删去队首初始化judge为INF,这样会很省时间

    #include<bits/stdc++.h>
    #define ll long long
    #define int long long
    #define rep(i,a,n) for(int i=a;i<=n;i++)
    #define per(i,n,a) for(int i=n;i>=a;i--)
    #define endl '
    '
    #define mem(a,b) memset(a,b,sizeof(a))
    #define IO ios::sync_with_stdio(false);cin.tie(0);
    using namespace std;
    const int INF=0x3f3f3f3f;
    const ll inf=0x3f3f3f3f3f3f3f3f;
    const int mod=1e9+7;
    const int maxn=2e5+5;
    const int maxk=1e6+5;
    int tot,head[maxn];
    struct E{
        int to,next,w;
    }edge[maxn<<1];
    void add(int u,int v,int w){
        edge[tot].to=v;
        edge[tot].w=w;
        edge[tot].next=head[u];
        head[u]=tot++;
    }
    int n,m,rt,ans=0,cnt,sum,lim,dis[maxn],tmp[maxn],siz[maxn],maxp[maxn];
    bool vis[maxn];
    void getrt(int u,int f){
        siz[u]=1,maxp[u]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            getrt(v,u);
            siz[u]+=siz[v];
            if(siz[v]>maxp[u]) maxp[u]=siz[v];
        }
        maxp[u]=max(maxp[u],sum-siz[u]);
        if(maxp[u]<maxp[rt]) rt=u;
    }
    int tmp2[maxn],ct,dep[maxn];
    void getdis(int u,int f){
        if(dis[u]>lim) return ;
        tmp2[cnt]=dep[u];    
        tmp[cnt++]=dis[u];
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            dis[v]=dis[u]+edge[i].w;
            dep[v]=dep[u]+1;
            getdis(v,u);
        }
    }
    int judge[maxk];
    void solve(int u){
        queue<int> que;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt=0;
            dis[v]=edge[i].w;
            dep[v]=1;
            getdis(v,u);
            for(int j=0;j<cnt;j++){
                if(judge[lim-tmp[j]]!=INF){
                    ans=min(ans,(tmp2[j]+judge[lim-tmp[j]]));
                }
            }
            for(int j=0;j<cnt;j++){
                if(tmp[j]==lim) ans=min(ans,tmp2[j]);
                que.push(tmp[j]);
                judge[tmp[j]]=min(judge[tmp[j]],tmp2[j]);
            }            
        }
        while(!que.empty()){
            judge[que.front()]=INF;
            que.pop();
        }
    }
    void divide(int u){
        vis[u]=1;
        solve(u);
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            maxp[rt=0]=sum=siz[v];
            getrt(v,0);
            getrt(rt,0);
            divide(rt);
        }
    }
    signed main(){
        scanf("%lld%lld",&n,&lim);
        mem(judge,INF);mem(head,-1);ans=INF;
        rep(i,1,n-1){
            int u,v,w;scanf("%lld%lld%lld",&u,&v,&w);
            u+=1,v+=1;
            add(u,v,w);add(v,u,w);
        }
        maxp[0]=sum=n;
        getrt(1,0);
        getrt(rt,0);
        divide(rt);
        if(ans!=INF) printf("%lld
    ",ans);
        else puts("-1");
    }    
    View Code

     6.hdu-4812 D-Tree

    这个题出的挺好,但我不知道为什么我一直RE(栈溢出),看了Hzwer的代码感觉差不多但是他能过。。。不过这题想法很好,因为要求两点val的积取模之后等于k,那么就用线性预处理逆元(这里我才学会,不会数论),然后点分治,一开始solve的时候,不带根节点u,每次都是子树上的,用mp查看其逆元是否存在,若存在则更新2个ans,然后再用一次getdis,把自己的子树的dis乘上根节点u的val,这样继续更新。最后的形态是所有子树的距离都带上根节点了,这个时候我们再来一次getdis来把所有情况给去掉,这题很有想法啊

    #include<algorithm>
    #include<cstdio>
    #include<map>
    #include<cmath>
    #include<cstring>
    #pragma comment(linker,"/STACK:102400000,102400000")
    #define ll long long
    #define rep(i,a,n) for(int i=a;i<=n;i++)
    #define per(i,n,a) for(int i=n;i>=a;i--)
    #define endl '
    '
    #define mod 1000003
    #define INF 1e9
    #define mem(a,b) memset(a,b,sizeof(a))
    using namespace std;
    int tot,head[100005];
    struct E{
        int to,next;
    }edge[200005];
    void add(int u,int v){
        edge[tot].to=v;
        edge[tot].next=head[u];
        head[u]=tot++;
    }
    int n,k,sum,cnt,rt,id[100005],siz[100005],maxp[100005];
    ll val[100005],dis[100005],tmp[100005];
    ll mp[1000005],ine[1000005];
    int ans1,ans2;
    bool vis[100005];
    void getrt(int u,int f){
        siz[u]=1,maxp[u]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            getrt(v,u);
            siz[u]+=siz[v];
            if(siz[v]>maxp[u]) maxp[u]=siz[v];
        }
        maxp[u]=max(maxp[u],sum-siz[u]);
        if(maxp[u]<maxp[rt]) rt=u;
    }
    void getdis(int u,int f){
        tmp[++cnt]=dis[u];
        id[cnt]=u;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            dis[v]=(dis[u]*val[v])%mod;
            getdis(v,u);
        }
    }
    void query(int x,int id){
        x=ine[x]*k%mod;
        int y=mp[x];
        if(y==0)return;
        if(y>id)swap(y,id);
        if(y<ans1||(y==ans1&&id<ans2))
            ans1=y,ans2=id;
    }
    void divide(int u){
        vis[u]=1;
        mp[val[u]]=u;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt=0;
            dis[v]=val[v];
            getdis(v,u);
            for(int j=1;j<=cnt;j++){
                query(tmp[j],id[j]);
            }
            cnt=0;
            dis[v]=(val[u]*val[v])%mod;
            getdis(v,u);
            for(int j=1;j<=cnt;j++){
                int now=mp[tmp[j]];
                if(!now||id[j]<now) mp[tmp[j]]=id[j];
            }        
        }
        mp[val[u]]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt=0;
            dis[v]=(val[u]*val[v])%mod;
            getdis(v,u);
            for(int j=1;j<=cnt;j++){
                mp[tmp[j]]=0;
            }
        }
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            rt=0;sum=siz[v];
            getrt(v,0);
            getrt(rt,0);
            divide(rt);
        }
    }
    int main(){
        ine[1]=1;
        for(int i=2;i<mod;i++){
            int a=mod/i,b=mod%i;
            ine[i]=(ine[b]*(-a)%mod+mod)%mod;    
        }
        while(~scanf("%d%d",&n,&k)){
            mem(head,-1);mem(vis,0);
            cnt=0;ans1=ans2=INF;
            rep(i,1,n) scanf("%d",&val[i]);
            rep(i,1,n-1){
                int u,v;scanf("%d%d",&u,&v);
                add(u,v);add(v,u);
            }
            rt=0;maxp[0]=n+1;sum=n;
            getrt(1,0);
            getrt(rt,0);
            divide(rt);
            if(ans1==INF) puts("No solution");
            else printf("%d %d
    ",ans1,ans2);
        }
        return 0;
    }
    View Code

     7.P2664 树上游戏※

    这个难度还是和之前6个题来说有很大提升。做了很久没有做出来。思路是:对于一个p作为根节点的子树中某个节点u而言,如果颜色c[u]在u到根节点p的路径上是第一次出现,那么对于根节点p及其不在u所在子树上的任一节点,这些节点均会产生siz[u]大小的新的贡献。同理对于根节点p也一样,所以一开始我们dfs1把根节点u的贡献全部算出来,并统计以p为点分治树根的情况下所有颜色col[c[u]]的贡献情况。

    接下来就要讨论一些经过p节点跨根的贡献情况,一个节点u在p的子树v为根的子树上,那么u到p这段链中产生的贡献就是每一个在u之前有几个不同颜色num,然后乘(siz[u]-siz[v]),记得一开始的时候要把v树置0噢,这个是跨根贡献,显然这些还是不够的,因为非v子树上还有不同的点,那怎么办?一开始的sum就有用了,sum说是p产生的贡献(其中包括v树,所以要减掉v树的情况),然后是对已用颜色的去重,若在该链上且在u之前被用过了,那么把sum-col 这样就能去重了。说不清楚太难了

    #include<bits/stdc++.h>
    #define ll long long
    #define rep(i,a,n) for(int i=a;i<=n;i++)
    #define per(i,n,a) for(int i=n;i>=a;i--)
    #define endl '
    '
    #define mem(a,b) memset(a,b,sizeof(a))
    #define IO ios::sync_with_stdio(false);cin.tie(0);
    using namespace std;
    const int INF=0x3f3f3f3f;
    const ll inf=0x3f3f3f3f3f3f3f3f;
    const int mod=1e9+7;
    const int maxn=1e5+5;
    int tot,head[maxn];
    struct E{
        int to,next;
    }edge[maxn<<1];
    void add(int u,int v){
        edge[tot].to=v;
        edge[tot].next=head[u];
        head[u]=tot++;
    }
    int n,rt,SIZE;
    int siz[maxn],c[maxn],maxp[maxn],tmp[maxn],dis[maxn];
    bool vis[maxn];
    void getrt(int u,int f){
        siz[u]=1,maxp[u]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(v==f||vis[v]) continue;
            getrt(v,u);
            siz[u]+=siz[v];
            if(siz[v]>maxp[u]) maxp[u]=siz[v];
        }
        maxp[u]=max(maxp[u],SIZE-siz[u]);
        if(maxp[u]<maxp[rt]) rt=u;
    }
    ll ans[maxn],cnt[maxn],col[maxn],sum,num,S;
    void dfs1(int u,int f){
        siz[u]=1;cnt[c[u]]++;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]||v==f) continue;
            dfs1(v,u);
            siz[u]+=siz[v];
        }
        if(cnt[c[u]]==1){
            sum+=siz[u];
            col[c[u]]+=siz[u];
        }
        cnt[c[u]]--;
    }
    void change(int u,int f,int k){
        cnt[c[u]]++;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]||v==f) continue;
            change(v,u,k);
        }
        if(cnt[c[u]]==1){
            sum+=k*siz[u];
            col[c[u]]+=k*siz[u];
        }
        cnt[c[u]]--;
    }
    void dfs2(int u,int f){
        cnt[c[u]]++;
        if(cnt[c[u]]==1){
            sum-=col[c[u]];num++;
        }
        ans[u]+=sum+num*S;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]||v==f) continue;
            dfs2(v,u);
        }
        if(cnt[c[u]]==1){
            sum+=col[c[u]];num--;
        }
        cnt[c[u]]--;
    }
    void clear(int u,int f){
        cnt[c[u]]=col[c[u]]=0;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]||v==f) continue;
            clear(v,u);
        }
    }
    void solve(int u){
        dfs1(u,0);ans[u]+=sum;
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            cnt[c[u]]++;sum-=siz[v];col[c[u]]-=siz[v];
            change(v,u,-1);cnt[c[u]]--;
            S=siz[u]-siz[v];dfs2(v,u);
            cnt[c[u]]++;sum+=siz[v];col[c[u]]+=siz[v];
            change(v,u,1);cnt[c[u]]--;        
        }    
        sum=0,num=0,clear(u,0);
    }
    void divide(int u){
        vis[u]=true;
        solve(u);
        for(int i=head[u];i!=-1;i=edge[i].next){
            int v=edge[i].to;
            if(vis[v]) continue;
            maxp[rt=0]=SIZE=siz[v];
            getrt(v,0);
            divide(rt);
        }
    }
    int main(){
        scanf("%d",&n);mem(head,-1);
        rep(i,1,n) scanf("%d",&c[i]);
        rep(i,1,n-1){
            int u,v;scanf("%d%d",&u,&v);
            add(u,v);add(v,u);
        }
        maxp[0]=SIZE=n;rt=0;
        getrt(1,0);
        divide(rt);    
        rep(i,1,n) cout<<ans[i]<<endl;
    }
    View Code
  • 相关阅读:
    io流
    JDBC-java数据库连接
    list接口、set接口、map接口、异常
    集合、迭代器、增强for
    math类和biginteger类
    基本包装类和System类
    正则表达式
    API-Object-equals方法和toString方法 Strinig字符串和StingBuffer类
    匿名对象 内部类 包 访问修饰符 代码块
    final 和 static 关键词
  • 原文地址:https://www.cnblogs.com/Anonytt/p/13053962.html
Copyright © 2020-2023  润新知