[BZOJ 3653] 谈笑风生
题意
给定一棵 (n) 个点根为 (1) 单位权值的树以及 (q) 个查询, 每次查询给定 (p) 和 (k), 求满足 (a,p) 都是 (b) 的祖先且 (a,p) 之间的距离不超过 (k) 的 ((a,b)) 的数量. ((a,b,p) 不能重合)
(n,qle3 imes 10^5).
题解
首先我们一眼就能看出这个鬼东西需要分 (a) 是否是 (p) 的祖先来计算. 如果 (a) 是 (p) 的祖先那么会对应 $size_p-1 $ 个不同的 (b). 否则的话对于每个深度与 (p) 之差不超过 (k) 且在 (p) 子树中的 (a) 都会对应 (size_a-1) 个不同的 (b).
前一部分感觉地球人都知道怎么做. 问题在于后一部分要求某个子树中深度在某一区间内的点的权值和.
区间查询, 不难想到线段树. 然后我们发现事实上祖先结点可以继承子结点的权值和信息, 于是我们用线段树合并处理.
具体做法是DFS预处理出所有的 (size) 和深度, 回溯的时候根据深度插入当前点的权值顺便合并一下子结点的线段树. 然后就可以在线计算查询了.
时间复杂度是 (O((n+q)log n)).
参考代码
#include <bits/stdc++.h>
namespace rvalue{
const int MAXV=3e5+10;
const int MAXE=1e6+10;
typedef long long intEx;
struct Edge{
int from;
int to;
Edge* next;
};
Edge E[MAXE];
Edge* head[MAXV];
Edge* top=E;
struct Node{
int l;
int r;
intEx sum;
Node* lch;
Node* rch;
Node(int,int);
void Insert(int,int);
intEx Query(int,int);
};
Node* N[MAXV];
int n;
int q;
int deep[MAXV];
int size[MAXV];
intEx ans[MAXV];
int ReadInt();
void Insert(int,int);
void DFS(int,int,int);
Node* Merge(Node*,Node*);
int main(){
n=ReadInt(),q=ReadInt();
for(int i=1;i<n;i++){
int a=ReadInt(),b=ReadInt();
Insert(a,b);
Insert(b,a);
}
DFS(1,0,1);
for(int i=0;i<q;i++){
int root=ReadInt();
int k=ReadInt();
intEx ans=N[root]->Query(deep[root]+1,std::min(deep[root]+k,n));
ans+=1ll*std::min(k,deep[root]-1)*(size[root]-1);
printf("%lld
",ans);
}
return 0;
}
void DFS(int root,int prt,int deep){
rvalue::size[root]=1;
rvalue::deep[root]=deep;
N[root]=new Node(1,n);
for(Edge* i=head[root];i!=NULL;i=i->next){
if(i->to!=prt){
DFS(i->to,root,deep+1);
rvalue::size[root]+=rvalue::size[i->to];
N[root]=Merge(N[root],N[i->to]);
}
}
N[root]->Insert(rvalue::deep[root],size[root]-1);
}
Node::Node(int l,int r):l(l),r(r),sum(0),lch(NULL),rch(NULL){}
Node* Merge(Node* L,Node* R){
if(L==NULL)
return R;
if(R==NULL)
return L;
Node* cur=new Node(L->l,R->r);
cur->sum=L->sum+R->sum;
cur->lch=Merge(L->lch,R->lch);
cur->rch=Merge(L->rch,R->rch);
return cur;
}
intEx Node::Query(int l,int r){
if(l<=this->l&&this->r<=r)
return this->sum;
else{
int mid=(this->l+this->r)>>1;
intEx ans=0;
if(l<=mid&&this->lch!=NULL)
ans+=this->lch->Query(l,r);
if(mid+1<=r&&this->rch!=NULL)
ans+=this->rch->Query(l,r);
return ans;
}
}
void Node::Insert(int x,int d){
this->sum+=d;
if(this->l!=this->r){
int mid=(this->l+this->r)>>1;
if(x<=mid){
if(this->lch==NULL)
this->lch=new Node(this->l,mid);
this->lch->Insert(x,d);
}
else{
if(this->rch==NULL)
this->rch=new Node(mid+1,this->r);
this->rch->Insert(x,d);
}
}
}
inline void Insert(int from,int to){
top->from=from;
top->to=to;
top->next=head[from];
head[from]=top++;
}
inline int ReadInt(){
int x=0;
register char ch=getchar();
while(!isdigit(ch))
ch=getchar();
while(isdigit(ch)){
x=x*10+ch-'0';
ch=getchar();
}
return x;
}
}
int main(){
freopen("laugh.in","r",stdin);
freopen("laugh.out","w",stdout);
rvalue::main();
return 0;
}