原题链接:http://acm.uestc.edu.cn/#/problem/show/92
题意:
给你一棵树,然后在树上连接一条边。现在有若干次询问,每次问你两个点(u,v)之间的距离在加那条边之后减小了多少。
题解:
对于那条加入的边,只有两种情况,要么走,要么不走。不走的距离就是$dis[u]+dis[v]-2*dis[LCA(u,v)]$,其中$dis$表示点到根节点的距离,LCA表示最近公共祖先。现在考虑走的情况:设加入的那条边是$(a,b)$,边权是c,那么答案显然是:
$$min(DIS(a,u)+DIS(b,v)+c,DIS(a,v)+DIS(b,u)+c)$$
其中DIS表示两点间在树上的最短距离。
代码:
#include<iostream> #include<cstdio> #include<cstring> #include<vector> #include<algorithm> #define MAX_N 100005 #define MAX_D 22 using namespace std; struct edge { public: int to, cost; edge(int t, int c) : to(t), cost(c) { } edge() { } }; vector<edge> G[MAX_N]; int n,q; int ancestor[MAX_N][MAX_D]; int depth[MAX_N]; int dis[MAX_N]; void init(){ memset(dis,0,sizeof(dis)); memset(ancestor,0,sizeof(ancestor)); memset(depth,0,sizeof(depth)); for(int i=0;i<=n;i++)G[i].clear(); } void dfs(int u,int p) { for (int i = 0; i < G[u].size(); i++) { int v = G[u][i].to; if (v == p)continue; dis[v]=dis[u]+G[u][i].cost; depth[v] = depth[u] + 1; ancestor[v][0] = u; dfs(v, u); } } void getAncestor() { for (int j = 1; j < MAX_D; j++) for (int i = 1; i <= n; i++) ancestor[i][j] = ancestor[ancestor[i][j - 1]][j - 1]; } int LCA(int u,int v) { if (depth[u] < depth[v])swap(u, v); for (int i = MAX_D - 1; i >= 0; i--) { if (depth[ancestor[u][i]] >= depth[v]) { u = ancestor[u][i]; if (depth[u] == depth[v])break; } } if (u == v)return u; for (int i = MAX_D - 1; i >= 0; i--) { if (ancestor[u][i] != ancestor[v][i]) { u = ancestor[u][i]; v = ancestor[v][i]; } } return ancestor[u][0]; } int getDis(int u,int v) { int L = LCA(u, v); return dis[u] + dis[v] - 2 * dis[L]; } int T; int cas=0; int main() { cin >> T; while (T--) { printf("Case #%d: ", ++cas); scanf("%d%d", &n, &q); init(); for (int i = 0; i < n - 1; i++) { int u, v, c; scanf("%d%d%d", &u, &v, &c); G[u].push_back(edge(v, c)); G[v].push_back(edge(u, c)); } int x, y, z; scanf("%d%d%d", &x, &y, &z); dfs(1, 0); getAncestor(); while (q--) { int u, v; scanf("%d%d", &u, &v); int tmp, ans; ans = tmp = getDis(u, v); ans = min(ans, getDis(u, x) + getDis(y, v) + z); ans = min(ans, getDis(u, y) + getDis(x, v) + z); printf("%d ", tmp - ans); } } return 0; }