• 【HDU6035】 Colorful Tree


    题目的意思是:给定一个点带颜色的树,两点之间的距离定义为路径上不同颜色的个数。求所有点对间的距离和。

    做法有点分治,还有传说中的虚树DP,树上差分。

    点分治法:

      考虑每个点的贡献,可以发现一个点的子树大小就是这个点的贡献。那么,对于同一个根的另一个子树的一个点x,去掉x到根结点对应颜色的贡献,再加上x到根结点上的颜色的种类数目,就是这个x点的答案。我们具体做的时候,是先不考虑根结点的,根结点对x点的贡献单独算。

      

    #include <algorithm>
    #include  <iterator>
    #include  <iostream>
    #include   <cstring>
    #include   <cstdlib>
    #include   <iomanip>
    #include    <bitset>
    #include    <cctype>
    #include    <cstdio>
    #include    <string>
    #include    <vector>
    #include     <stack>
    #include     <cmath>
    #include     <queue>
    #include      <list>
    #include       <map>
    #include       <set>
    #include   <cassert>
    
    /*
     
     &#8834;_ヽ
       \\ Λ_Λ  来了老弟
        \('&#12613;')
         > ⌒ヽ
        /   へ\
        /  / \\
        &#65434; ノ   ヽ_つ
       / /
       / /|
      ( (ヽ
      | |、\
      | 丿 \ ⌒)
      | |  ) /
     'ノ )  L&#65417;
     
     */
    
    using namespace std;
    #define lson (l , mid , rt << 1)
    #define rson (mid + 1 , r , rt << 1 | 1)
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define pb push_back
    #define pq priority_queue
    
    typedef long long ll;
    typedef unsigned long long ull;
    //typedef __int128 bll;
    typedef pair<ll ,ll > pll;
    typedef pair<int ,int > pii;
    typedef pair<int,pii> p3;
    
    //priority_queue<int> q;//这是一个大根堆q
    //priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
    #define fi first
    #define se second
    //#define endl '
    '
    
    #define boost ios::sync_with_stdio(false);cin.tie(0)
    #define rep(a, b, c) for(int a = (b); a <= (c); ++ a)
    #define max3(a,b,c) max(max(a,b), c);
    #define min3(a,b,c) min(min(a,b), c);
    
    const ll oo = 1ll<<17;
    const ll mos = 0x7FFFFFFF;  //2147483647
    const ll nmos = 0x80000000;  //-2147483648
    const int inf = 0x3f3f3f3f;
    const ll inff = 0x3f3f3f3f3f3f3f3f; //18
    const int mod = 1e9+7;
    const double esp = 1e-8;
    const double PI=acos(-1.0);
    const double PHI=0.61803399;    //黄金分割点
    const double tPHI=0.38196601;
    
    template<typename T>
    inline T read(T&x){
        x=0;int f=0;char ch=getchar();
        while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
        while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
        return x=f?-x:x;
    }
    
    inline void cmax(int &x,int y){if(x<y)x=y;}
    inline void cmax(ll &x,ll y){if(x<y)x=y;}
    inline void cmin(int &x,int y){if(x>y)x=y;}
    inline void cmin(ll &x,ll y){if(x>y)x=y;}
    
    /*-----------------------showtime----------------------*/
    const int maxn = 2e5+9;
    int col[maxn];
    vector<int>mp[maxn];
    ll ans = 0, sumcol = 0;
    int sz[maxn],wt[maxn], root, curn;
    int vis[maxn];
    void findRoot(int u, int fa) {
        sz[u] = 1;wt[u] = 0;
        for(int i=0; i<mp[u].size(); i++) {
            int v = mp[u][i];
            if(v == fa || vis[v]) continue;
            findRoot(v, u);
            sz[u] += sz[v];
            wt[u] = max(sz[v], wt[u]);
        }
        wt[u] = max(wt[u], curn - sz[u]);
        if(wt[u] <= wt[root]) root = u;
    }
    // map<int, int> pp;
    ll pp[maxn];
    int youmeiyou[maxn];
    int ss;
    void gao(int u, int fa, vector<pii>& vv, int cnt, ll sumfa, ll sum) {
        ll res = 0;
        if(youmeiyou[col[u]] == 0)
            vv.pb(pii(col[u], sz[u])), cnt++, res += pp[col[u]];
        
        youmeiyou[col[u]]++;
        ans += sumcol - sumfa - res + 1ll * cnt * sum;
        //sum-color[根的颜色]+size[root]
        if(youmeiyou[col[ss]] == 0) ans += sum - pp[col[ss]];
        for(int i=0; i<mp[u].size(); i++) {
            int v = mp[u][i];
            if(fa == v || vis[v]) continue;
            gao(v, u, vv, cnt, sumfa + res, sum);
        }
        youmeiyou[col[u]] --;
    }
    
    void solve(int u) {
        vis[u] = 1;
        findRoot(u, -1);
        ll sum = 1;
        sumcol = 0;
        queue<int>needclear;
        needclear.push(col[u]);
        for(int i=0; i<mp[u].size(); i++) {
            int v = mp[u][i];
            if(vis[v]) continue;
            vector<pii>vv;
            ss = u;
            gao(v, -1, vv, 0, 0, sum);
            
            for(int j=0; j<vv.size(); j++){
                int c = vv[j].fi;
                if(pp[c])pp[c] += vv[j].se;
                else {
                    pp[c] = vv[j].se;
                    needclear.push(c);
                }
                sumcol += vv[j].se;
            }
            sum += sz[v];
        }
        
        while(!needclear.empty()) {
            pp[needclear.front()] = 0;
            needclear.pop();
        }
        for(int i=0; i<mp[u].size(); i++) {
            int v = mp[u][i];
            if(!vis[v]) {
                root = 0;   wt[0] = inf; curn = sz[v];
                findRoot(v, -1);
                solve(root);
            }
        }
    }
    int main(){
        int n, cas = 0;
        while(~scanf("%d", &n)) {
            memset(vis, 0, sizeof(vis));
            for(int i=1; i<=n; i++) scanf("%d", &col[i]);
            for(int i=1; i<=n; i++) mp[i].clear();
            for(int i=1; i<n; i++) {
                int u,v;
                scanf("%d%d", &u, &v);
                mp[u].pb(v);
                mp[v].pb(u);
            }
            
            ans = 0;
            root = 0; wt[0] = inf;
            curn = n;
            findRoot(1, -1);
            solve(root);
            printf("Case #%d: %lld
    ", ++cas, ans);
        }
        return 0;
    }
    /*
     6
     1 2 3 1 2 3
     1 2
     1 3
     3 4
     3 5
     4 6
     */
    View Code

     虚树 + 树上差分法:

      对于一种颜色,可以把树分割成许多联通块,同一个联通块内,这种颜色不会产生影响,所以某个点上,某个颜色的影响就是n - size,size是包含这个点的联通块的大小。

      由于有多种颜色,我们可以对每种颜色构建对应的虚树,选择这种颜色的点和这些点的直接儿子作为关键点。类似树上差分的思想,先把答案保存在每个联通块最上面的点。

      

    #include <bits/stdc++.h>
    
    using namespace std;
    #define pb push_back
    #define fi first
    #define se second
    #define debug(x) cerr<<#x << " := " << x << endl;
    #define bug cerr<<"-----------------------"<<endl;
    #define FOR(a, b, c) for(int a = b; a <= c; ++ a)
    
    typedef long long ll;
    typedef long double ld;
    typedef pair<int, int> pii;
    typedef pair<ll, ll> pll;
    typedef pair<pii, int>PII;
    
    template<class T> void _R(T &x) { cin >> x; }
    void _R(int &x) { scanf("%d", &x); }
    void _R(ll &x) { scanf("%lld", &x); }
    void _R(double &x) { scanf("%lf", &x); }
    void _R(char &x) { scanf(" %c", &x); }
    void _R(char *x) { scanf("%s", x); }
    void R() {}
    template<class T, class... U> void R(T &head, U &... tail) { _R(head); R(tail...); }
    
    
    template<typename T>
    inline T read(T&x){
        x=0;int f=0;char ch=getchar();
        while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
        while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
        return x=f?-x:x;
    }
    
    const ll inf = 0x3f3f3f3f3f3f3f3f;
    
    const int mod = 1e9+7;
    
    /**********showtime************/
                const int maxn = 2e5+9;
                int col[maxn],vis[maxn];
                vector<int>mp[maxn],xu_mp[maxn];
                vector<int>node[maxn],xu;
                int sz[maxn], dfn[maxn], dp[maxn], tim;
                int fa[maxn][20];
                ll  fen[maxn],ans;
    
                void dfs(int u, int o) {
                    sz[u] = 1;  dfn[u] = ++tim;
                    fa[u][0] = o;
                    dp[u] = dp[o]  + 1;
                    for(int i=1; i<20; i++)
                        fa[u][i] = fa[fa[u][i-1]][i-1];
                    for(int v : mp[u]) {
                        if(v == o) continue;
                        dfs(v, u);
                        sz[u] += sz[v];
                    }
                }
    
                int lca(int u, int v) {
                    if(dp[u] < dp[v]) swap(u, v);
    
                    for(int i=19; i>=0; i--) {
                        if(dp[fa[u][i]] >= dp[v])
                            u = fa[u][i];
                    }
                    if(u == v) return u;
    
                    for(int i=19; i>=0; i--) {
                        if(fa[u][i] != fa[v][i])
                            u = fa[u][i], v = fa[v][i];
                    }
                    return fa[u][0];
                }
    
                bool cmp(int x, int y) {
                    return dfn[x] < dfn[y];
                }
                int used[maxn];
                int nsz[maxn];
                int curcol;
                int n;
                int cdp[maxn];
                //求虚树上每个联通块的大小
                void gaoNewSz(int u, int o) {
                    ll s = 0;
                    cdp[u] = 0;
                    for(int v : xu_mp[u]) {
                        if(v == o) continue;
                        gaoNewSz(v, u);
                        if(col[v] == curcol)
                            cdp[u] += sz[v];
                        else cdp[u] += cdp[v];
                    }
                    nsz[u] = n - (sz[u] - cdp[u]);
                }
                //建立树上的差分
                void gaoSub(int u, int fa, int val) {
                    int w = val;
                    if(col[u] == curcol) {
                        fen[u] -= val;
                    }
                    else if(col[fa] == curcol || u == 1)
                    {
                        fen[u] += nsz[u];
                        w = nsz[u];
                    }
    
                    for(int v : xu_mp[u]) {
                        if(v == fa) continue;
                        if(col[u] == curcol)gaoSub(v, u, 0);
                        else gaoSub(v, u, w);
                    }
                }
    
                //建立虚树
                void build(vector <int> & xu) {
                    sort(xu.begin(), xu.end(), cmp);
                    stack<int>st;
                    queue<int>que;
    
                    for(int i=0; i<xu.size(); i++) {
                        int u = xu[i];
                        if(st.size() <= 1) st.push(u);
                        else {
                            int x = st.top(); st.pop();
                            int o = lca(x, u);
                            if(o == x) {
                                st.push(x);
                                st.push(u);
                                continue;
                            }
                            while(!st.empty()) {
                                int y = st.top(); st.pop();
    
                                if(dfn[y] > dfn[o]) {
                                    xu_mp[y].pb(x);
                                    if(used[y] == 0) used[y] = 1, que.push(y);
                                    x = y;
                                }
                                else if(dfn[y] == dfn[o]) {
                                    xu_mp[y].pb(x);
                                    st.push(y);
                                    if(used[y] == 0) used[y] = 1, que.push(y);
                                    break;
                                }
                                else {
                                    xu_mp[o].pb(x);
                                    st.push(y);
                                    st.push(o);
                                    if(used[o] == 0) used[o] = 1, que.push(o);
                                    break;
                                }
                            }
                            st.push(u);
                        }
                    }
                    while(st.size() > 1) {
                        int u = st.top(); st.pop();
                        int v = st.top();
                        xu_mp[v].pb(u);
                     //xu_mp[u].pb(v);
                     //   if(used[u] == 0) used[u] = 1, que.push(u);
                        if(used[v] == 0) used[v] = 1, que.push(v);
                    }
                    while(!st.empty())st.pop();
    
                    gaoNewSz(1, 1);
                    gaoSub(1, 1, 0);
    
                    while(!que.empty()) {
                        int u = que.front();
                        xu_mp[u].clear();
                        used[u] = 0;
                        que.pop();
                    }
                }
    
                //树上差分,最后的更新
                void pushdown(int u, int fa, ll val) {
                    ans  += fen[u] + val + n;
                    val += fen[u];
                    for(int v : mp[u]) {
                        if(v == fa) continue;
                        pushdown(v, u, val);
                    }
                }
    
    int main(){
                int cas = 0;
                while(~scanf("%d", &n)){
                    ans = 0;tim = 0;
                    for(int i=1; i<=n; i++){
                        mp[i].clear();
                        fen[i] = 0;
                        vis[i] = 0;
                        dp[i] = 0;
                        node[i].clear();
                    }
                    for(int i=1; i<=n; i++) {
                        read(col[i]);
                        vis[col[i]] = 1;
                        node[col[i]].pb(i);
                    }
                    for(int i=1; i<n; i++) {
                        int u,v;
                        read(u); read(v);
                        mp[u].pb(v);
                        mp[v].pb(u);
                    }
    
                    dfs(1, 1);
    
                    for(int i=1; i<maxn; i++) {
                        if(vis[i]) {
                            xu.clear();
                            if(col[1] != i) xu.pb(1);
                            for(int v : node[i]) {
                                xu.pb(v);
                                for(int k : mp[v]) {
                                    if(col[k] != i && dp[k] > dp[v])
                                        xu.pb(k);
                                }
                            }
                            curcol = i;
                            build(xu);
                        }
                    }
                    pushdown(1, 1, 0);
                    printf("Case #%d: %lld
    ", ++cas, (ans - n )/ 2);
                }
                return 0;
    }
    View Code

    附上虚树建立的网上流行模板

    void insert(int x) {
                    if(top == 1) {s[++top] = x; return ;}
                    int lca = LCA(x, s[top]);
                    if(lca == s[top]){ s[++top] = x;return ;}
                    while(top > 1 && dfn[s[top - 1]] >= dfn[lca]) add_edge(s[top - 1], s[top]), top--;
                    if(lca != s[top]) add_edge(lca, s[top]), s[top] = lca;//
                    s[++top] = x;
                }
  • 相关阅读:
    Java并发问题--乐观锁与悲观锁以及乐观锁的一种实现方式-CAS
    什么情况下Java程序会产生死锁?
    正确重写hashCode的办法
    浅析JVM类加载机制
    JVM中的新生代和老年代(Eden空间、两个Survior空间)
    解释循环中的递归调用
    get和post方法功能类似的,使用建议
    微信开发(五)微信消息加解密 (EncodingAESKey)
    PostgreSQL远程连接,发生致命错误:没有用于主机“…”,用户“…”,数据库“…”,SSL关闭的pg_hba.conf记录
    struts原理
  • 原文地址:https://www.cnblogs.com/ckxkexing/p/11129428.html
Copyright © 2020-2023  润新知