• HDU 5877 2016大连网络赛 Weak Pair(树状数组,线段树,动态开点,启发式合并,可持久化线段树)


    Weak Pair

    Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
    Total Submission(s): 1468    Accepted Submission(s): 472


    Problem Description
    You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
      (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
      (2) au×avk.

    Can you find the number of weak pairs in the tree?
     

    Input
    There are multiple cases in the data set.
      The first line of input contains an integer T denoting number of test cases.
      For each case, the first line contains two space-separated integers, N and k, respectively.
      The second line contains N space-separated integers, denoting a1 to aN.
      Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.

      Constrains: 
      
      1N105 
      
      0ai109 
      
      0k1018
     

    Output
    For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
     

    Sample Input
    1 2 3 1 2 1 2
     

    Sample Output
    1
    这是一道很好的数据结构的题目:
    可以用很多方法写
    首先思路是:dfs这颗树,每到一个节点,都计算这个节点的祖先中满足条件的有几个
    而计算这个就需要维护一个序列,并且高效的得出多少个祖先满足条件。
    即在序列中找到小于k/a[i]的数有多少个,很容易想到用树状数组和线段树。
    权值1e9需要离散化。
    树状数组:
    #include <iostream>
    #include <string.h>
    #include <stdlib.h>
    #include <algorithm>
    #include <math.h>
    #include <stdio.h>
    #include <map>
    
    using namespace std;
    const int maxn=1e5;
    typedef long long int LL;
    int n;
    LL k;
    LL a[maxn+5];
    struct Node
    {
        int value;
        int next;
    }edge[maxn*2+5];
    int head[maxn+5];
    int vis[maxn+5];
    int tot;
    int c[maxn*2+5];
    LL b[maxn+5];
    LL e[maxn*2+5];
    map<LL,int> m;
    void add(int x,int y)
    {
        edge[tot].value=y;
        edge[tot].next=head[x];
        head[x]=tot++;
    }
    int lowbit(int x)
    {
        return x&(-x);
    }
    void update(int x,int num)
    {
        while(x<=n*2)
        {
            c[x]+=num;
            x+=lowbit(x);
        }
    }
    int sum(int x)
    {
        int _sum=0;
        while(x>0)
        {
            _sum+=c[x];
            x-=lowbit(x);
        }
        return  _sum;
    }
    LL ans;
    void dfs(int root)
    {
        vis[root]=1;
        for(int i=head[root];i!=-1;i=edge[i].next)
        {
            int v=edge[i].value;
            if(!vis[v])
            {
                ans+=sum(m[b[v]]);
                update(m[a[v]],1);
                dfs(v);
                update(m[a[v]],-1);
            }
        }
    }
    
    void init()
    {
        memset(c,0,sizeof(c));
        memset(vis,0,sizeof(vis));
        memset(head,-1,sizeof(head));
        tot=0;
    }
    int tag[maxn+5];
    int main()
    {
        int t;
        scanf("%d",&t);
        int x,y;
        while(t--)
        {
            scanf("%d%lld",&n,&k);
            init();
            int cnt=n;
            m.clear();
            for(int i=1;i<=n;i++)
            {
                scanf("%lld",&a[i]);
                e[i]=a[i];
                if(a[i]==0)
                    m[a[i]]=2*n;
                else
                {
                    b[i]=k/a[i];
                    e[++cnt]=b[i];
                }
            }
            sort(e+1,e+cnt+1);
            int tot=1;
            for(int i=1;i<=cnt;i++)
            {
                if(!m.count(e[i]))
                    m[e[i]]=tot++;
            }
            memset(tag,0,sizeof(tag));
            for(int i=1;i<=n-1;i++)
            {
                scanf("%d%d",&x,&y);
                add(x,y);
                tag[y]++;
            }
            int root;
            for(int i=1;i<=n;i++)
            {
                if(tag[i]==0)
                    root=i;
            }
            ans=0;
            update(m[a[root]],1);
            dfs(root);
            printf("%lld
    ",ans);
        }
        return 0;
    }

    线段树:
    <pre name="code" class="html">#include <iostream>
    #include <string.h>
    #include <algorithm>
    #include <stdlib.h>
    #include <math.h>
    #include <stdio.h>
    #include <string>
    #include <map>
    #include <vector>
    
    using namespace std;
    typedef long long int LL;
    const int maxn=1e5;
    vector<int> v[maxn+5];
    int sum[maxn*8+5];
    int n;
    LL k;
    LL a[maxn+5];
    LL b[maxn+5];
    LL e[maxn*2+5];
    int aa[maxn+5];
    int bb[maxn+6];
    map<LL,int> m;
    
    void PushUp(int node)
    {
        sum[node]=sum[node<<1]+sum[node<<1|1];
    }
    void update(int node,int begin,int end,int ind,int num)
    {
        if(begin==end)
        {
            sum[node]+=num*(end-begin+1);
            return;
        }
        int m=(begin+end)>>1;
        if(ind<=m)
            update(node<<1,begin,m,ind,num);
        else
            update(node<<1|1,m+1,end,ind,num);
        PushUp(node);
    }
    LL Query(int node,int begin,int end,int left,int right)
    {
        if(left<=begin&&end<=right)
            return sum[node];
        int m=(begin+end)>>1;
        LL ret=0;
        if(left<=m)
            ret+=Query(node<<1,begin,m,left,right);
        if(right>m)
            ret+=Query(node<<1|1,m+1,end,left,right);
        PushUp(node);
        return ret;
    }
    int tag[maxn+5];
    LL ans;
    void dfs(int root)
    {
        int len=v[root].size();
        for(int i=0;i<len;i++)
        {
            int w=v[root][i];
            ans+=Query(1,1,2*n,1,bb[w]);
            update(1,1,2*n,aa[w],1);
            dfs(v[root][i]);
            update(1,1,2*n,aa[w],-1);
        }
    }
    void init()
    {
    
        memset(sum,0,sizeof(sum));
        memset(tag,0,sizeof(tag));
    }
    int main()
    {
        int t;
        scanf("%d",&t);
        int x,y;
        while(t--)
        {
            scanf("%d%lld",&n,&k);
            int cnt=0;
            init();
            m.clear();
            for(int i=1;i<=n;i++)
            {
                scanf("%lld",&a[i]);
                e[++cnt]=a[i];
                b[i]=k/a[i];
                e[++cnt]=b[i];
                v[i].clear();
            }
            sort(e+1,e+cnt+1);
            int cot=1;
            for(int i=1;i<=cnt;i++)
            {
                if(!m.count(e[i]))
                    m[e[i]]=cot++;
            }
            for(int i=1;i<=n;i++)
            {
                aa[i]=m[a[i]];
                bb[i]=m[b[i]];
            }
            for(int i=1;i<=n-1;i++)
            {
                scanf("%d%d",&x,&y);
                v[x].push_back(y);
                tag[y]++;
            }
            int root;
            for(int i=1;i<=n;i++)
            {
                if(tag[i]==0)
                    root=i;
            }
            ans=0;
            update(1,1,2*n,m[a[root]],1);
            dfs(root);
            printf("%lld
    ",ans);
        }
        return 0;
    }
    
    其实用线段树也可以不离散的方法做,是线段树的动态开点,动态开点就是用到了这个点再去开,不用的点不用开
    这样在0到1e18的范围内,最多储存的点也就n个叶子节点,开个8*n的空间就足够了

    </pre><pre code_snippet_id="1877993" snippet_file_name="blog_20160912_2_1715825" name="code" class="html"><pre name="code" class="html">#include <iostream>
    #include <string.h>
    #include <stdlib.h>
    #include <algorithm>
    #include <math.h>
    #include <string>
    #include <stdio.h>
    #include <vector>
    
    using namespace std;
    const int maxn=1e5;
    const long long int len=1e18;
    typedef long long int LL;
    LL a[maxn+5];
    LL b[maxn+5];
    int n;
    LL k;
    vector<int> v[maxn+5];
    struct Node
    {
        int lch,rch;
        LL sum;
        Node(){};
        Node(int lch,int rch,LL sum)
        {
            this->lch=lch;
            this->rch=rch;
            this->sum=sum;
        }
    }tr[maxn*100+5];
    int p;
    void PushUp(int node)
    {
        tr[node].sum=tr[tr[node].lch].sum+tr[tr[node].rch].sum;
    }
    
    int newnode()
    {
        tr[++p]=Node(-1,-1,0);
        return p;
    }
    void update(int node,LL begin,LL end,LL ind,int num)
    {
        if(begin==end)
        {
            tr[node].sum+=num;
            return;
        }
        LL m=(begin+end)>>1;
        if(tr[node].lch==-1) tr[node].lch=newnode();
        if(tr[node].rch==-1) tr[node].rch=newnode();
        if(ind<=m)
            update(tr[node].lch,begin,m,ind,num);
        else
            update(tr[node].rch,m+1,end,ind,num);
        PushUp(node);
    }
    LL query(int node,LL begin,LL end,LL left,LL right)
    {
        if(node==-1)
            return 0;
        if(left<=begin&&end<=right)
            return tr[node].sum;
        LL m=(begin+end)>>1;
        LL ret=0;
        if(left<=m)
            ret+=query(tr[node].lch,begin,m,left,right);
        if(right>m)
            ret+=query(tr[node].rch,m+1,end,left,right);
        PushUp(node);
        return ret;
    
    }
    int tag[maxn+5];
    LL ans;
    void dfs(int root)
    {
        int len1=v[root].size();
        for(int i=0;i<len1;i++)
        {
            int w=v[root][i];
            ans+=query(1,0,len,0,b[w]);
            update(1,0,len,a[w],1);
            dfs(w);
            update(1,0,len,a[w],-1);
        }
    }
    void init()
    {
        memset(tag,0,sizeof(tag));
        p=0;
        newnode();
    }
    int main()
    {
        int t;
        scanf("%d",&t);
        int x,y;
        while(t--)
        {
            scanf("%d%lld",&n,&k);
            for(int i=1;i<=n;i++)
            {
                scanf("%lld",&a[i]);
                b[i]=k/a[i];
                v[i].clear();
            }
            init();
            for(int i=1;i<=n-1;i++)
            {
                scanf("%d%d",&x,&y);
                v[x].push_back(y);
                tag[y]++;
            }
            int root;
            for(int i=1;i<=n;i++)
            {
                if(!tag[i])
                    root=i;
            }
            ans=0;
            update(1,0,len,a[root],1);
            dfs(root);
            printf("%lld
    ",ans);
        }
        return 0;
    }

    还可以自下而上,用线段树的启发合并,计算每一个节点的所有子节点对他的贡献
    关于线段树的启发式合并,有必要再写一篇博客总结一下

    #include <iostream>
    #include <string.h>
    #include <stdlib.h>
    #include <stdio.h>
    #include <algorithm>
    #include <math.h>
    
    using namespace std;
    const int maxn=1e5;
    typedef long long int LL;
    int rt[maxn*100+5];
    int ls[maxn*100+5];
    int rs[maxn*100+5];
    LL sum[maxn*100+5];
    int a[maxn+5];
    LL k;
    int n;
    int p;
    int l,r;
    int newnode()
    {
        sum[p]=ls[p]=rs[p]=0;
        return p++;
    }
    void build(int &node,int begin,int end,LL val)
    {
        if(!node) node=newnode();
        sum[node]=1;
        if(begin==end) return;
        LL mid=(begin+end)>>1;
        if(val<=mid) build(ls[node],begin,mid,val);
        else build(rs[node],mid+1,end,val);
    }
    LL Query(int node,int begin,int end,LL val)
    {
        if(!node||val<begin) return 0;
        if(begin==end) return sum[node];
        LL mid=(begin+end)>>1;
        if(val<=mid) return Query(ls[node],begin,mid,val);
        else return sum[ls[node]]+Query(rs[node],mid+1,end,val);
    }
    void mergge(int &x,int y, int begin,int end)
    {
        if(!x||!y) {x=x^y;return;}
        sum[x]+=sum[y];
        if(begin==end) return;
        LL mid=(begin+end)>>1;
        mergge(ls[x],ls[y],begin,mid);
        mergge(rs[x],rs[y],mid+1,end);
    }
    struct Node
    {
        int value;
        int next;
    }edge[maxn*2+5];
    int head[maxn+5];
    int tot;
    void add(int x,int y)
    {
        edge[tot].value=y;
        edge[tot].next=head[x];
        head[x]=tot++;
    }
    LL ans;
    void dfs(int root)
    {
        for(int i=head[root];i!=-1;i=edge[i].next)
        {
            int w=edge[i].value;
               dfs(w);
               mergge(rt[root],rt[w],l,r);
        }
        ans+=Query(rt[root],l,r,k/a[root]);
        if(k>=1ll*a[root]*a[root])
            ans--;
    }
    int tag[maxn+5];
    int main()
    {
        int t;
        scanf("%d",&t);
        int x,y;
        while(t--)
        {
            scanf("%d%lld",&n,&k);
            p=1;
            memset(tag,0,sizeof(tag));
            memset(head,-1,sizeof(head));
    
            tot=0;
            l=1e9;r=0;
            for(int i=1;i<=n;i++)
            {
                 scanf("%d",&a[i]);
                 l=min(l,a[i]);r=max(r,a[i]);
            }
            for(int i=1;i<=n;i++)
                build(rt[i]=0,l,r,a[i]);
    
    
            for(int i=1;i<=n-1;i++)
            {
                scanf("%d%d",&x,&y);
                add(x,y);
                tag[y]++;
            }
            int root;
            for(int i=1;i<=n;i++)
               if(tag[i]==0) root=i;
            ans=0;
            dfs(root);
            printf("%lld
    ",ans);
        }
        return 0;
    }

    也可以用拓扑排序,自下而上进行启发式合并,
    #include <iostream>
    #include <string.h>
    #include <stdlib.h>
    #include <stdio.h>
    #include <algorithm>
    #include <math.h>
    #include <queue>
    
    using namespace std;
    const int maxn=1e5;
    typedef long long int LL;
    int rt[maxn*100+5];
    int ls[maxn*100+5];
    int rs[maxn*100+5];
    LL sum[maxn*100+5];
    int a[maxn+5];
    int f[maxn+5];
    LL k;
    int n;
    int p;
    int l,r;
    queue<int> q;
    int newnode()
    {
        sum[p]=ls[p]=rs[p]=0;
        return p++;
    }
    void build(int &node,int begin,int end,LL val)
    {
        if(!node) node=newnode();
        sum[node]=1;
        if(begin==end) return;
        LL mid=(begin+end)>>1;
        if(val<=mid) build(ls[node],begin,mid,val);
        else build(rs[node],mid+1,end,val);
    }
    LL Query(int node,int begin,int end,LL val)
    {
        if(!node||val<begin) return 0;
        if(begin==end) return sum[node];
        LL mid=(begin+end)>>1;
        if(val<=mid) return Query(ls[node],begin,mid,val);
        else return sum[ls[node]]+Query(rs[node],mid+1,end,val);
    }
    void mergge(int &x,int y, int begin,int end)
    {
        if(!x||!y) {x=x^y;return;}
        sum[x]+=sum[y];
        if(begin==end) return;
        LL mid=(begin+end)>>1;
        mergge(ls[x],ls[y],begin,mid);
        mergge(rs[x],rs[y],mid+1,end);
    }
    LL ans;
    int tag[maxn+5];
    int main()
    {
        int t;
        scanf("%d",&t);
        int x,y;
        while(t--)
        {
            scanf("%d%lld",&n,&k);
            p=1;
            memset(tag,0,sizeof(tag));
            l=1e9;r=0;
            for(int i=1;i<=n;i++)
            {
                 scanf("%d",&a[i]);
                 l=min(l,a[i]);r=max(r,a[i]);
            }
            for(int i=1;i<=n;i++)
                build(rt[i]=0,l,r,a[i]);
    
    
            for(int i=1;i<=n-1;i++)
            {
                scanf("%d%d",&x,&y);
                tag[x]++;
    			f[y]=x;
            }
            for(int i=1;i<=n;i++)
    		{
               if(tag[i]==0)
    			   q.push(i);
    		}
            ans=0;
    		while(!q.empty())
    		{
    			int x=q.front();q.pop();
    			if(1LL*a[x]*a[x]<=k) ans--;
    			ans+=Query(rt[x],l,r,k/a[x]);
    			mergge(rt[f[x]],rt[x],l,r);
    			if(!--tag[f[x]]) q.push(f[x]);
    			
    		}
            printf("%lld
    ",ans);
        }
        return 0;
    }

    最后写一种,可持续化线段树的解法。首先将树形转成线形,然后逐个点插入,求一个根节点的子树对根节点的贡献,就是求DFS序列一段区间
    小于k/a[i]的有多少个,可持续化线段树利用类似前缀和的原理,tree[r]-tree[l-1]就是l到r这一段区间所有点的线段树

    #include <iostream>
    #include <string.h>
    #include <stdlib.h>
    #include <stdio.h>
    #include <algorithm>
    #include <math.h>
    #include <stack>
    
    using namespace std;
    const int maxn=1e5;
    typedef long long int LL;
    int rt[maxn*100+5];
    int ls[maxn*100+5];
    int rs[maxn*100+5];
    LL sum[maxn*100+5];
    int p;
    int n;
    LL k;
    int l,r;
    void update(int &node,int l,int r,int val)
    {
    	
        ls[p]=ls[node];rs[p]=rs[node];
        sum[p]=sum[node];node=p;
    	p++;
        
        if(l==r)
        {
            sum[node]++;
            return;
        }
        sum[node]++;
        int mid=(l+r)>>1;
        if(val<=mid) update(ls[node],l,mid,val);
        else update(rs[node],mid+1,r,val);
    }
    LL query(int node,int l,int r,LL val)
    {
    	if(val<l) return 0;
        if(!node) return 0;
        if(l==r) return sum[node];
        LL mid=(l+r)>>1;
        if(val<=mid) return query(ls[node],l,mid,val);
        else return sum[ls[node]]+query(rs[node],mid+1,r,val);
    }
    struct Node
    {
        int value;
        int next;
    }edge[maxn*2+5];
    int head[maxn+5];
    int tot;
    void add(int x,int y)
    {
        edge[tot].value=y;
        edge[tot].next=head[x];
        head[x]=tot++;
    }
    int res[maxn*2];
    int a[maxn+5];
    int cot;
    void dfs(int root)
    {
        res[cot++]=root;
        for(int i=head[root];i!=-1;i=edge[i].next)
        {
            int w=edge[i].value;
            dfs(w);
        }
        res[cot++]=root;
    }
    int tag[maxn+5];
    int flag[maxn+5];
    int main()
    {
        int t;
        scanf("%d",&t);
        int x,y;
        while(t--)
        {
            scanf("%d%lld",&n,&k);
            l=1e9;r=0;
            for(int i=1;i<=n;i++)
            {
                 scanf("%d",&a[i]);
                 l=min(l,a[i]);r=max(r,a[i]);
            }
               
            memset(head,-1,sizeof(head));
            memset(tag,0,sizeof(tag));
            tot=0;
            p=1;
            for(int i=1;i<=n-1;i++)
            {
                scanf("%d%d",&x,&y);
                add(x,y);
                tag[y]++;
            }
            int root;
            for(int i=1;i<=n;i++)
            {
                if(!tag[i])
                    root=i;
            }
            cot=0;
            dfs(root);
            memset(flag,0,sizeof(flag));
            update(rt[res[0]],l,r,a[res[0]]);
            flag[res[0]]=1;
            LL ans=0;
    		int now=0;
            for(int i=1;i<cot;i++)
            {
                if(flag[res[i]]==1)
                {
    				LL ans1=query(rt[res[now]],l,r,k/a[res[i]]);
    				LL ans2=query(rt[res[i]],l,r,k/a[res[i]]);
    				//cout<<ans1<<" "<<ans2<<endl;
                    ans+=ans1-ans2;
                    continue;
                }
                flag[res[i]]=1;
                update(rt[res[i]]=rt[res[now]],l,r,a[res[i]]);
    			now=i;
            }
            printf("%lld
    ",ans);
        }
        return 0;
    }
    



    
    
    
    
  • 相关阅读:
    Python 正则表达式匹配两个指定字符串中间的内容
    Switch Case 和 If Else
    MYSQL LETT/RIGHT/INNER/FUll JOIN 注意事项
    MYSQL 批处理 Packet for query is too large
    Junit单元测试不支持多线程
    postman中 form-data、x-www-form-urlencoded、raw、binary的区别
    一个项目中:只能存在一个 WebMvcConfigurationSupport (添加swagger坑)
    Nginx 转发特点URL到指定服务
    基于UDP协议的程序设计
    TcpClient类与TcpListener类
  • 原文地址:https://www.cnblogs.com/dacc123/p/8228584.html
Copyright © 2020-2023  润新知