[题目链接]
http://poj.org/problem?id=3417
[算法]
树上差分
[代码]
#include <algorithm> #include <bitset> #include <cctype> #include <cerrno> #include <clocale> #include <cmath> #include <complex> #include <cstdio> #include <cstdlib> #include <cstring> #include <ctime> #include <deque> #include <exception> #include <fstream> #include <functional> #include <limits> #include <list> #include <map> #include <iomanip> #include <ios> #include <iosfwd> #include <iostream> #include <istream> #include <ostream> #include <queue> #include <set> #include <sstream> #include <stdexcept> #include <streambuf> #include <string> #include <utility> #include <vector> #include <cwchar> #include <cwctype> #include <stack> #include <limits.h> using namespace std; #define MAXN 100010 #define MAXLOG 20 struct edge { int to,nxt; } e[MAXN << 1]; int i,n,m,tot,u,v,ans; int sum[MAXN],dep[MAXN],head[MAXN]; int anc[MAXN][MAXLOG]; inline void addedge(int u,int v) { tot++; e[tot] = (edge){v,head[u]}; head[u] = tot; } inline void dfs1(int u) { int i,v; for (i = 1; i < MAXLOG; i++) { if (dep[u] < (1 << i)) break; anc[u][i] = anc[anc[u][i - 1]][i - 1]; } for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; if (v != anc[u][0]) { dep[v] = dep[u] + 1; anc[v][0] = u; dfs1(v); } } } inline void dfs2(int u) { int i,v; for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; if (v == anc[u][0]) continue; dfs2(v); sum[u] += sum[v]; } } inline int lca(int u,int v) { int i,t; if (dep[u] > dep[v]) swap(u,v); t = dep[v] - dep[u]; for (i = 0; i < MAXLOG; i++) { if (t & (1 << i)) v = anc[v][i]; } if (u == v) return u; for (i = MAXLOG - 1; i >= 0; i--) { if (anc[u][i] != anc[v][i]) { u = anc[u][i]; v = anc[v][i]; } } return anc[u][0]; } int main() { while (scanf("%d%d",&n,&m) != EOF) { tot = 0; memset(anc,0,sizeof(anc)); memset(dep,0,sizeof(dep)); for (i = 1; i <= n; i++) { head[i] = 0; sum[i] = 0; } for (i = 1; i < n; i++) { scanf("%d%d",&u,&v); addedge(u,v); addedge(v,u); } dfs1(1); for (i = 1; i <= m; i++) { scanf("%d%d",&u,&v); sum[u]++; sum[v]++; sum[lca(u,v)] -= 2; } dfs2(1); ans = 0; for (i = 2; i <= n; i++) { if (sum[i] == 0) ans += m; if (sum[i] == 1) ans++; } printf("%d ",ans); } return 0; }