【传送门:BZOJ1468&BZOJ3365】
简要题意:
给出一棵n个点的树,和每条边的边权,求出有多少个点对的距离<=k
题解:
点分治模板题
点分治的主要步骤:
1、首先选取一个点,把无根树变成有根树。 那么如何选点呢? ——树形DP
因为树是递归定义的,所以我们当然希望递归的层数最小。 每次选取的点,要保证与此点相连的结点数最多的连通块的结点数最小,我们把这个点叫做“重心”
那么找到一颗树的重心有以下算法:
(1)dfs一次,算出以每个点为根的子树大小
(2)记录以每个结点为根的最大子树的大小
(3)判断:如果以当前结点为根的最大子树大小比当前根更优,更新当前根
2、处理联通块中通过根结点的路径
3、标记根结点(相当于处理过后,将根结点从子树中删除)
4、递归处理以当前点的儿子为根的每棵子树
参考代码(一):
#include<cstdio> #include<cstring> #include<cstdlib> #include<cmath> #include<algorithm> using namespace std; struct node { int x,y,d,next; }a[81000];int len,last[41000]; void ins(int x,int y,int d) { len++; a[len].x=x;a[len].y=y;a[len].d=d; a[len].next=last[x];last[x]=len; } int tot[41000],root,sum,ms[41000]; bool v[41000]; void getroot(int x,int fa) { tot[x]=1;ms[x]=0; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y!=fa&&v[y]==false) { getroot(y,x); tot[x]+=tot[y]; ms[x]=max(ms[x],tot[y]); } } ms[x]=max(ms[x],sum-tot[x]); if(ms[root]>ms[x]) root=x; } int dep[41000],id; int dd[41000]; void getdep(int x,int fa) { dep[++id]=dd[x]; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y!=fa&&v[y]==false) { dd[y]=dd[x]+a[k].d; getdep(y,x); } } } int ans=0; int k; int cal(int x,int d) { dd[x]=d;id=0; getdep(x,0); sort(dep+1,dep+id+1); int l=1,r=id,c=0; while(l<r) { if(dep[l]+dep[r]<=k){c+=r-l;l++;} else r--; } return c; } void solve(int x) { ans+=cal(x,0); v[x]=true; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(v[y]==false) { ans-=cal(y,a[k].d); sum=tot[y]; root=0;getroot(y,x); solve(root); } } } int main() { int n; scanf("%d",&n); len=0;memset(last,0,sizeof(last)); for(int i=1;i<n;i++) { int x,y,d; scanf("%d%d%d",&x,&y,&d); ins(x,y,d);ins(y,x,d); } scanf("%d",&k); memset(v,false,sizeof(v)); ans=0; sum=tot[0]=n; ms[0]=1<<31-1; root=0;getroot(1,0); solve(root); printf("%d ",ans); return 0; }
参考代码(二):
#include<cstdio> #include<cstring> #include<cstdlib> #include<cmath> #include<algorithm> using namespace std; struct node { int x,y,d,next; }a[81000];int len,last[41000]; void ins(int x,int y,int d) { len++; a[len].x=x;a[len].y=y;a[len].d=d; a[len].next=last[x];last[x]=len; } int tot[41000],root,sum,ms[41000]; bool v[41000]; void getroot(int x,int fa) { tot[x]=1;ms[x]=0; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y!=fa&&v[y]==false) { getroot(y,x); tot[x]+=tot[y]; ms[x]=max(ms[x],tot[y]); } } ms[x]=max(ms[x],sum-tot[x]); if(ms[root]>ms[x]) root=x; } int dep[41000],id; int dd[41000]; void getdep(int x,int fa) { dep[++id]=dd[x]; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y!=fa&&v[y]==false) { dd[y]=dd[x]+a[k].d; getdep(y,x); } } } int ans=0; int k; int cal(int x,int d) { dd[x]=d;id=0; getdep(x,0); sort(dep+1,dep+id+1); int l=1,r=id,c=0; while(l<r) { if(dep[l]+dep[r]<=k){c+=r-l;l++;} else r--; } return c; } void solve(int x) { ans+=cal(x,0); v[x]=true; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(v[y]==false) { ans-=cal(y,a[k].d); sum=tot[y]; root=0;getroot(y,x); solve(root); } } } int main() { int n,m; scanf("%d%d",&n,&m); len=0;memset(last,0,sizeof(last)); for(int i=1;i<=m;i++) { int x,y,d;char st[3]; scanf("%d%d%d%s",&x,&y,&d,st+1); ins(x,y,d);ins(y,x,d); } scanf("%d",&k); memset(v,false,sizeof(v)); ans=0; sum=tot[0]=n; ms[0]=1<<31-1; root=0;getroot(1,0); solve(root); printf("%d ",ans); return 0; }