【题意】给定n个点的树,m次询问,每次给定ki个特殊点,一个点会被最近的特殊点控制,询问每个特殊点控制多少点。n,m,Σki<=300000。
【算法】虚树+倍增
【题解】★参考:thy_asdf
对询问建立虚树,然后DFS统计虚树上每个点被哪个点控制,记为belong[x]。
统计的方法是从下往上DFS得到x来自x子树的控制点,再从上往下DFS得到x来自非x子树的控制点。
令D(x,y)表示点x和点y之间的路径长度,比较x的不同控制点y的方式是取D(x,y)的最小值(相同时比较编号)。
依此考虑虚树上的每一条边(x,y),记z为同时是x的儿子和y的祖先的点,令f[x]表示点x的答案。
如果belong[x]=belong[y],那么f[belong[x]]+=size[z]-size[y]
否则,从y倍增找到分界点mid,满足mid及以下由y控制,mid以上由x控制。
倍增过程中,记A=D(belong[x],mid),B=D(belong[x],mid),当满足A>B或A=B&&belong[x]>belong[y]时进行倍增。
找到分界点mid后,f[belong[x]]+=size[x]-size[m],f[belong[y]+=size[m]-size[y]。
最后要处理没有出现在虚树路径上的点(包括虚树节点),这些点没有在上述过程中被统计。
记rem[x]表示子树x中未被统计的节点数,初始rem[x]=size[x],特别的rem[a[1]]=n(a[1]是虚树最高节点,n是原树总点数,这样是为了保存a[1]往上延伸的节点)
每次处理完边(x,y),rem[x]-=size[z]。
全部做完后,对虚树的每个点f[belong[x]]+=rem[x]。
#include<cstdio> #include<cctype> #include<cstring> #include<algorithm> using namespace std; int read(){ int s=0,t=1;char c; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } const int maxn=300010; int n,N,size[maxn],in[maxn],ou[maxn],f[maxn],a[maxn*2],b[maxn]; int deep[maxn],tot,first[maxn],belong[maxn],st[maxn]; int rem[maxn]; bool v[maxn]; namespace cyc{ int first[maxn],tot,dfsnum=0,f[maxn][30]; struct edge{int v,from;}e[maxn*2]; void insert(int u,int v){tot++;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} void dfs(int x,int fa){ in[x]=++dfsnum;size[x]=1; for(int j=1;(1<<j)<=deep[x];j++)f[x][j]=f[f[x][j-1]][j-1]; for(int i=first[x];i;i=e[i].from)if(e[i].v!=fa){ deep[e[i].v]=deep[x]+1; f[e[i].v][0]=x; dfs(e[i].v,x); size[x]+=size[e[i].v]; } ou[x]=dfsnum; } int lca(int x,int y){ if(deep[x]<deep[y])swap(x,y); int d=deep[x]-deep[y]; for(int j=0;j<=20;j++)if(d&(1<<j))x=f[x][j]; if(x==y)return x; for(int j=20;j>=0;j--)if((1<<j)<=deep[x]&&f[x][j]!=f[y][j]){ x=f[x][j];y=f[y][j]; } return f[x][0]; } void build(){ n=read(); for(int i=1;i<n;i++){ int u=read(),v=read(); insert(u,v);insert(v,u); } dfs(1,-1); } } int D(int x,int y){if(!y)return 0x3f3f3f3f;else return deep[x]+deep[y]-2*deep[cyc::lca(x,y)];} bool cmp(int a,int b){return in[a]<in[b];} bool check(int x,int y){return in[x]<=in[y]&&ou[y]<=ou[x];} struct edge{int u,v,from;}e[maxn*2]; void insert(int u,int v){tot++;e[tot].u=u;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} void dfs1(int x){ if(v[x])belong[x]=x; for(int i=first[x];i;i=e[i].from){ dfs1(e[i].v); int A=D(x,belong[x]),B=D(x,belong[e[i].v]); if(B<A||(B==A&&belong[e[i].v]<belong[x]))belong[x]=belong[e[i].v]; } } void dfs2(int x){ for(int i=first[x];i;i=e[i].from){ int A=D(e[i].v,belong[x]),B=D(e[i].v,belong[e[i].v]); if(A<B||(A==B&&belong[x]<belong[e[i].v]))belong[e[i].v]=belong[x]; dfs2(e[i].v); } } void solve(int x,int y){ int z=y; for(int i=20;i>=0;i--)if((1<<i)<=deep[z]&&deep[cyc::f[z][i]]>=deep[x]+1)z=cyc::f[z][i]; rem[x]-=size[z]; if(belong[x]==belong[y]){f[belong[x]]+=size[z]-size[y];return;}// int m=y; for(int i=20;i>=0;i--)if((1<<i)<=deep[m]){ if(deep[cyc::f[m][i]]<deep[x])continue; int A=D(cyc::f[m][i],belong[x]),B=D(cyc::f[m][i],belong[y]); if(B<A||(A==B&&belong[x]>belong[y]))m=cyc::f[m][i];// } f[belong[x]]+=size[z]-size[m]; f[belong[y]]+=size[m]-size[y]; } void work(){ int last=read();N=last; for(int i=1;i<=N;i++)a[i]=read(),f[b[i]=a[i]]=0,v[a[i]]=1; sort(a+1,a+N+1,cmp); for(int i=1;i<last;i++)a[++N]=cyc::lca(a[i],a[i+1]); sort(a+1,a+N+1,cmp); N=unique(a+1,a+N+1)-a-1; tot=0; for(int i=1;i<=N;i++)first[a[i]]=0,belong[a[i]]=0,rem[a[i]]=size[a[i]]; rem[a[1]]=n; int top=0; for(int i=1;i<=N;i++){ while(top&&!check(st[top],a[i]))top--; if(top)insert(st[top],a[i]); st[++top]=a[i]; } dfs1(a[1]);dfs2(a[1]); for(int i=1;i<=tot;i++)solve(e[i].u,e[i].v); for(int i=1;i<=N;i++)f[belong[a[i]]]+=rem[a[i]]; for(int i=1;i<=last;i++)printf("%d ",f[b[i]]),v[b[i]]=0; } int main(){ cyc::build(); int m=read(); while(m--)work(); return 0; }
(倍增求LCA)
注意,由于多次调用LCA,所以用倍增LCA算法有很大的常数,甚至在LOJ会导致TLE。
但是倍增求LCA会使过程更简洁。
下面这份代码是树链剖分求LCA,代码可读性不高。
#include<cstdio> #include<cctype> #include<cstring> #include<algorithm> using namespace std; int read(){ int s=0,t=1;char c; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } const int maxn=300010; int n,N,size[maxn],in[maxn],ou[maxn],f[maxn],a[maxn*2],b[maxn]; int deep[maxn],tot,first[maxn],belong[maxn],st[maxn]; int rem[maxn]; bool v[maxn]; namespace cyc{ int first[maxn],tot,dfsnum=0,f[maxn][30]; struct edge{int v,from;}e[maxn*2]; void insert(int u,int v){tot++;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} void dfs(int x,int fa){ in[x]=++dfsnum;size[x]=1; for(int j=1;(1<<j)<=deep[x];j++)f[x][j]=f[f[x][j-1]][j-1]; for(int i=first[x];i;i=e[i].from)if(e[i].v!=fa){ deep[e[i].v]=deep[x]+1; f[e[i].v][0]=x; dfs(e[i].v,x); size[x]+=size[e[i].v]; } ou[x]=dfsnum; } int lca(int x,int y){ if(deep[x]<deep[y])swap(x,y); int d=deep[x]-deep[y]; for(int j=0;j<=20;j++)if(d&(1<<j))x=f[x][j]; if(x==y)return x; for(int j=20;j>=0;j--)if((1<<j)<=deep[x]&&f[x][j]!=f[y][j]){ x=f[x][j];y=f[y][j]; } return f[x][0]; } void build(){ n=read(); for(int i=1;i<n;i++){ int u=read(),v=read(); insert(u,v);insert(v,u); } dfs(1,-1); } } int D(int x,int y){if(!y)return 0x3f3f3f3f;else return deep[x]+deep[y]-2*deep[cyc::lca(x,y)];} bool cmp(int a,int b){return in[a]<in[b];} bool check(int x,int y){return in[x]<=in[y]&&ou[y]<=ou[x];} struct edge{int u,v,from;}e[maxn*2]; void insert(int u,int v){tot++;e[tot].u=u;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} void dfs1(int x){ if(v[x])belong[x]=x; for(int i=first[x];i;i=e[i].from){ dfs1(e[i].v); int A=D(x,belong[x]),B=D(x,belong[e[i].v]); if(B<A||(B==A&&belong[e[i].v]<belong[x]))belong[x]=belong[e[i].v]; } } void dfs2(int x){ for(int i=first[x];i;i=e[i].from){ int A=D(e[i].v,belong[x]),B=D(e[i].v,belong[e[i].v]); if(A<B||(A==B&&belong[x]<belong[e[i].v]))belong[e[i].v]=belong[x]; dfs2(e[i].v); } } void solve(int x,int y){ int z=y; for(int i=20;i>=0;i--)if((1<<i)<=deep[z]&&deep[cyc::f[z][i]]>=deep[x]+1)z=cyc::f[z][i]; rem[x]-=size[z]; if(belong[x]==belong[y]){f[belong[x]]+=size[z]-size[y];return;}// int m=y; for(int i=20;i>=0;i--)if((1<<i)<=deep[m]){ if(deep[cyc::f[m][i]]<deep[x])continue; int A=D(cyc::f[m][i],belong[x]),B=D(cyc::f[m][i],belong[y]); if(B<A||(A==B&&belong[x]>belong[y]))m=cyc::f[m][i];// } f[belong[x]]+=size[z]-size[m]; f[belong[y]]+=size[m]-size[y]; } void work(){ int last=read();N=last; for(int i=1;i<=N;i++)a[i]=read(),f[b[i]=a[i]]=0,v[a[i]]=1; sort(a+1,a+N+1,cmp); for(int i=1;i<last;i++)a[++N]=cyc::lca(a[i],a[i+1]); sort(a+1,a+N+1,cmp); N=unique(a+1,a+N+1)-a-1; tot=0; for(int i=1;i<=N;i++)first[a[i]]=0,belong[a[i]]=0,rem[a[i]]=size[a[i]]; rem[a[1]]=n; int top=0; for(int i=1;i<=N;i++){ while(top&&!check(st[top],a[i]))top--; if(top)insert(st[top],a[i]); st[++top]=a[i]; } dfs1(a[1]);dfs2(a[1]); for(int i=1;i<=tot;i++)solve(e[i].u,e[i].v); for(int i=1;i<=N;i++)f[belong[a[i]]]+=rem[a[i]]; for(int i=1;i<=last;i++)printf("%d ",f[b[i]]),v[b[i]]=0; } int main(){ cyc::build(); int m=read(); while(m--)work(); return 0; }
upd:复习一下。
【建虚树】
第一步、所有关键点按DFS序排序,加入LCA后排序去重。
第二步、按顺序用栈维护构造出虚树。
第三步、虚树DP后清零。
【此题的虚树DP】
第一步、处理虚树上点与点。
第二步、处理虚树边上点的子树。
第三步、处理虚树点的 [ 不在虚树上的孩子 ] 子树。(通过减虚树孩子的sz得到)