题目描述
给定一棵有n个点的树
询问树上距离为k的点对是否存在。
输入输出格式
输入格式:
n,m 接下来n-1条边a,b,c描述a到b有一条长度为c的路径
接下来m行每行询问一个K
输出格式:
对于每个K每行输出一个答案,存在输出“AYE”,否则输出”NAY”(不包含引号)
输入输出样例
说明
对于30%的数据n<=100
对于60%的数据n<=1000,m<=50
对于100%的数据n<=10000,m<=100,c<=1000,K<=10000000
题解
看一看$k$的范围,可以直接把所有答案预处理出来,然后$O(1)$查询
时间复杂度$O(n^2)$,随机数据可以跑
据说还有$O(nlog^2n)$的方法,然而我不会……
1 //minamoto 2 #include<cstdio> 3 #include<iostream> 4 #define inf 0x3f3f3f3f 5 #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) 6 char buf[1<<21],*p1=buf,*p2=buf; 7 template<class T>inline bool cmax(T&a,const T&b){return a<b?a=b,1:0;} 8 inline int read(){ 9 #define num ch-'0' 10 char ch;bool flag=0;int res; 11 while(!isdigit(ch=getc())) 12 (ch=='-')&&(flag=true); 13 for(res=num;isdigit(ch=getc());res=res*10+num); 14 (flag)&&(res=-res); 15 #undef num 16 return res; 17 } 18 const int N=10005; 19 int ans[10000005]; 20 int ver[N<<1],head[N],Next[N<<1],edge[N<<1]; 21 int sz[N],son[N],st[N];bool vis[N]; 22 int n,m,size,mx,rt,tot,top; 23 inline void add(int u,int v,int e){ 24 ver[++tot]=v,Next[tot]=head[u],head[u]=tot,edge[tot]=e; 25 ver[++tot]=u,Next[tot]=head[v],head[v]=tot,edge[tot]=e; 26 } 27 void getrt(int u,int fa){ 28 sz[u]=1,son[u]=0; 29 for(int i=head[u];i;i=Next[i]){ 30 int v=ver[i]; 31 if(vis[v]||v==fa) continue; 32 getrt(v,u); 33 sz[u]+=sz[v],cmax(son[u],sz[v]); 34 } 35 cmax(son[u],size-sz[u]); 36 if(son[u]<mx) mx=son[u],rt=u; 37 } 38 void query(int u,int fa,int d){ 39 st[++top]=d; 40 for(int i=head[u];i;i=Next[i]){ 41 int v=ver[i]; 42 if(vis[v]||v==fa) continue; 43 query(v,u,d+edge[i]); 44 } 45 } 46 void solve(int rt,int d,int f){ 47 top=0; 48 query(rt,0,d); 49 if(f){ 50 for(int i=1;i<top;++i) 51 for(int j=i+1;j<=top;++j) 52 ++ans[st[i]+st[j]]; 53 } 54 else{ 55 for(int i=1;i<top;++i) 56 for(int j=i+1;j<=top;++j) 57 --ans[st[i]+st[j]]; 58 } 59 } 60 void divide(int u){ 61 vis[u]=true; 62 solve(u,0,1); 63 for(int i=head[u];i;i=Next[i]){ 64 int v=ver[i]; 65 if(vis[v]) continue; 66 solve(v,edge[i],0); 67 mx=inf,rt=0,size=sz[v]; 68 getrt(v,0); 69 divide(rt); 70 } 71 } 72 int main(){ 73 n=read(),m=read(); 74 for(int i=1;i<n;++i){ 75 int u=read(),v=read(),e=read(); 76 add(u,v,e); 77 } 78 rt=0,mx=inf,size=n; 79 getrt(1,0),divide(rt); 80 while(m--){ 81 int k=read(); 82 puts(ans[k]?"AYE":"NAY"); 83 } 84 return 0; 85 }