点分治入门题
首先发现是树上点对的问题,那么首先想到上点分治
然后发现题目要求是求出树上点对之间距离小于等于k的对数,那么我们很自然地进行分类:
对于一棵有根树,树上的路径只有两种:一种经过根节点,另一种不经过根节点
对于经过根节点的路径,我们可以通过计算出每个点的根节点的距离,然后相加就能求出点对间距离
对于不经过根节点的路径,我们可以递归到子节点去算,去找子节点对应的子树来算即可
但是这里有两个问题:第一,如何快速算出以一个点为根的合法点对数量?
我们知道,可以在线性时间内求出每个点到根节点的距离,但如果我们逐个枚举点对的话,时间就会退化成平方级别
这显然不够优秀
所以我们将每个点到根节点距离排序,然后用两个指针,初始分别指向头和尾,如果两个指针指到的之和是合法的,那么这两个指针间的部分都是合法的(具体看代码),扫一遍即可
第二:这样做的结果是正确的吗?
我们看到,如果查到的一个点对在同一棵子树内,那么在计算以这个点为根和以这个点的子节点为根的时候,这个点对都会被计算一次!
这显然是不对的
因此我们在枚举每个子树时需要先去掉这一部分,然后再计算
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> #define ll long long using namespace std; const int inf=0x3f3f3f3f; struct Edge { int next; int to; int val; }edge[200005]; int head[100005]; int rt,s; int n,k; int maxp[100005]; int siz[100005]; bool vis[100005]; int dis[100005]; int used[100005]; int ans=0; int cnt=1; void init() { memset(head,-1,sizeof(head)); memset(vis,0,sizeof(vis)); ans=0; cnt=1; } void add(int l,int r,int w) { edge[cnt].next=head[l]; edge[cnt].to=r; edge[cnt].val=w; head[l]=cnt++; } void get_rt(int x,int fa) { siz[x]=1,maxp[x]=0; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to]||to==fa)continue; get_rt(to,x); siz[x]+=siz[to]; maxp[x]=max(maxp[x],siz[to]); } maxp[x]=max(maxp[x],s-siz[x]); if(maxp[x]<maxp[rt])rt=x; } void get_dis(int x,int fa) { used[++used[0]]=dis[x]; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to]||to==fa)continue; dis[to]=dis[x]+edge[i].val; get_dis(to,x); } } int calc(int x,int val) { dis[x]=val; used[0]=0; get_dis(x,0); sort(used+1,used+used[0]+1); int l=1,r=used[0]; int ret=0; while(l<r) { if(used[l]+used[r]<=k)ret+=r-l,l++; else r--; } return ret; } void solve(int x) { vis[x]=1; ans+=calc(x,0); for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to])continue; ans-=calc(to,edge[i].val); rt=0,s=siz[to],maxp[rt]=inf; get_rt(to,0); solve(rt); } } int main() { while(1) { scanf("%d%d",&n,&k); init(); if(!n&&!k)return 0; for(int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add(x,y,z),add(y,x,z); } maxp[rt]=s=n; get_rt(1,0); solve(rt); printf("%d ",ans); } return 0; }