• nowcoder 79F 小H和圣诞树 换根 DP + 根号分治


    设节点个数大于 $sqrt n$ 的颜色为关键颜色,那么可以证明关键颜色最多有 $sqrt n$ 个.
    对于每个关键颜色,暴力预处理出该颜色到查询中另一个颜色的距离和.
    对于不是关键颜色的询问,直接建立虚树进行统计即可.
    由于不是关键颜色,节点数最多为 $sqrt n$ ,那么时间复杂度是 $O(2 imes nsqrt n)$.
    总时间复杂度为 $O(nsqrt n)$,这个就叫做根号分治.

    #include <cstdio>
    #include <algorithm>  
    #include <vector>    
    #include <cmath>   
    #include <map>
    #define N 100003 
    #define ll long long         
    #define setIO(s) freopen(s".in", "r" , stdin)  , freopen(s".out", "w" , stdout)  
    using namespace std; 
    namespace IO
    {
        char *p1,*p2,buf[100000];
        #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
        int readint() {int x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;}
        ll readll() {ll x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;}
    };
    vector <int> G[N], ty[N], node;      
    int n , edges, tim, toop;      
    ll dis[N], depth[N];    
    int col[N], tax[N], id[N], A[N], size[N], S[N]; 
    int hd[N], nex[N << 1], to[N << 1], top[N], dfn[N], fa[N], dep[N], son[N], siz[N];
    ll val[N << 1];     
    bool cmp(int a, int b)
    {
        return dfn[a] < dfn[b];    
    }    
    inline void addedge(int u, int v, int c)
    {
        nex[++ edges] = hd[u], hd[u] = edges, to[edges] = v, val[edges] = 1ll * c;
    }
    void dfs1(int u, int ff)
    {
        int i, v;
        fa[u] = ff, dep[u] = dep[ff] + 1, dfn[u] = ++ tim, siz[u] = 1;
        for(i = hd[u] ; i ; i = nex[i])
        {
            v = to[i];
            if(v == ff) continue;
            depth[v] = depth[u] + 1ll * val[i], dfs1(v, u), siz[u] += siz[v];
            if(siz[v] > siz[son[u]]) son[u] = v;    
        }
    }
    void dfs2(int u, int tp)
    {     
        top[u] = tp;   
        if(son[u]) dfs2(son[u], tp);    
        for(int i = hd[u] ; i ; i = nex[i])
        {
            int v = to[i];
            if(v == fa[u] || v == son[u]) continue; 
            dfs2(v, v);     
        }  
    }
    inline int LCA(int x, int y)
    {
        while(top[x] ^ top[y])
        {
            dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]];
        }
        return dep[x] < dep[y] ? x : y;     
    }     
    inline ll Dis(int x, int y)
    {
        return depth[x] + depth[y] - (depth[LCA(x, y)] << 1);   
    }
    void solve1(int u, int ff, int cur)
    {
        size[u] = (col[u] == cur), dis[u] = 0;                 
        for(int i = hd[u] ; i ; i = nex[i])
        {
            int v = to[i];
            if(v == ff) continue;  
            solve1(v, u, cur),size[u] += size[v], dis[u] += (dis[v] + 1ll * size[v] * val[i]);   
        }            
    }  
    void solve(int u, int ff, int cur)
    {  
        for(int i = hd[u] ; i ; i = nex[i])
        {
            int v = to[i];
            if(v == ff) continue;     
            dis[v] += (dis[u] - dis[v] - 1ll * size[v] * val[i] + 1ll * (tax[cur] - size[v]) * val[i]);    
            solve(v, u, cur);                 
        }   
    }
    inline void addvir(int u, int v)
    {
        G[u].push_back(v);    
    }
    inline void insert(int x)
    {
        if(toop < 2)
        {
            S[++ toop] = x;
            return ;
        }
        int lca = LCA(x, S[toop]);  
        if(lca != S[toop])
        {
            while(toop > 1 && dep[S[toop - 1]] >= dep[lca]) addvir(S[toop - 1], S[toop]),-- toop;  
            if(S[toop] != lca) addvir(lca, S[toop]), S[toop] = lca;    
        }
        S[++ toop] = x;      
    }  
    void pre(int u, int ff, int cur)
    {
        size[u] = (col[u] == cur), dis[u] = 0;     
        for(int i = 0; i < G[u].size(); ++ i)
        {
            int v = G[u][i];  
            pre(v, u, cur), size[u] += size[v], dis[u] += dis[v] + 1ll * size[v] * Dis(v, u);   
        }    
    }
    void work(int u, int ff, int cur)
    {
        for(int i = 0; i < G[u].size() ; ++ i)
        {
            int v = G[u][i];       
            dis[v] += (dis[u] - dis[v] - 1ll * size[v] * Dis(u, v) + 1ll * (tax[cur] - size[v]) * Dis(u, v)); 
            work(v, u, cur);     
        } 
    }
    void clear(int u)
    {
        size[u] = dis[u] = 0;
        for(int i = 0; i < G[u].size(); ++ i) clear(G[u][i]) ;
        G[u].clear();   
    }
    struct Node
    {
        int a, b;
    }ask[N];  
    vector < int > P[N];      
    vector < ll > answer[N];       
    int point[N];  
    int main()
    {
        using namespace IO;
        // setIO("input");
        int i , j, idx = 0, m, Q;
        n = readint();
        m = sqrt(n);
        for(i = 1; i <= n ; ++ i) col[i] = readint(), ++tax[col[i]], ty[col[i]].push_back(i);                                  
        for(i = 1; i < n ; ++ i)
        {
            int a = readint(), b = readint(), c = readint();
            addedge(a, b, c), addedge(b, a, c); 
        }      
        dfs1(1, 0), dfs2(1, 1);  
        for(i = 1; i <= n ; ++ i)   if(tax[i] >= m) id[i] = ++idx;          
        Q = readint();
        for(i = 1; i <= Q; ++ i)
        {
            ask[i].a = readint(), ask[i].b = readint(); 
            if(tax[ask[i].a] < tax[ask[i].b]) swap(ask[i].a, ask[i].b);  
            if(tax[ask[i].a] >= m) P[ask[i].a].push_back(ask[i].b);    
        }        
        for(i = 1; i <= n ; ++ i)
        {
            if(tax[i] >= m)
            {
                solve1(1, 0, i), solve(1, 0, i);       
                for(j = 0 ; j < P[i].size() ; ++ j)
                {
                    int cur = P[i][j];     
                    ll re = 0;         
                    for(int k = 0; k < ty[cur].size(); ++ k)
                    {
                        re += dis[ty[cur][k]];   
                    }
                    answer[i].push_back(re);      
                }
            } 
        }           
        for(int cas = 1; cas <= Q; ++ cas)
        {
            int a, b;
            a = ask[cas].a, b = ask[cas].b;       
            if(tax[a] >= m) printf("%lld
    ", a == b ? answer[a][point[a] ++ ] / 2 : answer[a][point[a] ++ ]);               
            else
            {      
                int tmp = 0;          
                ll re = 0;
                for(i = 0; i < ty[a].size(); ++ i) A[++ tmp] = ty[a][i]; 
                for(i = 0; i < ty[b].size(); ++ i) A[++ tmp] = ty[b][i];   
                sort(A + 1, A + 1 + tmp, cmp);   
                tmp = unique(A + 1, A + 1 + tmp) - (A + 1);         
                toop = 0; 
                if(A[1] != 1) S[++ toop] = 1;           
                for(i = 1 ; i <= tmp ; ++ i) insert(A[i]); 
                while(toop > 1) addvir(S[toop - 1], S[toop]), --toop;        
                pre(1, 0, b), work(1, 0, b);  
                for(i = 0; i < ty[a].size(); ++ i) re += dis[ty[a][i]];  
                printf("%lld
    ", a == b ? re / 2 : re);  
            }
        }
        return 0;   
    }
    

      

  • 相关阅读:
    java静态代码分析工具infer
    Go的安装和使用/卸载/升级、安装指定版本
    ldap服务器OpenLDAP安装使用
    python2 和 python3兼容写法
    ldap客户端以及jenkins的配置
    mac下java的安装和升级以及相关环境设置
    常见高危安全漏洞
    XFS: Cross Frame Script (跨框架脚本) 攻击。
    WEB渗透测试之三大漏扫神器
    编写自己的Acunetix WVS漏洞扫描脚本详细教程
  • 原文地址:https://www.cnblogs.com/guangheli/p/11367477.html
Copyright © 2020-2023  润新知