摘要:
本文主要介绍了解决LCA(最近公共祖先问题)的两种算法,分别是离线Tarjan算法和在线算法,着重展示了在具体题目中的应用细节。
最近公共祖先是指对于一棵有根树T的两个结点u和v,它们的LCA(T,u,v)表示一个结点x,满足x是u和v的公共祖先且x深度尽可能的大(也即最近)。
求最近公共祖先有两种方法:一种是离线求解算法,也就是将询问全部存起来,处理完之后一次回答所有询问;另一种方法就是在线求解算法,对于每次询问,动态地回答。
离线算法Tarjan
Tarjan算法就是利用深度优先搜索的框架,对于新搜索到的一个结点,首先创建由这个结点构成的集合,再对当前结点的每一个子树进行搜索,每搜索完一棵子树,则可以确定这棵子树之内的LCA问题都已经解决。其他的LCA问题肯定都在这个子树之外,这时把子树所形成集合与当前结点的集合合并,并将当前结点设为这个集合的祖先。
之后继续搜索下一棵子树,直到当前结点的所有子树搜索完,这时把当前结点也设为已经被检查过的,同时可以处理有关当前结点的LCA询问,如果有一个从当前结点到v的询问,且v已经被检查过,则由于进行的是深度优先搜索,当前结点和v的LCA一定还没有被检查过,然而这个最近公共祖先的包含v的子树一定已经搜索过了,那么这个最近公共祖先一定是v所在集合的祖先。
算法实现如下:
1 const int maxn = 10005;//结点数 2 bool vis[maxn]; 3 int tree[maxn][maxn], ans[maxn][maxn], fa[maxn]; 4 //tree[u][0]表示结点u有几个孩子,分别是 5 //ans[u][v]表示u和v的LCA 6 //fa[i]表示i的祖先 7 set<int> query[maxn];//保存关于u结点的询问 8 9 int n; 10 void init() { 11 for(int i = 0; i <= n; i++) { 12 fa[i] = i; 13 vis[i] = 0; 14 } 15 } 16 17 int Find(int x) {//查找一个集合的祖先 18 return fa[x] == x ? x : fa[x] = Find(fa[x]); 19 } 20 21 void Union(int x, int y) {//合并两个集合,将x并入y中 22 int fx = Find(x); 23 int fy = Find(y); 24 fa[fx] = fy; 25 } 26 27 void dfs(int u) {//dfs遍历树 28 vis[u] = 1; 29 for(int i = 1; i <= tree[u][0]; i++) { 30 int v = tree[u][i]; 31 if(vis[v]) continue; 32 dfs(v); 33 Union(v, u);//当遍历完一棵子树的时候,就将子树和父亲合并 34 } 35 for(set<int>::iterator it = query[u].begin(); 36 it != query[u].end(); it++) {//当处理完u结点的子树时,就可以回答部分有关u的询问 37 int v = *it; 38 if(vis[v]) { //如果v被访问过,就可以回答这个询问 39 ans[u][v] = Find(v); //u and v 的LCA就是v所在集合的公共祖先 40 query[v].erase(query[u].find(u));//将v中的此询问删掉 41 } 42 } 43 }
Tarjan算法可以解决LCA查询要求实现知道全部查询提问,如果LCA要求即问即答,就需要使用在线算法。
在线算法
在线算法需要对该树进行预处理,生成三个序列:欧拉序列、深度序列、遍历结点第一次出现的时间序列,然后通过RMQ(区间最值查询)来O(1)地回答问题。
巧妙的是只用对树进行一次深度优先遍历,就可以得到这三个序列了。
结点第一次出现的时间:就是深度优先遍历的过程中第一次遍历到这个结点的时间,该序列的长度是n,记为pos数组,即pos[u] = 3,表示u结点是第三个遍历到的。
欧拉序列:按照深度优先遍历,依次经过的结点按照遍历顺序全部记录下来,包括回溯的过程,也就是一个点可能被记录多次。该序列的长度由深搜的过程决定,记为t数组。
深度序列:该序列的长度和欧拉序列的长度一致,记录的是欧拉序列中对应结点的深度,记为dep。
有了这三个序列,假设我们需要查询LCA(T,u,v),通过pos[u]和pos[v]可以知道u和v结点在t数组和dep数组中是第几个,利用深度优先遍历的过程,可以知道在dep[pos[u]]~dep[pos[v]]中深度最小的结点就是LCA(T,u,v)了。
算法实现如下:
1 const int maxn = 1006; 2 3 int tot, ans[maxn][maxn], link[maxn][maxn];//link存树的结构 4 int dep[maxn * 4], pos[maxn], t[maxn * 4]; 5 int dp[maxn * 4][12];//存储区间最值 6 bool v[maxn]; 7 8 void dfs(int u, int dfn) { 9 if(!v[u]) { 10 v[u] = 1; 11 pos[u] = tot; 12 } 13 dep[tot] = dfn; //深度序列 14 t[tot++] = u; //欧拉序列 15 for(int i = 1; i <= link[u][0]; i++) { 16 dfs(link[u][i], dfn + 1); 17 18 dep[tot] = dfn; 19 t[tot++] = u; 20 } 21 return; 22 } 23 24 void init() { 25 for(int j = 0; (1 << j) <= tot; j++) { 26 for(int i = 0; i + (1 << j) <= tot; i++) { 27 if(j == 0) 28 dp[i][j] = i; 29 else { 30 if(dep[dp[i][j - 1]] < dep[dp[i + (1 << (j - 1))][j - 1]]) 31 dp[i][j] = dp[i][j - 1]; 32 else 33 dp[i][j] = dp[i + (1 << (j - 1))][j - 1]; 34 } 35 } 36 } 37 } 38 39 int RMQ(int p1, int p2) { 40 int k = log2(p2 - p1 + 1); 41 if( (1<<k) < p2 - p1 + 1) k++; 42 if(dep[dp[p1][k]] < dep[dp[p2 - (1 << k) + 1][k]]) 43 return t[dp[p1][k]]; 44 else 45 return t[dp[p2 - (1 << k) + 1][k]]; 46 } 47 48 int lca(int v1, int v2) { 49 if(pos[v1] < pos[v2]) 50 return RMQ(pos[v1], pos[v2]); 51 else 52 return RMQ(pos[v2], pos[v1]); 53 }
下面看一道例题:HDU 2586 How far away ?
题意
输出一棵有根数,问任意两个结点间的距离
解题思路
首先问一次计算一次不是什么好的办法,我们可以将每个结点到根结点的距离预处理出来,然后找到两个结点的最近公共祖先,然后答案就是dis[u] + dis[v] - 2 * dis[lca(u, v)]。因为是即问即答,所以采用在线的方法。
注意RMQ中k的计算方式有所不同,采用之前的方法计算会发生数组访问越界。
代码如下:
1 #include <cstdio> 2 #include <vector> 3 #include <cmath> 4 #include <cstring> 5 6 using namespace std; 7 8 const int maxn = 41010; 9 struct E{ 10 int v, ne, d; 11 E(){} 12 E(int _v, int _n, int _d): v(_v), ne(_n), d(_d){} 13 }e[maxn * 2]; 14 15 int t[maxn * 2], dep[maxn * 2], pos[maxn], dis[maxn]; 16 int head[maxn], esize; 17 bool vis[maxn]; 18 int dp[maxn * 2][30]; 19 int n, m, tot; 20 21 void init() { 22 esize = tot = 0; 23 memset(vis, 0, sizeof(vis)); 24 memset(dis, 0, sizeof(dis)); 25 memset(head, -1, sizeof(head)); 26 } 27 28 void add(int u, int v, int d) { 29 e[esize] = E(v, head[u], d); 30 head[u] = esize++; 31 } 32 33 void dfs(int u, int de) { 34 if(!vis[u]) { 35 vis[u] = 1; 36 pos[u] = tot; 37 } 38 dep[tot] = de; 39 t[tot++] = u; 40 41 for(int i = head[u]; i != -1; i = e[i].ne) { 42 int v = e[i].v; 43 int d = e[i].d; 44 if(vis[v]) continue; 45 dis[v] = dis[u] + d; 46 dfs(v, de + 1); 47 48 dep[tot] = de; 49 t[tot++] = u; 50 } 51 return; 52 } 53 54 void cdep() { 55 for(int j = 0; (1 << j) < tot; j++) { 56 for(int i = 1; i + (1 << j) < tot; i++) { 57 if(j == 0) 58 dp[i][j] = i; 59 else { 60 if(dep[dp[i][j - 1]] < dep[dp[i + (1 << (j - 1))][j - 1]]) 61 dp[i][j] = dp[i][j - 1]; 62 else 63 dp[i][j] = dp[i + (1 << (j - 1))][j - 1]; 64 } 65 } 66 } 67 } 68 69 int RMQ(int u, int v) { 70 int k = 0; 71 k = log2(v - u + 1); 72 if((1 << k) < v - u + 1) k++; 73 /*int len = v - u + 1, k = 0; 74 k = log(len * 1.0)/log(2.0);*/ 75 if(dep[dp[u][k]] < dep[dp[v - (1 << k) + 1][k]]) 76 return t[dp[u][k]]; 77 else 78 return t[dp[v - (1 << k) + 1][k]]; 79 } 80 81 int lca(int u, int v) { 82 if(pos[u] < pos[v]) 83 return RMQ(pos[u], pos[v]); 84 else 85 return RMQ(pos[v], pos[u]); 86 } 87 88 int main() 89 { 90 int T; 91 scanf("%d", &T); 92 while(T--) { 93 scanf("%d%d", &n, &m); 94 init(); 95 for(int i = 1; i < n; i++) { 96 int u, v, w; 97 scanf("%d%d%d", &u, &v, &w); 98 add(u, v, w); 99 add(v, u, w); 100 } 101 dfs(1, 0); 102 cdep(); 103 104 int u, v; 105 while(m--) { 106 scanf("%d%d", &u, &v); 107 printf("%d ", dis[u] + dis[v] - 2 * dis[lca(u, v)]); 108 } 109 } 110 return 0; 111 }