题目描述
给定一个N个结点的树,结点用正整数1..N编号。每条边有一个正整数权值。用d(a,b)表示从结点a到结点b路边上经过边的权值。其中要求a<b.将这n*(n-1)/2个距离从大到小排序,输出前M个距离值。
题解
把每次点分治时的dfs序写下来,假设我们在一个位置找能够和它拼成一条链的另一个位置,可以发现那些位置的顺序在dfs序上构成了一段连续区间,用ST表+堆维护。
注意在进队列之前先内啥一下。
代码
#include<iostream> #include<cstdio> #include<queue> #include<cmath> #define N 50002 #define M 16 using namespace std; int tot,head[N],lo[N*M],st[M][N*M],size[N],dp[N],sum,now,deep[N],root,n,p[M][N*M]; bool vis[N]; int start,ed; inline int rd(){ int x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } struct edge{int n,to,l;}e[N<<1]; inline void add(int u,int v,int l){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;e[tot].l=l;} struct node{ int now,l,r,sum; node(int nownum=0,int num1=0,int num2=0){ now=nownum;l=num1;r=num2; int loo=lo[r-l+1]; sum=now+max(st[loo][l],st[loo][r-(1<<loo)+1]); } int calc(){ int loo=lo[r-l+1]; if(st[loo][l]>=st[loo][r-(1<<loo)+1])return p[loo][l];else return p[loo][r-(1<<loo)+1]; } bool operator <(const node &b)const{return sum<b.sum;} }pa[N*M]; priority_queue<node>q; void getroot(int u,int fa){ size[u]=1;dp[u]=0; for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){ int v=e[i].to; getroot(v,u); size[u]+=size[v];dp[u]=max(dp[u],size[v]); } dp[u]=max(dp[u],sum-size[u]); if(dp[u]<dp[root])root=u; } void getsize(int u,int fa){ size[u]=1; for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa&&!vis[e[i].to]){ int v=e[i].to; getsize(v,u); size[u]+=size[v]; } } void getdeep(int u,int fa){ st[0][++now]=deep[u];p[0][now]=now; pa[now]=node(deep[u],start,ed); for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]&&e[i].to!=fa){ int v=e[i].to; deep[v]=deep[u]+e[i].l; getdeep(v,u); } } inline void calc(int u){ st[0][++now]=0;p[0][now]=now; pa[now]=node{0,now,now}; start=now;ed=now; for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]){ int v=e[i].to; deep[v]=e[i].l; getdeep(v,u); ed=now; } } void solve(int u){ calc(u);vis[u]=1; for(int i=head[u];i;i=e[i].n)if(!vis[e[i].to]){ int v=e[i].to; root=n+1;sum=size[v]; getroot(v,u);getsize(root,0); solve(root); } } int main(){ n=rd();int k=rd();int u,v,w; for(int i=1;i<n;++i){ u=rd();v=rd();w=rd(); add(u,v,w);add(v,u,w); } dp[root=n+1]=n+1;sum=n; getroot(1,0);getsize(root,0); solve(root); for(int i=1;(1<<i)<=now&&i<M;++i) for(int j=1;j+(1<<i)-1<=now;++j) st[i][j]=max(st[i-1][j],st[i-1][j+(1<<i-1)]),p[i][j]=st[i-1][j]>=st[i-1][j+(1<<i-1)]?p[i-1][j]:p[i-1][j+(1<<i-1)]; for(int i=2;i<=now;++i)lo[i]=lo[i>>1]+1; for(int i=1;i<=now;++i)pa[i]=node(pa[i].now,pa[i].l,pa[i].r),q.push(pa[i]);///care !!!! for(int i=1;i<=k;++i){ node x=q.top();q.pop(); printf("%d ",x.sum); int mid=x.calc(); if(x.l<mid)q.push(node(x.now,x.l,mid-1)); if(x.r>mid)q.push(node(x.now,mid+1,x.r)); } return 0; }