最近公共祖先问题(LCA)是求一颗树上的某两点距离他们最近的公共祖先节点,由于树的特性,树上两点之间路径是唯一的,所以对于很多处理关于树的路径问题的时候为了得知树两点的间的路径,LCA是几乎最有效的解法。
首先是LCA的倍增算法。算法主体是依靠首先对整个树的预处理DFS,用来预处理出每个点的直接父节点,同时可以处理出每个点的深度和与根节点的距离,然后利用类似RMQ的思想处理出每个点的 2 的幂次的祖先节点,这就可以用 nlogn 的时间完成整个预处理的工作。然后每一次求两个点的LCA时只要对两个点深度经行考察,将深度深的那个利用倍增先爬到和浅的同一深度,然后一起一步一步爬直到爬到相同节点,就是LCA了。
具体模板是从鹏神的模板小改来的。
注释方便理解版:
1 #include<stdio.h>
2 #include<string.h>
3 #include<algorithm>
4 using namespace std;
5
6 const int maxn=1e5+5;
7 const int maxm=1e5+5;
8 const int maxl=20; //总点数的log范围,一般会开稍大一点
9
10 int fa[maxl][maxn],dep[maxn],dis[maxn]; //fa[i][j]是j点向上(不包括自己)2**i 层的父节点,dep是某个点的深度(根节点深度为0),dis是节点到根节点的距离
11 int head[maxn],point[maxm],nxt[maxm],val[maxm],size;
12 int n;
13
14 void init(){
15 size=0;
16 memset(head,-1,sizeof(head));
17 }
18
19 void add(int a,int b,int v){
20 point[size]=b;
21 val[size]=v;
22 nxt[size]=head[a];
23 head[a]=size++;
24 point[size]=a;
25 val[size]=v;
26 nxt[size]=head[b];
27 head[b]=size++;
28 }
29
30 void Dfs(int s,int pre,int d){ //传入当前节点标号,父亲节点标号,以及当前深度
31 fa[0][s]=pre; //当前节点的上一层父节点是传入的父节点标号
32 dep[s]=d;
33 for(int i=head[s];~i;i=nxt[i]){
34 int j=point[i];
35 if(j==pre)continue;
36 dis[j]=dis[s]+val[i];
37 Dfs(j,s,d+1);
38 }
39 }
40
41 void Pre(){
42 dis[1]=0;
43 Dfs(1,-1,0);
44 for(int k=0;k+1<maxl;++k){ //类似RMQ的做法,处理出点向上2的幂次的祖先。
45 for(int v=1;v<=n;++v){
46 if(fa[k][v]<0)fa[k+1][v]=-1;
47 else fa[k+1][v]=fa[k][fa[k][v]]; //处理出两倍距离的祖先
48 }
49 }
50 }
51
52 int Lca(int u,int v){
53 if(dep[u]>dep[v])swap(u,v); //定u为靠近根的点
54 for(int k=maxl-1;k>=0;--k){
55 if((dep[v]-dep[u])&(1<<k)) //根据层数差值的二进制向上找v的父亲
56 v=fa[k][v];
57 }
58 if(u==v)return u; //u为v的根
59 for(int k=maxl-1;k>=0;--k){
60 if(fa[k][u]!=fa[k][v]){ //保持在相等层数,同时上爬寻找相同父节点
61 u=fa[k][u];
62 v=fa[k][v];
63 }
64 }
65 return fa[0][u]; //u离lca只差一步
66 }
木有注释版:
1 #include<stdio.h>
2 #include<string.h>
3 #include<algorithm>
4 using namespace std;
5
6 const int maxn=1e5+5;
7 const int maxm=1e5+5;
8 const int maxl=20;
9
10 int fa[maxl][maxn],dep[maxn],dis[maxn];
11 int head[maxn],point[maxm],nxt[maxm],val[maxm],size;
12 int n;
13
14 void init(){
15 size=0;
16 memset(head,-1,sizeof(head));
17 }
18
19 void add(int a,int b,int v){
20 point[size]=b;
21 val[size]=v;
22 nxt[size]=head[a];
23 head[a]=size++;
24 point[size]=a;
25 val[size]=v;
26 nxt[size]=head[b];
27 head[b]=size++;
28 }
29
30 void Dfs(int s,int pre,int d){
31 fa[0][s]=pre;
32 dep[s]=d;
33 for(int i=head[s];~i;i=nxt[i]){
34 int j=point[i];
35 if(j==pre)continue;
36 dis[j]=dis[s]+val[i];
37 Dfs(j,s,d+1);
38 }
39 }
40
41 void Pre(){
42 dis[1]=0;
43 Dfs(1,-1,0);
44 for(int k=0;k+1<maxl;++k){
45 for(int v=1;v<=n;++v){
46 if(fa[k][v]<0)fa[k+1][v]=-1;
47 else fa[k+1][v]=fa[k][fa[k][v]];
48 }
49 }
50 }
51
52 int Lca(int u,int v){
53 if(dep[u]>dep[v])swap(u,v);
54 for(int k=maxl-1;k>=0;--k){
55 if((dep[v]-dep[u])&(1<<k))
56 v=fa[k][v];
57 }
58 if(u==v)return u;
59 for(int k=maxl-1;k>=0;--k){
60 if(fa[k][u]!=fa[k][v]){
61 u=fa[k][u];
62 v=fa[k][v];
63 }
64 }
65 return fa[0][u];
66 }
静态树上路径求最小值:LCA倍增
1 #include<bits/stdc++.h>
2 using namespace std;
3
4 const int maxn=1e6+5;
5 const int maxm=2e6+5;
6 const int maxl=22;
7 const int INF = 0x3f3f3f3f;
8
9 int fa[maxl][maxn],dep[maxn],dis[maxl][maxn];
10 int head[maxn],point[maxm],nxt[maxm],val[maxm],size;
11 int vis[maxn];
12 int n,q,tmp=INF;
13
14 void init(){
15 size=0;
16 memset(head,-1,sizeof(head));
17 memset(vis,0,sizeof(vis));
18 }
19
20 void add(int a,int b){
21 point[size]=b;
22 nxt[size]=head[a];
23 head[a]=size++;
24 point[size]=a;
25 nxt[size]=head[b];
26 head[b]=size++;
27 }
28
29 void Dfs(int s,int pre,int d){
30 fa[0][s]=pre;
31 dis[0][s]=s;
32 dep[s]=d;
33 for(int i=head[s];~i;i=nxt[i]){
34 int j=point[i];
35 if(j==pre)continue;
36 Dfs(j,s,d+1);
37 }
38 }
39
40 void Pre(){
41 Dfs(1,-1,0);
42 for(int k=0;k+1<maxl;++k){
43 for(int v=1;v<=n;++v){
44 if(fa[k][v]<0)fa[k+1][v]=-1;
45 else fa[k+1][v]=fa[k][fa[k][v]];
46 if(fa[k][v]<0)dis[k+1][v]=dis[k][v];
47 else dis[k+1][v]=min(dis[k][v],dis[k][fa[k][v]]);
48 }
49 }
50 }
51
52 int Lca(int u,int v){
53 tmp = min( u, v );
54 if(dep[u]>dep[v])swap(u,v);
55 for(int k=maxl-1;k>=0;--k){
56 if((dep[v]-dep[u])&(1<<k)){
57 tmp = min( tmp, dis[k][v]);
58 v=fa[k][v];
59 }
60 }
61 tmp = min( tmp,v );
62 if(u==v)return u;
63 for(int k=maxl-1;k>=0;--k){
64 if(fa[k][u]!=fa[k][v]){
65 tmp=min(tmp,min(dis[k][u],dis[k][v]));
66 u=fa[k][u],v=fa[k][v];
67 }
68 }
69 tmp = min( tmp, min(u,v));
70 tmp = min( tmp, fa[0][u]);
71 return fa[0][u];
72 }
73 //tmp即为u、v路径上的最小值
离线Tarjan的做法主要是防止由于每个点对可能被询问多次,导致每次求都需要 logn 的时间,会超时,所以离线来一并处理所有的询问。
Tarjan的做法是通过递归到最底层,然后开始不断递归回去合并并查集,这样就能够在访问完每个点之后赋值它有关切另一个点已经被访问过的询问。
同样是鹏神的模板修改成自己的代码风格后的。
注释版:
1 #include<stdio.h> //差不多要这些头文件
2 #include<string.h>
3 #include<vector>
4 #include<algorithm>
5 using namespace std;
6
7 const int maxn=1e5+5; //点数、边数、询问数
8 const int maxm=2e5+5;
9 const int maxq=1e4+5;
10
11 int n;
12 int head[maxn],nxt[maxm],point[maxm],val[maxm],size;
13 int vis[maxn],fa[maxn],dep[maxn],dis[maxn];
14 int ans[maxq];
15 vector<pair<int,int> >v[maxn]; //记录询问、问题编号
16
17 void init(){
18 memset(head,-1,sizeof(head));
19 size=0;
20 memset(vis,0,sizeof(vis));
21 for(int i=1;i<=n;++i){
22 v[i].clear();
23 fa[i]=i;
24 }
25 dis[1]=dep[1]=0;
26 }
27
28 void add(int a,int b,int v){
29 point[size]=b;
30 val[size]=v;
31 nxt[size]=head[a];
32 head[a]=size++;
33 point[size]=a;
34 val[size]=v;
35 nxt[size]=head[b];
36 head[b]=size++;
37 }
38
39 int find(int x){
40 return x==fa[x]?x:fa[x]=find(fa[x]);
41 }
42
43 void Tarjan(int s,int pre){
44 for(int i=head[s];~i;i=nxt[i]){
45 int j=point[i];
46 if(j!=pre){
47 dis[j]=dis[s]+val[i];
48 dep[j]=dep[s]+1;
49 Tarjan(j,s); //这里Tarjan的DPS操作必须在并查集合并之前,这样才能保证求lca的时候lca是每一小部分合并时的祖先节点,如果顺序交换,那么所有的查询都会得到 1 节点,就是错误的
50 int x=find(j),y=find(s);
51 if(x!=y)fa[x]=y;
52 }
53 }
54 vis[s]=1;
55 for(int i=0;i<v[s].size();++i){
56 int j=v[s][i].first;
57 if(vis[j]){
58 int lca=find(j);
59 int id=v[s][i].second;
60 ans[id]=lca; //这里视题目要求给答案赋值
61 // ans[id]=dep[s]+dep[j]-2*dep[lca];
62 // ans[id]=dis[s]+dis[j]-2*dis[lca];
63 }
64 }
65 }
66
67
68
69 for(int i=1;i<=k;++i){ //主函数中的主要部分
70 int a,b;
71 scanf("%d%d",&a,&b);
72 v[a].push_back(make_pair(b,i)); //加问题的时候两个点都要加一次
73 v[b].push_back(make_pair(a,i));
74 }
75 Tarjan(1,0);
木有注释版:
1 #include<stdio.h>
2 #include<string.h>
3 #include<vector>
4 #include<algorithm>
5 using namespace std;
6
7 const int maxn=1e5+5;
8 const int maxm=2e5+5;
9 const int maxq=1e4+5;
10
11 int n;
12 int head[maxn],nxt[maxm],point[maxm],val[maxm],size;
13 int vis[maxn],fa[maxn],dep[maxn],dis[maxn];
14 int ans[maxq];
15 vector<pair<int,int> >v[maxn];
16
17 void init(){
18 memset(head,-1,sizeof(head));
19 size=0;
20 memset(vis,0,sizeof(vis));
21 for(int i=1;i<=n;++i){
22 v[i].clear();
23 fa[i]=i;
24 }
25 dis[1]=dep[1]=0;
26 }
27
28 void add(int a,int b,int v){
29 point[size]=b;
30 val[size]=v;
31 nxt[size]=head[a];
32 head[a]=size++;
33 point[size]=a;
34 val[size]=v;
35 nxt[size]=head[b];
36 head[b]=size++;
37 }
38
39 int find(int x){
40 return x==fa[x]?x:fa[x]=find(fa[x]);
41 }
42
43 void Tarjan(int s,int pre){
44 for(int i=head[s];~i;i=nxt[i]){
45 int j=point[i];
46 if(j!=pre){
47 dis[j]=dis[s]+val[i];
48 dep[j]=dep[s]+1;
49 Tarjan(j,s);
50 int x=find(j),y=find(s);
51 if(x!=y)fa[x]=y;
52 }
53 }
54 vis[s]=1;
55 for(int i=0;i<v[s].size();++i){
56 int j=v[s][i].first;
57 if(vis[j]){
58 int lca=find(j);
59 int id=v[s][i].second;
60 ans[id]=lca;
61 // ans[id]=dep[s]+dep[j]-2*dep[lca];
62 // ans[id]=dis[s]+dis[j]-2*dis[lca];
63 }
64 }
65 }
66
67
68
69 for(int i=1;i<=k;++i){
70 int a,b;
71 scanf("%d%d",&a,&b);
72 v[a].push_back(make_pair(b,i));
73 v[b].push_back(make_pair(a,i));
74 }
75 Tarjan(1,0);
另外,现在又有LCA用dfs序+RMQ的做法,可以实现O(nlogn)预处理,O(1)查询的LCA,基本可以完全替代倍增LCA和TarjanLCA,但是树上路径长度和树上路径最小值无法用这个来做。
1 #include <bits/stdc++.h>
2 using namespace std;
3
4 const int maxn = 2e5+5;
5 const int maxl = 20;
6 int vis[maxn],dep[maxn],dp[maxn][maxl];
7
8 int head[maxn],in[maxn],id[maxn];
9 int point[maxn],nxt[maxn],sz;
10 int val[maxn];
11 int fa[maxl][maxn]; //fa[i][j]是j点向上(不包括自己)2**i 层的父节点,dep是某个点的深度(根节点深度为0),dis是节点到根节点的距离
12 int n;
13
14 void init(){
15 sz = 0;
16 memset(head,-1,sizeof(head));
17 memset(fa,-1,sizeof(fa));
18 }
19
20 void Pre(){
21 for(int k=0;k+1<maxl;++k){ //类似RMQ的做法,处理出点向上2的幂次的祖先。
22 for(int v=1;v<=n;++v){
23 if(fa[k][v]<0)fa[k+1][v]=-1;
24 else fa[k+1][v]=fa[k][fa[k][v]]; //处理出两倍距离的祖先
25 }
26 }
27 }
28
29 void dfs(int u,int p,int d,int&k){
30 fa[0][u]=p; //当前节点的上一层父节点是传入的父节点标号
31 vis[k] = u;
32 id[u] = k;
33 dep[k++]=d;
34 for(int i = head[u];~i;i=nxt[i]){
35 int v = point[i];
36 if(v == p)continue;
37 dfs(v,u,d+1,k);
38 vis[k] = u;
39 dep[k++]=d;
40 }
41 }
42
43 void RMQ(int root){
44 int k =0 ;
45 dfs(root,-1,0,k);
46 int m = k;
47 int e= (int)(log2(m+1.0));
48 for(int i = 0 ; i < m ; ++ i)dp[i][0]=i;
49 for(int j = 1 ; j <= e ; ++ j){
50 for(int i = 0 ; i + ( 1<< j ) - 1 < m ; ++ i){
51 int N = i + (1<<(j-1));
52 if(dep[dp[i][j-1]] < dep[dp[N][j-1]]){
53 dp[i][j] = dp[i][j-1];
54 }
55 else dp[i][j] = dp[N][j-1];
56 }
57 }
58 }
59
60 void add(int a,int b){
61 point[sz] = b;
62 nxt[sz] = head[a];
63 head[a] = sz++;
64 }
65
66
67 int LCA(int u,int v){
68 int left = min(id[u],id[v]),right = max(id[u],id[v]);
69 int k = (int)(log2(right- left+1.0));
70 int pos,N = right - (1<<k)+1;
71 if(dep[dp[left][k]] < dep[dp[N][k]])pos = dp[left][k];
72 else pos = dp[N][k];
73 return vis[pos];
74 }
75
76 int q;
77
78 inline int get(int a,int k){
79 int res = a;
80 for(int i = 0 ; (1ll << i ) <= k ; ++ i){
81 if(k&(1ll<<i)){
82 res = fa[i][res];
83 }
84 }
85 return res;
86 }
87
88 void run(){
89 while(q--){
90 int a,b,k;
91 scanf("%d%d%d",&a,&b,&k);
92 int lca = LCA(a,b);
93 int num = dep[id[a]] - dep[id[lca]] + dep[id[b]] - dep[id[lca]] + 1;
94 int up = (num - 1)%k;
95 int ans = val[a];
96 // printf("a : %d
",a);
97 while(dep[id[a]] - dep[id[lca]] >= k){
98 int Id = get(a,k);
99 ans ^= val[Id];
100 // printf("a : %d
",Id);
101 a = Id;
102 }
103 if(dep[id[b]] - dep[id[lca]] > up){
104 // printf("up: %d
",up);
105 if(up == 0)ans^= val[b];
106 b = get(b,up);
107 while(dep[id[b]] - dep[id[lca]] >k ){
108 int Id = get(b,k);
109 ans ^= val[Id];
110 b = Id;
111 }
112 }
113 printf("%d
",ans);
114 }
115 }
116
117 int main(){
118 while(scanf("%d%d",&n,&q)!=EOF){
119
120 init();
121 for(int i =1 ; i < n; ++ i){
122 int a,b;
123 scanf("%d%d",&a,&b);
124 add(a,b);
125 add(b,a);
126 }
127 for(int i = 1;i <= n ; ++ i)scanf("%d",&val[i]);
128 RMQ(1);
129 Pre();
130 run();
131
132 }
133 return 0;
134 }