Weak Pair
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)Total Submission(s): 1468 Accepted Submission(s): 472
Problem Description
You are given a rooted tree
of N nodes,
labeled from 1 to N .
To the i th
node a non-negative value ai is
assigned.An ordered pair
of nodes (u,v) is
said to be weak if
(1)u is
an ancestor of v (Note:
In this problem a node u is
not considered an ancestor of itself);
(2)au×av≤k .
Can you find the number of weak pairs in the tree?
(1)
(2)
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integerT denoting
number of test cases.
For each case, the first line contains two space-separated integers,N and k ,
respectively.
The second line containsN space-separated
integers, denoting a1 to aN .
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodesu and v ,
where node u is
the parent of node v .
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
The first line of input contains an integer
For each case, the first line contains two space-separated integers,
The second line contains
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes
Constrains:
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1 2 3 1 2 1 2
Sample Output
1这是一道很好的数据结构的题目:可以用很多方法写首先思路是:dfs这颗树,每到一个节点,都计算这个节点的祖先中满足条件的有几个而计算这个就需要维护一个序列,并且高效的得出多少个祖先满足条件。即在序列中找到小于k/a[i]的数有多少个,很容易想到用树状数组和线段树。权值1e9需要离散化。树状数组:#include <iostream> #include <string.h> #include <stdlib.h> #include <algorithm> #include <math.h> #include <stdio.h> #include <map> using namespace std; const int maxn=1e5; typedef long long int LL; int n; LL k; LL a[maxn+5]; struct Node { int value; int next; }edge[maxn*2+5]; int head[maxn+5]; int vis[maxn+5]; int tot; int c[maxn*2+5]; LL b[maxn+5]; LL e[maxn*2+5]; map<LL,int> m; void add(int x,int y) { edge[tot].value=y; edge[tot].next=head[x]; head[x]=tot++; } int lowbit(int x) { return x&(-x); } void update(int x,int num) { while(x<=n*2) { c[x]+=num; x+=lowbit(x); } } int sum(int x) { int _sum=0; while(x>0) { _sum+=c[x]; x-=lowbit(x); } return _sum; } LL ans; void dfs(int root) { vis[root]=1; for(int i=head[root];i!=-1;i=edge[i].next) { int v=edge[i].value; if(!vis[v]) { ans+=sum(m[b[v]]); update(m[a[v]],1); dfs(v); update(m[a[v]],-1); } } } void init() { memset(c,0,sizeof(c)); memset(vis,0,sizeof(vis)); memset(head,-1,sizeof(head)); tot=0; } int tag[maxn+5]; int main() { int t; scanf("%d",&t); int x,y; while(t--) { scanf("%d%lld",&n,&k); init(); int cnt=n; m.clear(); for(int i=1;i<=n;i++) { scanf("%lld",&a[i]); e[i]=a[i]; if(a[i]==0) m[a[i]]=2*n; else { b[i]=k/a[i]; e[++cnt]=b[i]; } } sort(e+1,e+cnt+1); int tot=1; for(int i=1;i<=cnt;i++) { if(!m.count(e[i])) m[e[i]]=tot++; } memset(tag,0,sizeof(tag)); for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); add(x,y); tag[y]++; } int root; for(int i=1;i<=n;i++) { if(tag[i]==0) root=i; } ans=0; update(m[a[root]],1); dfs(root); printf("%lld ",ans); } return 0; }线段树:<pre name="code" class="html">#include <iostream> #include <string.h> #include <algorithm> #include <stdlib.h> #include <math.h> #include <stdio.h> #include <string> #include <map> #include <vector> using namespace std; typedef long long int LL; const int maxn=1e5; vector<int> v[maxn+5]; int sum[maxn*8+5]; int n; LL k; LL a[maxn+5]; LL b[maxn+5]; LL e[maxn*2+5]; int aa[maxn+5]; int bb[maxn+6]; map<LL,int> m; void PushUp(int node) { sum[node]=sum[node<<1]+sum[node<<1|1]; } void update(int node,int begin,int end,int ind,int num) { if(begin==end) { sum[node]+=num*(end-begin+1); return; } int m=(begin+end)>>1; if(ind<=m) update(node<<1,begin,m,ind,num); else update(node<<1|1,m+1,end,ind,num); PushUp(node); } LL Query(int node,int begin,int end,int left,int right) { if(left<=begin&&end<=right) return sum[node]; int m=(begin+end)>>1; LL ret=0; if(left<=m) ret+=Query(node<<1,begin,m,left,right); if(right>m) ret+=Query(node<<1|1,m+1,end,left,right); PushUp(node); return ret; } int tag[maxn+5]; LL ans; void dfs(int root) { int len=v[root].size(); for(int i=0;i<len;i++) { int w=v[root][i]; ans+=Query(1,1,2*n,1,bb[w]); update(1,1,2*n,aa[w],1); dfs(v[root][i]); update(1,1,2*n,aa[w],-1); } } void init() { memset(sum,0,sizeof(sum)); memset(tag,0,sizeof(tag)); } int main() { int t; scanf("%d",&t); int x,y; while(t--) { scanf("%d%lld",&n,&k); int cnt=0; init(); m.clear(); for(int i=1;i<=n;i++) { scanf("%lld",&a[i]); e[++cnt]=a[i]; b[i]=k/a[i]; e[++cnt]=b[i]; v[i].clear(); } sort(e+1,e+cnt+1); int cot=1; for(int i=1;i<=cnt;i++) { if(!m.count(e[i])) m[e[i]]=cot++; } for(int i=1;i<=n;i++) { aa[i]=m[a[i]]; bb[i]=m[b[i]]; } for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); v[x].push_back(y); tag[y]++; } int root; for(int i=1;i<=n;i++) { if(tag[i]==0) root=i; } ans=0; update(1,1,2*n,m[a[root]],1); dfs(root); printf("%lld ",ans); } return 0; }其实用线段树也可以不离散的方法做,是线段树的动态开点,动态开点就是用到了这个点再去开,不用的点不用开这样在0到1e18的范围内,最多储存的点也就n个叶子节点,开个8*n的空间就足够了
</pre><pre code_snippet_id="1877993" snippet_file_name="blog_20160912_2_1715825" name="code" class="html"><pre name="code" class="html">#include <iostream> #include <string.h> #include <stdlib.h> #include <algorithm> #include <math.h> #include <string> #include <stdio.h> #include <vector> using namespace std; const int maxn=1e5; const long long int len=1e18; typedef long long int LL; LL a[maxn+5]; LL b[maxn+5]; int n; LL k; vector<int> v[maxn+5]; struct Node { int lch,rch; LL sum; Node(){}; Node(int lch,int rch,LL sum) { this->lch=lch; this->rch=rch; this->sum=sum; } }tr[maxn*100+5]; int p; void PushUp(int node) { tr[node].sum=tr[tr[node].lch].sum+tr[tr[node].rch].sum; } int newnode() { tr[++p]=Node(-1,-1,0); return p; } void update(int node,LL begin,LL end,LL ind,int num) { if(begin==end) { tr[node].sum+=num; return; } LL m=(begin+end)>>1; if(tr[node].lch==-1) tr[node].lch=newnode(); if(tr[node].rch==-1) tr[node].rch=newnode(); if(ind<=m) update(tr[node].lch,begin,m,ind,num); else update(tr[node].rch,m+1,end,ind,num); PushUp(node); } LL query(int node,LL begin,LL end,LL left,LL right) { if(node==-1) return 0; if(left<=begin&&end<=right) return tr[node].sum; LL m=(begin+end)>>1; LL ret=0; if(left<=m) ret+=query(tr[node].lch,begin,m,left,right); if(right>m) ret+=query(tr[node].rch,m+1,end,left,right); PushUp(node); return ret; } int tag[maxn+5]; LL ans; void dfs(int root) { int len1=v[root].size(); for(int i=0;i<len1;i++) { int w=v[root][i]; ans+=query(1,0,len,0,b[w]); update(1,0,len,a[w],1); dfs(w); update(1,0,len,a[w],-1); } } void init() { memset(tag,0,sizeof(tag)); p=0; newnode(); } int main() { int t; scanf("%d",&t); int x,y; while(t--) { scanf("%d%lld",&n,&k); for(int i=1;i<=n;i++) { scanf("%lld",&a[i]); b[i]=k/a[i]; v[i].clear(); } init(); for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); v[x].push_back(y); tag[y]++; } int root; for(int i=1;i<=n;i++) { if(!tag[i]) root=i; } ans=0; update(1,0,len,a[root],1); dfs(root); printf("%lld ",ans); } return 0; }
关于线段树的启发式合并,有必要再写一篇博客总结一下
#include <iostream> #include <string.h> #include <stdlib.h> #include <stdio.h> #include <algorithm> #include <math.h> using namespace std; const int maxn=1e5; typedef long long int LL; int rt[maxn*100+5]; int ls[maxn*100+5]; int rs[maxn*100+5]; LL sum[maxn*100+5]; int a[maxn+5]; LL k; int n; int p; int l,r; int newnode() { sum[p]=ls[p]=rs[p]=0; return p++; } void build(int &node,int begin,int end,LL val) { if(!node) node=newnode(); sum[node]=1; if(begin==end) return; LL mid=(begin+end)>>1; if(val<=mid) build(ls[node],begin,mid,val); else build(rs[node],mid+1,end,val); } LL Query(int node,int begin,int end,LL val) { if(!node||val<begin) return 0; if(begin==end) return sum[node]; LL mid=(begin+end)>>1; if(val<=mid) return Query(ls[node],begin,mid,val); else return sum[ls[node]]+Query(rs[node],mid+1,end,val); } void mergge(int &x,int y, int begin,int end) { if(!x||!y) {x=x^y;return;} sum[x]+=sum[y]; if(begin==end) return; LL mid=(begin+end)>>1; mergge(ls[x],ls[y],begin,mid); mergge(rs[x],rs[y],mid+1,end); } struct Node { int value; int next; }edge[maxn*2+5]; int head[maxn+5]; int tot; void add(int x,int y) { edge[tot].value=y; edge[tot].next=head[x]; head[x]=tot++; } LL ans; void dfs(int root) { for(int i=head[root];i!=-1;i=edge[i].next) { int w=edge[i].value; dfs(w); mergge(rt[root],rt[w],l,r); } ans+=Query(rt[root],l,r,k/a[root]); if(k>=1ll*a[root]*a[root]) ans--; } int tag[maxn+5]; int main() { int t; scanf("%d",&t); int x,y; while(t--) { scanf("%d%lld",&n,&k); p=1; memset(tag,0,sizeof(tag)); memset(head,-1,sizeof(head)); tot=0; l=1e9;r=0; for(int i=1;i<=n;i++) { scanf("%d",&a[i]); l=min(l,a[i]);r=max(r,a[i]); } for(int i=1;i<=n;i++) build(rt[i]=0,l,r,a[i]); for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); add(x,y); tag[y]++; } int root; for(int i=1;i<=n;i++) if(tag[i]==0) root=i; ans=0; dfs(root); printf("%lld ",ans); } return 0; }
也可以用拓扑排序,自下而上进行启发式合并,
#include <iostream> #include <string.h> #include <stdlib.h> #include <stdio.h> #include <algorithm> #include <math.h> #include <queue> using namespace std; const int maxn=1e5; typedef long long int LL; int rt[maxn*100+5]; int ls[maxn*100+5]; int rs[maxn*100+5]; LL sum[maxn*100+5]; int a[maxn+5]; int f[maxn+5]; LL k; int n; int p; int l,r; queue<int> q; int newnode() { sum[p]=ls[p]=rs[p]=0; return p++; } void build(int &node,int begin,int end,LL val) { if(!node) node=newnode(); sum[node]=1; if(begin==end) return; LL mid=(begin+end)>>1; if(val<=mid) build(ls[node],begin,mid,val); else build(rs[node],mid+1,end,val); } LL Query(int node,int begin,int end,LL val) { if(!node||val<begin) return 0; if(begin==end) return sum[node]; LL mid=(begin+end)>>1; if(val<=mid) return Query(ls[node],begin,mid,val); else return sum[ls[node]]+Query(rs[node],mid+1,end,val); } void mergge(int &x,int y, int begin,int end) { if(!x||!y) {x=x^y;return;} sum[x]+=sum[y]; if(begin==end) return; LL mid=(begin+end)>>1; mergge(ls[x],ls[y],begin,mid); mergge(rs[x],rs[y],mid+1,end); } LL ans; int tag[maxn+5]; int main() { int t; scanf("%d",&t); int x,y; while(t--) { scanf("%d%lld",&n,&k); p=1; memset(tag,0,sizeof(tag)); l=1e9;r=0; for(int i=1;i<=n;i++) { scanf("%d",&a[i]); l=min(l,a[i]);r=max(r,a[i]); } for(int i=1;i<=n;i++) build(rt[i]=0,l,r,a[i]); for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); tag[x]++; f[y]=x; } for(int i=1;i<=n;i++) { if(tag[i]==0) q.push(i); } ans=0; while(!q.empty()) { int x=q.front();q.pop(); if(1LL*a[x]*a[x]<=k) ans--; ans+=Query(rt[x],l,r,k/a[x]); mergge(rt[f[x]],rt[x],l,r); if(!--tag[f[x]]) q.push(f[x]); } printf("%lld ",ans); } return 0; }
小于k/a[i]的有多少个,可持续化线段树利用类似前缀和的原理,tree[r]-tree[l-1]就是l到r这一段区间所有点的线段树
#include <iostream> #include <string.h> #include <stdlib.h> #include <stdio.h> #include <algorithm> #include <math.h> #include <stack> using namespace std; const int maxn=1e5; typedef long long int LL; int rt[maxn*100+5]; int ls[maxn*100+5]; int rs[maxn*100+5]; LL sum[maxn*100+5]; int p; int n; LL k; int l,r; void update(int &node,int l,int r,int val) { ls[p]=ls[node];rs[p]=rs[node]; sum[p]=sum[node];node=p; p++; if(l==r) { sum[node]++; return; } sum[node]++; int mid=(l+r)>>1; if(val<=mid) update(ls[node],l,mid,val); else update(rs[node],mid+1,r,val); } LL query(int node,int l,int r,LL val) { if(val<l) return 0; if(!node) return 0; if(l==r) return sum[node]; LL mid=(l+r)>>1; if(val<=mid) return query(ls[node],l,mid,val); else return sum[ls[node]]+query(rs[node],mid+1,r,val); } struct Node { int value; int next; }edge[maxn*2+5]; int head[maxn+5]; int tot; void add(int x,int y) { edge[tot].value=y; edge[tot].next=head[x]; head[x]=tot++; } int res[maxn*2]; int a[maxn+5]; int cot; void dfs(int root) { res[cot++]=root; for(int i=head[root];i!=-1;i=edge[i].next) { int w=edge[i].value; dfs(w); } res[cot++]=root; } int tag[maxn+5]; int flag[maxn+5]; int main() { int t; scanf("%d",&t); int x,y; while(t--) { scanf("%d%lld",&n,&k); l=1e9;r=0; for(int i=1;i<=n;i++) { scanf("%d",&a[i]); l=min(l,a[i]);r=max(r,a[i]); } memset(head,-1,sizeof(head)); memset(tag,0,sizeof(tag)); tot=0; p=1; for(int i=1;i<=n-1;i++) { scanf("%d%d",&x,&y); add(x,y); tag[y]++; } int root; for(int i=1;i<=n;i++) { if(!tag[i]) root=i; } cot=0; dfs(root); memset(flag,0,sizeof(flag)); update(rt[res[0]],l,r,a[res[0]]); flag[res[0]]=1; LL ans=0; int now=0; for(int i=1;i<cot;i++) { if(flag[res[i]]==1) { LL ans1=query(rt[res[now]],l,r,k/a[res[i]]); LL ans2=query(rt[res[i]],l,r,k/a[res[i]]); //cout<<ans1<<" "<<ans2<<endl; ans+=ans1-ans2; continue; } flag[res[i]]=1; update(rt[res[i]]=rt[res[now]],l,r,a[res[i]]); now=i; } printf("%lld ",ans); } return 0; }