Tree
POJ - 1741题意:统计树上有多少点对距离不超过k。
树分治模板题
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 const int maxn = 1e5 + 10; 7 const int inf = 0x3f3f3f3f; 8 int n, k; 9 int rt, snode, ans; 10 int size[maxn], d[maxn], vis[maxn]; 11 int dep[maxn]; 12 int maxson[maxn]; 13 struct Edge{ 14 int v, w, nxt; 15 Edge(int v = 0, int w = 0, int nxt = 0) : v(v), w(w), nxt(nxt) {} 16 }e[maxn<<1]; 17 int head[maxn], cnt; 18 void init(){ 19 cnt = 0; 20 memset(head, -1, sizeof head); 21 } 22 void add(int u, int v, int w){ 23 e[cnt] = Edge(v, w, head[u]); 24 head[u] = cnt++; 25 } 26 27 void getrt(int u, int f){ 28 size[u] = 1; 29 maxson[u] = 0; 30 for(int i = head[u]; ~i; i = e[i].nxt){ 31 int v = e[i].v; 32 if(v == f || vis[v]) continue; 33 getrt(v, u); 34 size[u] += size[v]; 35 maxson[u] = max(maxson[u], size[v]); 36 } 37 maxson[u] = max(maxson[u], snode - size[u]); 38 if(maxson[u] < maxson[rt]) rt = u; 39 } 40 void getdep(int u, int f){ 41 dep[++dep[0]] = d[u]; 42 for(int i = head[u]; ~i; i = e[i].nxt){ 43 int v = e[i].v; 44 if(v == f || vis[v]) continue; 45 d[v] = d[u] + e[i].w; 46 getdep(v, u); 47 } 48 } 49 int cal(int u, int w){ 50 d[u] = w; 51 dep[0] = 0; 52 getdep(u, 0); 53 sort(dep + 1, dep + 1 + dep[0]); 54 int sum = 0; 55 int l = 1, r = dep[0]; 56 while(l < r){ 57 if(dep[l] + dep[r] <= k) { 58 sum += r - l; 59 l++; 60 }else r--; 61 } 62 return sum; 63 } 64 65 void solve(int u){ 66 vis[u] = 1; 67 ans += cal(u, 0); 68 for(int i = head[u]; ~i; i = e[i].nxt){ 69 int v = e[i].v; 70 if(vis[v]) continue; 71 ans -= cal(v, e[i].w); 72 rt = 0; 73 snode = size[v]; 74 getrt(v, u); 75 solve(rt); 76 } 77 } 78 79 int main(){ 80 while(scanf("%d %d", &n, &k) && (n || k)){ 81 init(); 82 memset(vis, 0, sizeof vis); 83 int u, v, w; 84 for(int i = 1; i < n; i++){ 85 scanf("%d %d %d", &u, &v, &w); 86 add(u, v, w); 87 add(v, u, w); 88 } 89 rt = ans = 0; 90 snode = n; 91 maxson[0] = inf; 92 getrt(1, 0); 93 solve(rt); 94 printf("%d ", ans); 95 } 96 97 }
Distance in Tree
题意:统计树上有多少点对距离为k。
点分治
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int maxn = 50010; 4 const int maxd = 510; 5 const int inf = 0x3f3f3f3f; 6 int n, k; 7 int ct[maxd], temp[maxd]; 8 int rt, snode, ans; 9 int size[maxn]; 10 int maxson[maxn], vis[maxn]; 11 12 struct Edge{ 13 int v, nxt; 14 Edge(int v = 0, int nxt = 0) : v(v), nxt(nxt) {} 15 }e[maxn<<1]; 16 int head[maxn], cnt; 17 void init(){ 18 cnt = 0; 19 memset(head, -1, sizeof head); 20 } 21 void add(int u, int v){ 22 e[cnt] = Edge(v, head[u]); 23 head[u] = cnt++; 24 } 25 26 void getrt(int u, int f){ 27 size[u] = 1; 28 maxson[u] = 0; 29 for(int i = head[u]; ~i; i = e[i].nxt){ 30 int v = e[i].v; 31 if(vis[v] || v == f) continue; 32 getrt(v, u); 33 size[u] += size[v]; 34 maxson[u] = max(maxson[u], size[v]); 35 } 36 maxson[u] = max(maxson[u], snode - size[u]); 37 if(maxson[u] < maxson[rt]) rt = u; 38 } 39 40 void dfs(int u, int f, int d){ 41 if(d > k) return ; 42 ans += ct[k - d]; 43 ++temp[d]; 44 for(int i = head[u]; ~i; i = e[i].nxt){ 45 int v = e[i].v; 46 if(vis[v] || v == f) continue; 47 dfs(v, u, d + 1); 48 } 49 } 50 51 void cal(int u){ 52 memset(ct, 0, sizeof ct); 53 ct[0] = 1; 54 for(int i = head[u]; ~i; i = e[i].nxt){ 55 int v = e[i].v; 56 if(vis[v]) continue; 57 memset(temp, 0, sizeof temp); 58 dfs(v, u, 1); 59 for(int i = 1; i <= k ; i++) ct[i] += temp[i]; 60 } 61 } 62 63 void divide(int u){ 64 getrt(u, u); 65 u = rt; 66 vis[u] = 1; 67 cal(u); 68 for(int i = head[u]; ~i; i = e[i].nxt){ 69 int v = e[i].v; 70 if(vis[v]) continue; 71 rt = 0; 72 snode = size[v]; 73 divide(v); 74 } 75 } 76 77 int main(){ 78 while(scanf("%d %d", &n, &k) != EOF){ 79 int u, v; 80 init(); 81 memset(vis, 0, sizeof vis); 82 for(int i = 1; i < n; i++){ 83 scanf("%d %d", &u, &v); 84 add(u, v); 85 add(v, u); 86 } 87 rt = ans = 0; 88 snode = n; 89 maxson[0] = inf; 90 divide(1); 91 printf("%d ", ans); 92 93 } 94 }
树DP
题解:here
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int maxn = 50010; 4 const int maxd = 510; 5 int dp[maxn][maxd]; 6 struct Edge{ 7 int v, nxt; 8 Edge(int v = 0, int nxt = 0) : v(v), nxt(nxt){} 9 }e[maxn<<1]; 10 int head[maxn], cnt; 11 void init(){ 12 cnt = 0; 13 memset(head, -1, sizeof head); 14 } 15 void add(int u, int v){ 16 e[cnt] = Edge(v, head[u]); 17 head[u] = cnt++; 18 } 19 20 int n, k; 21 22 void dp1(int u, int f){ 23 dp[u][0] = 1; 24 for(int i = 1; i <= k; i++) dp[u][i] = 0; 25 for(int i = head[u]; ~i; i = e[i].nxt){ 26 int v = e[i].v; 27 if(v == f) continue; 28 dp1(v, u); 29 for(int i = 1; i <= k; i++) dp[u][i] += dp[v][i - 1]; 30 } 31 } 32 void dp2(int u, int f){ 33 for(int i = head[u]; ~i; i = e[i].nxt){ 34 int v = e[i].v; 35 if(v == f) continue; 36 for(int i = k; i >= 1; i--){ 37 dp[v][i] += dp[u][i - 1]; 38 if(i > 1) dp[v][i] -= dp[v][i - 2]; 39 } 40 dp2(v, u); 41 } 42 } 43 44 int main(){ 45 while(scanf("%d %d", &n, &k) != EOF){ 46 init(); 47 int u, v; 48 for(int i = 1; i < n; i++){ 49 scanf("%d %d", &u, &v); 50 add(u, v); add(v, u); 51 } 52 dp1(1, 0); 53 dp2(1, 0); 54 long long ans = 0; 55 for(int i = 1; i <= n; i++){ 56 ans += dp[i][k]; 57 } 58 printf("%lld ", ans / 2); 59 } 60 }
聪聪可可
HYSBZ - 2152题意:统计树上有多少点对距离为3的倍数。
两种写法
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 const int maxn = 1e5 + 10; 7 const int inf = 0x3f3f3f3f; 8 int n, k; 9 int rt, snode, ans; 10 int size[maxn], d[maxn], vis[maxn]; 11 int dep[maxn]; 12 int maxson[maxn]; 13 struct Edge{ 14 int v, w, nxt; 15 Edge(int v = 0, int w = 0, int nxt = 0) : v(v), w(w), nxt(nxt) {} 16 }e[maxn<<1]; 17 int head[maxn], cnt; 18 void init(){ 19 cnt = 0; 20 memset(head, -1, sizeof head); 21 } 22 void add(int u, int v, int w){ 23 e[cnt] = Edge(v, w, head[u]); 24 head[u] = cnt++; 25 } 26 27 void getrt(int u, int f){ 28 size[u] = 1; 29 maxson[u] = 0; 30 for(int i = head[u]; ~i; i = e[i].nxt){ 31 int v = e[i].v; 32 if(v == f || vis[v]) continue; 33 getrt(v, u); 34 size[u] += size[v]; 35 maxson[u] = max(maxson[u], size[v]); 36 } 37 maxson[u] = max(maxson[u], snode - size[u]); 38 if(maxson[u] < maxson[rt]) rt = u; 39 } 40 void getdep(int u, int f){ 41 dep[d[u]]++; 42 for(int i = head[u]; ~i; i = e[i].nxt){ 43 int v = e[i].v; 44 if(v == f || vis[v]) continue; 45 d[v] = (d[u] + e[i].w) % 3; 46 getdep(v, u); 47 } 48 } 49 int cal(int u, int w){ 50 dep[0] = dep[1] = dep[2] = 0; 51 d[u] = w; 52 getdep(u, 0); 53 return dep[0] * dep[0] + dep[1] * dep[2] * 2; 54 } 55 56 void solve(int u){ 57 vis[u] = 1; 58 ans += cal(u, 0); 59 for(int i = head[u]; ~i; i = e[i].nxt){ 60 int v = e[i].v; 61 if(vis[v]) continue; 62 ans -= cal(v, e[i].w); 63 rt = 0; 64 snode = size[v]; 65 getrt(v, u); 66 solve(rt); 67 } 68 } 69 70 int main(){ 71 while(scanf("%d", &n) != EOF){ 72 init(); 73 memset(vis, 0, sizeof vis); 74 int u, v, w; 75 for(int i = 1; i < n; i++){ 76 scanf("%d %d %d", &u, &v, &w); 77 w %= 3; 78 add(u, v, w); 79 add(v, u, w); 80 } 81 rt = ans = 0; 82 snode = n; 83 maxson[0] = inf; 84 getrt(1, 0); 85 solve(rt); 86 int temp = n *n; 87 int g = __gcd(temp, ans); 88 printf("%d/%d ", ans/g, temp/ g); 89 } 90 91 }
1 #include <bits/stdc++.h> 2 using namespace std; 3 const int maxn = 20010; 4 const int maxd = 4; 5 const int inf = 0x3f3f3f3f; 6 int n, k; 7 int ct[maxd], temp[maxd]; // 开太大会TLE...反复memset耗时严重... 8 int rt, snode, ans; 9 int size[maxn]; 10 int maxson[maxn], vis[maxn]; 11 12 struct Edge{ 13 int v, w, nxt; 14 Edge(int v = 0, int w = 0, int nxt = 0) : v(v), w(w), nxt(nxt) {} 15 }e[maxn<<1]; 16 int head[maxn], cnt; 17 void init(){ 18 cnt = 0; 19 memset(head, -1, sizeof head); 20 } 21 void add(int u, int v, int w){ 22 e[cnt] = Edge(v, w, head[u]); 23 head[u] = cnt++; 24 } 25 26 void getrt(int u, int f){ 27 size[u] = 1; 28 maxson[u] = 0; 29 for(int i = head[u]; ~i; i = e[i].nxt){ 30 int v = e[i].v; 31 if(vis[v] || v == f) continue; 32 getrt(v, u); 33 size[u] += size[v]; 34 maxson[u] = max(maxson[u], size[v]); 35 } 36 maxson[u] = max(maxson[u], snode - size[u]); 37 if(maxson[u] < maxson[rt]) rt = u; 38 } 39 40 void dfs(int u, int f, int d){ 41 if(d == 0) ans += ct[0]; //特殊处理 42 else ans += ct[k - d]; 43 ++temp[d]; 44 for(int i = head[u]; ~i; i = e[i].nxt){ 45 int v = e[i].v; 46 if(vis[v] || v == f) continue; 47 dfs(v, u, (d + e[i].w) % 3); 48 } 49 } 50 51 void cal(int u){ 52 memset(ct, 0, sizeof ct); 53 ct[0] = 1; 54 for(int i = head[u]; ~i; i = e[i].nxt){ 55 int v = e[i].v; 56 if(vis[v]) continue; 57 memset(temp, 0, sizeof temp); 58 dfs(v, u, e[i].w); 59 for(int i = 0; i < k ; i++) ct[i] = ct[i] + temp[i]; 60 } 61 } 62 63 void divide(int u){ 64 getrt(u, u); 65 u = rt; 66 vis[u] = 1; 67 cal(u); 68 for(int i = head[u]; ~i; i = e[i].nxt){ 69 int v = e[i].v; 70 if(vis[v]) continue; 71 rt = 0; 72 snode = size[v]; 73 divide(v); 74 } 75 } 76 77 int main(){ 78 //freopen("in.txt", "r", stdin); 79 //freopen("out1.txt", "w", stdout); 80 while(scanf("%d", &n) != EOF){ 81 int u, v, w; 82 k = 3; 83 init(); 84 memset(vis, 0, sizeof vis); 85 for(int i = 1; i < n; i++){ 86 scanf("%d %d %d", &u, &v, &w); 87 w %= 3; 88 add(u, v, w); 89 add(v, u, w); 90 } 91 rt = ans = 0; 92 snode = n; 93 maxson[0] = inf; 94 divide(1); 95 ans = ans * 2 + n; 96 int temp = n * n; 97 int g = __gcd(ans, temp); 98 printf("%d/%d ", ans/g, temp/g); 99 100 } 101 }