题目要对每次询问将一个树形图的三个点连接,输出最短距离。
利用tarjan离线算法,算出每次询问的任意两个点的最短公共祖先,并在dfs过程中求出离根的距离。把每次询问的三个点两两求出最短距离,这样最终结果就是3个值一半。
其实开始我用的一种很挫的方法才AC的,具体思路就不说了,感觉很麻烦又不好写的样子。怎么没想到上面的简便方法呢。
初始代码:
1 #include <iostream> 2 #include <sstream> 3 #include <cstdio> 4 #include <climits> 5 #include <cstring> 6 #include <cstdlib> 7 #include <string> 8 #include <stack> 9 #include <map> 10 #include <cmath> 11 #include <vector> 12 #include <queue> 13 #include <algorithm> 14 #define esp 1e-6 15 #define pi acos(-1.0) 16 #define pb push_back 17 #define lson l, m, rt<<1 18 #define rson m+1, r, rt<<1|1 19 #define mp(a, b) make_pair((a), (b)) 20 #define in freopen("in.txt", "r", stdin); 21 #define out freopen("out.txt", "w", stdout); 22 #define print(a) printf("%d ",(a)); 23 #define bug puts("********))))))"); 24 #define stop system("pause"); 25 #define Rep(i, c) for(__typeof(c.end()) i = c.begin(); i != c.end(); i++) 26 #define inf 0x0f0f0f0f 27 28 using namespace std; 29 typedef long long LL; 30 typedef vector<int> VI; 31 typedef pair<int, int> pii; 32 typedef vector<pii> VII; 33 typedef vector<pii, int> VIII; 34 typedef VI:: iterator IT; 35 const int maxn = 50000 + 1000; 36 const int maxm = (70000 + 1000) * 3; 37 int dis[maxn], lin[maxm][3], vis[maxn], pa[maxn]; 38 VII g[maxn]; 39 VII query[maxn]; 40 int n, m; 41 int findset(int x) 42 { 43 return pa[x] == x? x : pa[x] = findset(pa[x]); 44 } 45 void tarjan(int u) 46 { 47 vis[u] = 1; 48 pa[u] = u; 49 for(int i = 0; i < (int)query[u].size(); i++) 50 { 51 int v= query[u][i].first; 52 if(vis[v]) 53 { 54 lin[query[u][i].second][2] = findset(v); 55 } 56 } 57 for(int i = 0; i < (int)g[u].size(); i++) 58 { 59 int v = g[u][i].second; 60 if(!vis[v]) 61 { 62 dis[v] = dis[u] + g[u][i].first; 63 tarjan(v); 64 pa[v] = u; 65 } 66 } 67 } 68 void Init(void) 69 { 70 for(int i = 0; i < maxn; i++) 71 query[i].clear(), g[i].clear(); 72 memset(vis, 0, sizeof(vis)); 73 } 74 int main(void) 75 { 76 int flag = 0; 77 while(scanf("%d", &n) == 1) 78 { 79 if(flag) puts(""); 80 else flag = 1; 81 Init(); 82 for(int i = 1; i < n; i++) 83 { 84 int u, v, len; 85 scanf("%d%d%d", &u, &v, &len); 86 g[u].pb(mp(len, v)); 87 g[v].pb(mp(len, u)); 88 } 89 scanf("%d", &m); 90 for(int i = 1; i <= 3*m; i += 3) 91 { 92 int x, y, z; 93 scanf("%d%d%d", &x, &y, &z); 94 query[lin[i][0] = x].pb(mp(lin[i][1] = y, i)); 95 query[lin[i][1]].pb(mp(lin[i][0], i)); 96 query[lin[i+1][0] = x].pb(mp(lin[i+1][1] = z, i+1)); 97 query[lin[i+1][1]].pb(mp(lin[i+1][0], i+1)); 98 query[lin[i+2][0] = y].pb(mp(lin[i+2][1] = z, i+2)); 99 query[lin[i+2][1]].pb(mp(lin[i+2][0], i+2)); 100 } 101 dis[0] = 0; 102 tarjan(0); 103 for(int i = 1; i <= 3*m; i += 3) 104 { 105 int ans; 106 if(lin[i+1][2] == lin[i+2][2]) 107 { 108 // if(lin[i+2][2] == 0) 109 // ans = dis[lin[i][0]] + dis[lin[i][1]] - 2 * dis[lin[i][2]] + dis[lin[i][2]] + dis[lin[i+1][1]]; 110 ans = dis[lin[i][0]] + dis[lin[i][1]] - 2 * dis[lin[i][2]] - 2*dis[lin[i+1][2]] + dis[lin[i+1][1]] + dis[lin[i][2]]; 111 } 112 else 113 ans = dis[lin[i][0]] + dis[lin[i][1]] - 2 * dis[lin[i][2]] + dis[lin[i+1][1]]- max(dis[lin[i+1][2]], dis[lin[i+2][2]]); 114 printf("%d ",ans); 115 } 116 } 117 return 0; 118 }
简便方法的代码:
1 #include <iostream> 2 #include <sstream> 3 #include <cstdio> 4 #include <climits> 5 #include <cstring> 6 #include <cstdlib> 7 #include <string> 8 #include <stack> 9 #include <map> 10 #include <cmath> 11 #include <vector> 12 #include <queue> 13 #include <algorithm> 14 #define esp 1e-6 15 #define pi acos(-1.0) 16 #define pb push_back 17 #define lson l, m, rt<<1 18 #define rson m+1, r, rt<<1|1 19 #define mp(a, b) make_pair((a), (b)) 20 #define in freopen("in.txt", "r", stdin); 21 #define out freopen("out.txt", "w", stdout); 22 #define print(a) printf("%d ",(a)); 23 #define bug puts("********))))))"); 24 #define stop system("pause"); 25 #define Rep(i, c) for(__typeof(c.end()) i = c.begin(); i != c.end(); i++) 26 #define inf 0x0f0f0f0f 27 28 using namespace std; 29 typedef long long LL; 30 typedef vector<int> VI; 31 typedef pair<int, int> pii; 32 typedef vector<pii> VII; 33 typedef vector<pii, int> VIII; 34 typedef VI:: iterator IT; 35 const int maxn = 50000 + 100; 36 const int maxm = (70000 + 100) * 3; 37 int dis[maxn], lin[maxm][3], vis[maxn], pa[maxn]; 38 VII g[maxn]; 39 VII query[maxn]; 40 int ans[maxm]; 41 int n, m; 42 int findset(int x) 43 { 44 return pa[x] == x? x : pa[x] = findset(pa[x]); 45 } 46 void tarjan(int u) 47 { 48 vis[u] = 1; 49 pa[u] = u; 50 for(int i = 0; i < (int)query[u].size(); i++) 51 { 52 int v= query[u][i].first; 53 if(vis[v]) 54 { 55 ans[query[u][i].second] += dis[u] + dis[v] - 2 * dis[findset(v)]; 56 } 57 } 58 for(int i = 0; i < (int)g[u].size(); i++) 59 { 60 int v = g[u][i].second; 61 if(!vis[v]) 62 { 63 dis[v] = dis[u] + g[u][i].first; 64 tarjan(v); 65 pa[v] = u; 66 } 67 } 68 } 69 void Init(void) 70 { 71 for(int i = 0; i < maxn; i++) 72 query[i].clear(), g[i].clear(); 73 memset(vis, 0, sizeof(vis)); 74 memset(ans, 0, sizeof(ans)); 75 } 76 int main(void) 77 { 78 int flag = 0; 79 while(scanf("%d", &n) == 1) 80 { 81 if(flag) puts(""); 82 else flag = 1; 83 Init(); 84 for(int i = 1; i < n; i++) 85 { 86 int u, v, len; 87 scanf("%d%d%d", &u, &v, &len); 88 g[u].pb(mp(len, v)); 89 g[v].pb(mp(len, u)); 90 } 91 scanf("%d", &m); 92 for(int i = 1; i <= m; i++) 93 { 94 int x, y, z; 95 scanf("%d%d%d", &x, &y, &z); 96 query[x].pb(mp(y, i)); 97 query[y].pb(mp(x, i)); 98 query[x].pb(mp(z, i)); 99 query[z].pb(mp(x, i)); 100 query[y].pb(mp(z, i)); 101 query[z].pb(mp(y, i)); 102 } 103 dis[0] = 0; 104 tarjan(0); 105 for(int i = 1; i <= m; i++) 106 { 107 printf("%d ", ans[i]>>1); 108 } 109 } 110 return 0; 111 }