POJ_1986
没有太仔细看题意,dicuss里面说是查询树上两点间距离我就照办了,由于自己比较懒,一直没有学习和LCA相关的算法,所以只好复习一下前面写过的树链剖分和link-cut-tree了。
View Code // 树链剖分
#include<stdio.h> #include<string.h> #include<algorithm> #define MAXD 40010 #define MAXM 80010 int N, M, first[MAXD], e, v[MAXM], next[MAXM], w[MAXD], sum[4 * MAXD]; int q[MAXD], pre[MAXD], dep[MAXD], size[MAXD], son[MAXD], top[MAXD], wh[MAXD]; struct Edge { int x, y, z; }edge[MAXM]; void update(int cur) { sum[cur] = sum[cur << 1] + sum[cur << 1 | 1]; } void build(int cur, int x, int y) { int mid = x + y >> 1, ls = cur << 1, rs = cur << 1 | 1; if(x == y) { sum[cur] = w[x]; return ; } build(ls, x, mid), build(rs, mid + 1, y); update(cur); } int getsum(int cur, int x, int y, int s, int t) { int mid = x + y >> 1, ls = cur << 1, rs = cur << 1 | 1; if(x >= s && y <= t) return sum[cur]; if(mid >= t) return getsum(ls, x, mid, s, t); else if(mid + 1 <= s) return getsum(rs, mid + 1, y, s, t); else return getsum(ls, x, mid, s, t) + getsum(rs, mid + 1, y, s, t); } void prepare() { int i, j, x, rear = 0, cnt; q[rear ++] = 1, pre[1] = dep[1] = 0; for(i = 0; i < rear; i ++) { x = q[i]; for(j = first[q[i]]; j != -1; j = next[j]) if(v[j] != pre[x]) q[rear ++] = v[j], pre[v[j]] = x, dep[v[j]] = dep[x] + 1; } size[0] = 0; for(i = rear - 1; i >= 0; i --) { x = q[i], size[x] = 1, son[x] = 0; for(j = first[q[i]]; j != -1; j = next[j]) { size[x] += size[v[j]]; if(size[v[j]] > size[son[x]]) son[x] = v[j]; } } memset(top, -1, sizeof(top[0]) * (N + 1)); for(i = cnt = 0; i < rear; i ++) if(top[q[i]] == -1) { for(x = q[i]; x != 0; x = son[x]) top[x] = q[i], wh[x] = ++ cnt; } w[wh[1]] = 0; for(i = 0; i < M; i ++) { if(dep[edge[i].x] > dep[edge[i].y]) std::swap(edge[i].x, edge[i].y); w[wh[edge[i].y]] = edge[i].z; } build(1, 1, N); } void add(int x, int y) { v[e] = y; next[e] = first[x], first[x] = e ++; } void init() { int i; char b[5]; memset(first, -1, sizeof(first[0]) * (N + 1)), e = 0; for(i = 0; i < M; i ++) { scanf("%d%d%d%s", &edge[i].x, &edge[i].y, &edge[i].z, b); add(edge[i].x, edge[i].y), add(edge[i].y, edge[i].x); } prepare(); } void query(int x, int y) { int fx, fy, ans = 0; for(fx = top[x], fy = top[y]; fx != fy; y = pre[fy], fy = top[y]) { if(dep[fx] > dep[fy]) std::swap(x, y), std::swap(fx, fy); ans += getsum(1, 1, N, wh[fy], wh[y]); } if(x != y) { if(dep[x] > dep[y]) std::swap(x, y); ans += getsum(1, 1, N, wh[son[x]], wh[y]); } printf("%d\n", ans); } void solve() { int i, x, y, n; scanf("%d", &n); for(i = 0; i < n; i ++) { scanf("%d%d", &x, &y); query(x, y); } } int main() { while(scanf("%d%d", &N, &M) == 2) { init(); solve(); } return 0; }
View Code // link-cut-tree
#include<stdio.h> #include<string.h> #define MAXD 40010 #define MAXM 80010 int N, M, first[MAXD], e, next[MAXM], v[MAXM], w[MAXM]; struct Splay { bool root; int pre, ls, rs, key, sum; void update(); void zig(int x); void zag(int x); void splay(int x); void init() { root = true; pre = ls = rs = key = sum = 0; } }sp[MAXD]; void Splay::update() { sum = sp[ls].sum + sp[rs].sum + key; } void Splay::zig(int x) { int y = rs, fa = pre; rs = sp[y].ls, sp[rs].pre = x; sp[y].ls = x, pre = y; sp[y].pre = fa; if(root) root = false, sp[y].root = true; else sp[fa].rs == x ? sp[fa].rs = y : sp[fa].ls = y; update(); } void Splay::zag(int x) { int y = ls, fa = pre; ls = sp[y].rs, sp[ls].pre = x; sp[y].rs = x, pre = y; sp[y].pre = fa; if(root) root = false, sp[y].root = true; else sp[fa].rs == x ? sp[fa].rs = y : sp[fa].ls = y; update(); } void Splay::splay(int x) { int y, z; for(; !root;) { y = pre; if(sp[y].root) sp[y].rs == x ? sp[y].zig(y) : sp[y].zag(y); else { z = sp[y].pre; if(sp[z].rs == y) { if(sp[y].rs == x) sp[z].zig(z), sp[y].zig(y); else sp[y].zag(y), sp[z].zig(z); } else { if(sp[y].ls == x) sp[z].zag(z), sp[y].zag(y); else sp[y].zig(y), sp[z].zag(z); } } } update(); } void access(int x) { int fx; for(fx = x, x = 0; fx != 0; x = fx, fx = sp[x].pre) { sp[fx].splay(fx); sp[sp[fx].rs].root = true; sp[fx].rs = x, sp[x].root = false; sp[fx].update(); } } void add(int x, int y, int z) { v[e] = y, w[e] = z; next[e] = first[x], first[x] = e ++; } void dfs(int cur, int fa) { int i; for(i = first[cur]; i != -1; i = next[i]) if(v[i] != fa) { sp[v[i]].init(), sp[v[i]].pre = cur, sp[v[i]].key = sp[v[i]].sum = w[i]; dfs(v[i], cur); } } void init() { int i, x, y, z; char b[5]; memset(first, -1, sizeof(first[0]) * (N + 1)), e = 0; for(i = 0; i < M; i ++) { scanf("%d%d%d%s", &x, &y, &z, b); add(x, y, z), add(y, x, z); } sp[0].init(), sp[1].init(); dfs(1, 0); } void query(int x, int y) { int fx; access(y); for(fx = x, x = 0; fx != 0; x = fx, fx = sp[x].pre) { sp[fx].splay(fx); if(sp[fx].pre == 0) printf("%d\n", sp[sp[fx].rs].sum + sp[x].sum); sp[sp[fx].rs].root = true; sp[fx].rs = x, sp[x].root = false; sp[fx].update(); } } void solve() { int n, x, y; scanf("%d", &n); while(n --) { scanf("%d%d", &x, &y); query(x, y); } } int main() { while(scanf("%d%d", &N, &M) == 2) { init(); solve(); } }