题目描述
设T 为一棵有根树,我们做如下的定义:
? 设a和b为T 中的两个不同节点。如果a是b的祖先,那么称“a比b不知道
高明到哪里去了”。
? 设a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定
常数x,那么称“a 与b 谈笑风生”。
给定一棵n个节点的有根树T,节点的编号为1 到 n,根节点为1号节点。你需
要回答q 个询问,询问给定两个整数p和k,问有多少个有序三元组(a;b;c)满足:
1. a、b和 c为 T 中三个不同的点,且 a为p 号节点;
2. a和b 都比 c不知道高明到哪里去了;
3. a和b 谈笑风生。这里谈笑风生中的常数为给定的 k。
输入
第一行含有两个正整数n和q,分别代表有根树的点数与询问的个数。
接下来n - 1行,每行描述一条树上的边。每行含有两个整数u和v,代表在节点u和v之间有一条边。
接下来q行,每行描述一个操作。第i行含有两个整数,分别表示第i个询问的p和k。
1<=P<=N
1<=K<=N
N<=300000
Q<=300000
输出
输出 q 行,每行对应一个询问,代表询问的答案。
样例输入
5 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3
样例输出
3
1
3
1
3
提示
Hint:边要加双向
题目大意:给定一个n个节点的有根树,q次询问,每次询问两个数p,k,问满足1、a,b都是c的祖先。2、a编号为p。3、a,b距离<=k。的三元组(a,b,c)有多少个。
因为a,b都是c的祖先,所以它们其中一个一定是另一个的祖先。a的位置确定了,那就讨论b的位置:当b是a祖先时,直接将a子树大小乘上k和a的深度中小的那个就好了;当a是b祖先时,a子树中与a深度差<=k的点都可以是b,统计这些点的子树和就是答案,也就相当于将每个点的子树大小作为这个点的点权(要将自己刨去)求点权和。如果没有层数限制直接将dfs序架在线段树上区间求和就好了。但有了限制就要用主席树按深度建树,每个深度建一棵线段树,维护这一深度所有点的信息,查询时依旧是查a点子树区间,但因为深度>dep[a]+k的点在主席树中还没有维护信息所以并不影响。
#include<set> #include<map> #include<queue> #include<cmath> #include<stack> #include<vector> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> typedef long long ll; using namespace std; ll ans; int mx; int n,m; int p,k; int x,y; int num; int tot; int cnt; int s[300010]; int t[300010]; int d[300010]; int to[600010]; int ls[6000010]; int rs[6000010]; ll sum[6000010]; int head[300010]; int next[600010]; int size[300010]; int root[300010]; struct node { int dep; int id; }a[300010]; bool cmp(node a,node b) { return a.dep<b.dep; } void add(int x,int y) { tot++; next[tot]=head[x]; head[x]=tot; to[tot]=y; } void dfs(int x,int fa) { num++; s[x]=num; a[x].dep=a[fa].dep+1; d[x]=d[fa]+1; size[x]=1; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa) { dfs(to[i],x); size[x]+=size[to[i]]; } } t[x]=num; } int updata(int pre,int l,int r,int k,int v) { int rt=++cnt; if(l==r) { sum[rt]=sum[pre]+v; return rt; } ls[rt]=ls[pre]; rs[rt]=rs[pre]; sum[rt]=sum[pre]+v; int mid=(l+r)>>1; if(k<=mid) { ls[rt]=updata(ls[pre],l,mid,k,v); } else { rs[rt]=updata(rs[pre],mid+1,r,k,v); } return rt; } ll query(int x,int y,int l,int r,int L,int R) { if(L<=l&&r<=R) { return sum[y]-sum[x]; } int mid=(l+r)>>1; if(L>mid) { return query(rs[x],rs[y],mid+1,r,L,R); } else if(R<=mid) { return query(ls[x],ls[y],l,mid,L,R); } return query(ls[x],ls[y],l,mid,L,R)+query(rs[x],rs[y],mid+1,r,L,R); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); a[i].id=i; } a[n].id=n; dfs(1,0); sort(a+1,a+1+n,cmp); for(int i=1;i<=n;i++) { mx=max(mx,d[i]); if(a[i].dep>a[i-1].dep) { root[a[i].dep]=root[a[i-1].dep]; } root[a[i].dep]=updata(root[a[i].dep],1,n,s[a[i].id],size[a[i].id]-1); } for(int i=1;i<=m;i++) { scanf("%d%d",&p,&k); ans=0; ans+=1ll*min(k,d[p]-1)*(size[p]-1); ans+=query(root[d[p]],root[min(d[p]+k,mx)],1,n,s[p],t[p]); printf("%lld ",ans); } }