树上边差分,需要注意的是cnt[1] == 0 不能算的问题(他实际并未有连0号点)
1 #include <bits/stdc++.h> 2 #define read read() 3 #define up(i,l,r) for(register int i = (l); i <= (r); ++i) 4 #define down(i,l,r) for(register int i = (l); i >= (r); i--) 5 using namespace std; 6 int read{ 7 int x = 0, f = 1; char ch = getchar(); 8 while(ch < 48 || ch > 57) {if(ch == '-') f = -1; ch = getchar();} 9 while(ch >= 48&& ch <=57) {x = 10 * x + ch - 48; ch = getchar();} 10 return x * f; 11 } 12 //----------------------------------------------------------------- 13 const int N = 1e5+7; 14 int n,m; 15 16 struct edge{ 17 int v,nxt; 18 }e[N<<1];int tot,head[N]; 19 20 void add(int u,int v){e[++tot] = (edge){v,head[u]}; head[u] = tot; } 21 //----------------------------------------------------------------- 22 23 int fa[N],size[N],top[N],dep[N]; int cnt[N],ans; 24 25 void dfs(int u){ 26 size[u] = 1; 27 top[u] = u; 28 dep[u] = dep[fa[u]] + 1; 29 int h_size = 0, h_id = 0; 30 for(int i = head[u]; i; i = e[i].nxt){ 31 int v = e[i].v; 32 if(v == fa[u]) continue; 33 fa[v] = u; 34 dfs(v); 35 size[u] += size[v]; 36 if(size[v] > h_size) h_size = size[v], h_id = v; 37 } 38 if(h_id) top[h_id] = u; 39 } 40 41 int find(int u){ 42 if(top[u] == u) return u; 43 top[u] = find(top[u]); 44 return top[u]; 45 } 46 47 int lca(int u,int v){ 48 if(find(u) != find(v)){ 49 if(dep[top[u]] > dep[top[v]]) return lca(fa[top[u]],v); 50 else return lca(u,fa[top[v]]); 51 } 52 return dep[u] > dep[v] ? v : u; 53 } 54 //----------------------------------------------------------------- 55 56 void dfs_edge(int u,int f){ 57 58 for(int i = head[u]; i; i = e[i].nxt){ 59 int v = e[i].v; 60 if(v == f) continue; 61 dfs_edge(v,u); 62 cnt[u] += cnt[v]; 63 } 64 //if(u == 1) return;//!! 65 if(cnt[u] == 0) ans += m; 66 else if(cnt[u] == 1) ++ans; 67 } 68 69 void work(){ 70 dfs(1); 71 cnt[1] = -1; 72 up(i,1,m){ 73 int u = read,v = read; 74 ++cnt[u]; ++cnt[v]; cnt[(lca(u,v))] -= 2; 75 } 76 dfs_edge(1,0); 77 printf("%d ",ans); 78 } 79 80 void readdata(){ 81 n = read; m = read; 82 int u,v; 83 up(i,1,n-1){ 84 u = read; v = read; 85 add(u,v); 86 add(v,u); 87 } 88 } 89 90 int main(){ 91 freopen("input.txt","r",stdin); 92 readdata(); 93 work(); 94 return 0; 95 }