题目背景
感谢hzwer的点分治互测。
题目描述
给定一棵有n个点的树
询问树上距离为k的点对是否存在。
输入输出格式
输入格式:
n,m 接下来n-1条边a,b,c描述a到b有一条长度为c的路径
接下来m行每行询问一个K
输出格式:
对于每个K每行输出一个答案,存在输出“AYE”,否则输出”NAY”(不包含引号)
思路:
点分治裸题。。。
什么是点分治?
一种优化的暴力,可以方便的求出树上两点间的距离
怎么做呢?qwq?
我们知道,树的重心可以使树的深度最小(最坏n/2,期望logn)
我们在点分治时就可以这么来优化
每次找到树的重心,然后遍历子树,dfs计算出每一条经过重心的路径长度
然后删掉重心,每个新形成的树再跑一遍
直到无法递归为止
长度记下来,o(1)查询即可
代码:
#include<iostream> #include<cstdio> #include<cstring> #define rii register int i #define rij register int j using namespace std; bool bj[10000005]; struct ljb{ int to,val,next; }x[20005]; int sd,lsl,cnt,head[10005],last[10005],ans,n,m,k; int cd[10005],ks[10005],js[10005],zg[10005],size[10005]; int judge[10005]; void dfs1(int wz,int fa,int sd) { cnt++; judge[cnt]=sd; int ltt=head[wz]; while(ltt!=0) { if(x[ltt].to==fa||zg[x[ltt].to]==1) { ltt=x[ltt].next; continue; } dfs1(x[ltt].to,wz,sd+x[ltt].val); ltt=x[ltt].next; } return; } int dfs(int wz,int fa) { size[wz]=1; int ltt=head[wz]; while(ltt!=0) { if(fa!=x[ltt].to&&zg[x[ltt].to]==0) { dfs(x[ltt].to,wz); size[wz]+=size[x[ltt].to]; } ltt=x[ltt].next; } if(size[wz]*2>=n&&!ans) { ans=wz; } return 0; } void updata(int l,int r) { for(rii=ks[l];i<=js[l];i++) { for(rij=ks[r];j<=js[r];j++) { bj[judge[i]+judge[j]]=1; } } return; } void solve(int wz) { memset(judge,0,sizeof(judge)); memset(ks,0,sizeof(ks)); memset(js,0,sizeof(js)); memset(size,0,sizeof(size)); ans=0; lsl=0; cnt=0; dfs(wz,wz); int zx=ans; if(zx==0) { zx=wz; } int ltt=head[zx]; while(ltt!=0) { if(zg[x[ltt].to]==0) { lsl++; ks[lsl]=cnt+1; dfs1(x[ltt].to,zx,x[ltt].val); js[lsl]=cnt; } ltt=x[ltt].next; } for(rii=1;i<=lsl;i++) { for(rij=i+1;j<=lsl;j++) { updata(i,j); } } for(rii=1;i<=cnt;i++) { bj[judge[i]]=1; } zg[zx]=1; ltt=head[wz]; while(ltt!=0) { if(zg[x[ltt].to]==0) { solve(x[ltt].to); } ltt=x[ltt].next; } return; } int main() { scanf("%d%d",&n,&m); for(rii=1;i<=n-1;i++) { int ltt,kkk,val; scanf("%d%d%d",<t,&kkk,&val); if(head[ltt]==0) { head[ltt]=i*2-1; } x[i*2-1].to=kkk; x[i*2-1].val=val; if(last[ltt]!=0) { x[last[ltt]].next=i*2-1; } last[ltt]=i*2-1; if(head[kkk]==0) { head[kkk]=i*2; } x[i*2].to=ltt; x[i*2].val=val; if(last[kkk]!=0) { x[last[kkk]].next=i*2; } last[kkk]=i*2; bj[val]=1; } solve(1); for(rii=1;i<=m;i++) { int ltt; scanf("%d",<t); if(bj[ltt]==1) { printf("AYE "); } else { printf("NAY "); } } }