• HDU6035:Colorful Tree(树形DP)


    传送门

    题意

    给出一棵最小生成树及每个节点的颜色,询问(frac{n(n-1)}2)条路径的权值和,一条路径的权值为该路径的颜色种数

    分析

    勉强理解了ftae的做法,但是代码还是不太会,还是太弱了(⊙﹏⊙)。
    基本思想:求出每种颜色经过的路径数。
    做一定转化:总路径数-每种颜色未经过的路径数
    那么未经过的路径数如何求呢?借用虚树的思想,对于颜色c[i],将颜色为c[i]的节点从树上删去,这样树就变成了一个个联通块/节点,未经过的路径数为联通块自身路径数之和,比如一个4个节点的联通块路径数即为12.
    那么难度在于如何实现统计操作。我们设置sons[i]代表第i个节点的子树大小,sum[i]代表颜色i已经合并的节点数

    那么ct=sons[v]-sum[c[u]]+pre是什么?
    它是从当前点 v 到下面每条链下最近的颜色为 c[u] 的节点之间的节点的数量
    (请仔细思考)

    而这些节点未被合并到sum[c[u]]中,故它为一个联通块,统计该联通块的路径数。

    注意到 pre 的初始值为 sum[c[u]] ,也就是说 −sum[c[u]]+pre 保证我们算的是当前这颗子树下的未被合并的节点数量(因为我们前面可能先遍历了其它子树)(请仔细思考)

    感谢ftae的题解

    trick

    代码

    #include <bits/stdc++.h>
    using namespace std;
    
    #define ll long long
    #define F(i,a,b) for(int i=a;i<=b;++i)
    #define R(i,a,b) for(int i=a;i<b;++i)
    #define mem(a,b) memset(a,b,sizeof(a))
    
    const int N = 200200;
    int cas;
    int n,c[N];
    int vis[N];//第i种颜色是否出现
    vector<int>vec[N];
    ll s;//记录在向上合并的过程中某个颜色在其未出现的块里能形成多少条路径
    ll sum[N];//第i中颜色已合并的节点数
    int sons[N];//节点i的子树大小
    
    void dfs(int fa,int u)
    {
        sons[u]=1;
        sum[c[u]]++;//合并u
        ll pre=sum[c[u]];//已合并的节点数
        int sz=vec[u].size();
        R(i,0,sz)
        {
            int v=vec[u][i];
            if(v!=fa)
            {
                dfs(u,v);
                sons[u]+=sons[v];//将v子树节点数加入到u子树节点统计中
                ll ct=sons[v]-sum[c[u]]+pre;
                s+=1LL*ct*(ct-1)/2;
                sum[c[u]]+=ct;//合并v子树下的节点
                pre=sum[c[u]];
            }
        }
    }
    
    int main()
    {
        while(~scanf("%d",&n))
        {
            s=0;
            mem(sum,0);mem(vis,0);
            int num=0;
            F(i,1,n)
            {
                scanf("%d",c+i);
                if(!vis[c[i]]) { num++;vis[c[i]]=1; }
                vec[i].clear();
            }
            int u,v;
            R(i,1,n)//建树
            {
                scanf("%d %d",&u,&v);
                vec[u].push_back(v);
                vec[v].push_back(u);
            }
            dfs(0,1);
            ll ans=1LL*num*n*(n-1)/2-s;
            F(i,1,n) if(vis[i])
            {
                ll ct=n-sum[i];
                ans-=ct*(ct-1)/2;
            }
            printf("Case #%d: %I64d
    ", ++cas,ans);
        }
        return 0;
    }
    
  • 相关阅读:
    简明Python3教程 12.问题解决
    简明Python3教程 11.数据结构
    【SPOJ 694】Distinct Substrings
    【codeforces Manthan, Codefest 17 C】Helga Hufflepuff's Cup
    【CF Manthan, Codefest 17 B】Marvolo Gaunt's Ring
    【CF Manthan, Codefest 17 A】Tom Riddle's Diary
    【SPOJ 220】 PHRASES
    【POJ 3261】Milk Patterns
    【POJ 3294】Life Forms
    【POJ 1226】Substrings
  • 原文地址:https://www.cnblogs.com/chendl111/p/7242319.html
Copyright © 2020-2023  润新知