题目链接:http://poj.org/problem?id=1741
Time Limit: 1000MS Memory Limit: 30000K
Description
Give a tree with n vertices, each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
题意:
给出一棵 $n(n le 10000)$ 个节点的树,树上的每条边都有一个长度(最大不超过 $1000$),求整棵树上长度不超过 $k$ 的路径有多少条。
题解:
对于题目所给出的无根树,任意指定一个根,不会影响结果,不妨选取树的重心 $p$ 作为根节点,
对于根节点 $p$ 来说,所有路径不外乎以下两种:
1、经过节点 $p$ 的;
2、不经过节点 $p$ 的,显然,这样的路径必然是在 $p$ 的某一棵子树中。
对于第 2 种路径,基于分治思想,可以把 $p$ 的每棵子树看做子问题,递归求解。那么问题就落在,怎么求第 1 种路径的数目。
对于第 1 种路径,每一条路径都可以分成两段:$p sim x$ 和 $p sim y$;
假设 $dist[x]$ 表示从根节点 $p$ 到当前节点 $x$ 的距离;而 $sub[x]$ 表示当前节点 $x$ 是 $p$ 的哪一棵子树,特别地令 $sub[p] = p,dist[p] = 0$;
容易求路径数目转化为求满足如下条件的节点对 $(x,y)$ 的数目:
1、$dist[x] + dist[y] le k$;
2、$sub[x] e sub[y]$。
不妨设函数 ${ m{calc}}(p)$ 能够返回满足上述条件的节点对数目,那么如何实现 ${ m{calc}}(p)$ 呢?据书上说……一般有两种方式:
一、树上直接统计:
设 $p$ 的子树为 $s_1,s_2, cdots ,s_m$,那么对于 $s_i$ 中的每一个节点 $x$,统计子树 $s_1,s_2, cdots ,s_{i-1}$ 中所有满足 $dist[x]+dist[y] le k$ 的节点 $y$ 的数目即可。
具体地,可以建立一个树状数组,依次处理每棵子树 $s_i$:
1、对于 $s_i$ 中的每一个节点 $x$,${ m{ask}}(k - dist[x])$ 即可统计节点 $y$ 的数目;
2、对于 $s_i$ 中的每一个节点 $x$,执行一次 ${ m{add}}(dist[x],1)$ 代表与根节点 $p$ 距离为 $dist[x]$ 的节点增加一个。
二、指针扫描数组
对于以节点 $p$ 为根的子树,把该子树上任意节点 $x$ 到 $p$ 的距离 $dist[x]$ 放入一个数组 $d$ 中,并且对数组 $d$ 进行升序排序。
使用两个指针 $l,r$ 分别从头尾开始扫描数组 $d$,直到 $l==r$ 时停止,容易发现当指针 $l$ 从左向右扫描的过程中,恰好使得满足 $dist[l] + dist[r] le k$ 的右指针 $r$ 是从右向左单调递减的。
那么,显然当路径的一端 $x$ 固定时(换句话说就是左指针固定时),左移右指针直到满足 $dist[l] + dist[r] le k$,此时即可知满足题目要求的路径的另一端 $y$ 的个数为 $r-l$。
因此,每次左指针 $l$ 指向一个位置就会得到一个值 $r-l$,把它们全部累加起起来返回即可。
这时,你可能要问了:这样不对啊,这样的话 ${ m{calc}}(p)$ 所统计的有些节点对 $(x,y)$ 的路径有可能会是 $x o p o x o y$ 或者 $x o y o p o y$ 这样啊。
确实是这样的,因为我们没有迫使两个端点 $x,y$ 不属于 $p$ 的同一棵子树,这样统计下来显然是有多算若干条形如 $x o p o x o y$ 或 $x o y o p o y$ 的路径的,换句话说,我们将属于 $p$ 的同一棵子树的端点对 $(x,y)$ 都算进去了,这显然是不符合我们最初对于 ${ m{calc}}(p)$ 的定义的。
那么要如何才能去掉这些不符合要求的端点对呢?假设 $p$ 的子树为 $s_1,s_2, cdots ,s_m$(用其根节点编号表示),
显然,这些“坏的”端点对都是属于 $p$ 的同一棵子树下的,因此上述那些坏的端点对的总数目为 ${ m{calc}}(s_1) + { m{calc}}(s_1) + cdots + { m{calc}}(s_m)$(注意,对于此时的 ${ m{calc}}(s_i)$,$dist[s_i] = dist(p,s_i)$,而非 $dist[s_i] = 0$);
因此我们实际上正确的 ${ m{correct\_calc}}(p) = { m{calc}}(p) - [{ m{calc}}(s_1) + { m{calc}}(s_1) + cdots + { m{calc}}(s_m)]$。
由于方法一中BIT的范围和路径长度有关,这个范围远比 $n$ 要大,因此本题更加适用方法二。
时间复杂度(对于方法二):
考虑第一层递归:
仅有一次 ${ m{calc}}(root)$,$O(n)$ 求得所有节点的深度,$O(n log n)$ 的排序,$O(n)$ 的双指针扫描,合起来是 $O(n log n)$;
对所有 $root$ 的子树都要 ${ m{calc}}(s_i)$,但是所有的 ${ m{calc}}(s_i)$ 合起来一共也是处理 $O(n)$ 的点,因此全部合起来同 ${ m{calc}}(root)$ 是一样的复杂度。
而后面的每一层递归,虽然有多个根节点要 ${ m{calc}}(root)$,但同上面一样的道理,同一层的递归下全部 ${ m{calc}}(root)$ 合起来一共处理 $O(n)$ 的点,因此每一层递归的时间复杂度同第一层。
因此,若递归最深达到 $T$ 层,那么整个算法的时间复杂度为 $O(T imes n log n)$。
若随意取根节点,一旦每次都取到链的端点,就会使得 $T$ 达到 $n$,时间复杂度退化为 $O(n^2 log n)$;
因此,对于每层递归的每一棵树我们都要选择树的重心作为根节点,则所有子树的大小不会超过整棵树大小的一半,即可保证 $T$ 为 $O(log n)$,则算法总的时间复杂度就为 $O(n log ^2 n)$。
AC代码:
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> using namespace std; const int INF=0x3f3f3f3f; const int maxn=10000+10; int n,k; int vis[maxn],dist[maxn]; vector<int> d; int ans; struct Edge{ int u,v,w; Edge(int _u=0,int _v=0,int _w=0){u=_u,v=_v,w=_w;} }; vector<Edge> E; vector<int> G[maxn]; void init(int l,int r) { E.clear(); for(int i=l;i<=r;i++) G[i].clear(); } void addedge(int u,int v,int w) { E.push_back(Edge(u,v,w)); G[u].push_back(E.size()-1); } namespace getCtr //获取重心 { int n,siz[maxn]; pair<int,int> center; void dfs(int now,int par) { siz[now]=1; int maxpart=0; for(int i=0;i<G[now].size();i++) { Edge &e=E[G[now][i]]; int nxt=e.v; if(nxt!=par && !vis[nxt]) { dfs(nxt,now); siz[now]+=siz[nxt]; maxpart=max(maxpart,siz[nxt]); } } maxpart=max(maxpart,n-siz[now]); if(maxpart<center.first) { center.first=maxpart; center.second=now; } } void work(int now,int par) { center=make_pair(INF,0); dfs(now,par); } }; void getdist(int now,int dist,int par) { d.push_back(dist); for(int i=0;i<G[now].size();i++) { Edge &e=E[G[now][i]]; int nxt=e.v; if(nxt!=par && !vis[nxt]) getdist(nxt,dist+e.w,now); } } int calc(int p,int dist_p) { d.clear(); getdist(p,dist_p,0); sort(d.begin(),d.end()); int res=0; int i=0,j=d.size()-1; while(i<j) { if(d[i]+d[j]<=k) res+=j-i, i++; else j--; } return res; } void dfs(int now) { vis[now]=1; ans+=calc(now,0); for(int i=0;i<G[now].size();i++) { Edge &e=E[G[now][i]]; int nxt=e.v; if(!vis[nxt]) ans-=calc(nxt,e.w); } for(int i=0;i<G[now].size();i++) { Edge &e=E[G[now][i]]; int nxt=e.v; if(!vis[nxt]) { getCtr::n=getCtr::siz[nxt]; getCtr::work(nxt,now); dfs(getCtr::center.second); } } } int main() { while(scanf("%d%d",&n,&k),n+k) { init(1,n); for(int i=1,u,v,w;i<n;i++) { scanf("%d%d%d",&u,&v,&w); addedge(u,v,w); addedge(v,u,w); } ans=0; memset(vis,0,sizeof(vis)); getCtr::n=n; getCtr::work(1,0); dfs(getCtr::center.second); printf("%d ",ans); } }