树上路径的f(u,v)=路径上所有点的乘积。
树上每个点的权值都是由给定的k个素数组合而成的,如果f(u,v)是立方数,那么就说明f(u,v)是可行的方案。
问有多少种可行的方案。
f(u,v)可是用状态压缩来表示,因为最多只有30个素数, 第i位表示第i个素数的幂,那么每一位的状态只有0,1,2因为3和0是等价的,所以用3进制状态来表示就行了。
其他代码就是裸的树分。
另外要注意的是,因为counts函数没有统计只有一个点的情况,所以需要另外统计。
1 #pragma warning(disable:4996) 2 #pragma comment(linker, "/STACK:1024000000,1024000000") 3 #include <stdio.h> 4 #include <string.h> 5 #include <time.h> 6 #include <math.h> 7 #include <map> 8 #include <set> 9 #include <queue> 10 #include <stack> 11 #include <vector> 12 #include <bitset> 13 #include <algorithm> 14 #include <iostream> 15 #include <string> 16 #include <functional> 17 #include <unordered_map> 18 const int INF = 1 << 30; 19 typedef __int64 LL; 20 /* 21 用三进制的每一位表示第i个素数的幂 22 如果幂都是0,那么说明是立方 23 */ 24 const int N = 50000 + 10; 25 std::vector<int> g[N]; 26 std::unordered_map<LL, int> mp; 27 struct Node 28 { 29 int sta[33]; 30 }node[N]; 31 LL prime[33]; 32 std::vector<Node> dist; 33 int n, k; 34 int size[N], vis[N], total, root, mins; 35 LL _3bit[33]; 36 void init() 37 { 38 _3bit[0] = 1; 39 for (int i = 1;i <= 32;++i) 40 _3bit[i] = _3bit[i - 1] * 3; 41 } 42 void getRoot(int u, int fa) 43 { 44 int maxs = 0; 45 size[u] = 1; 46 for (int i = 0;i < g[u].size();++i) 47 { 48 int v = g[u][i]; 49 if (v == fa || vis[v]) continue; 50 getRoot(v, u); 51 size[u] += size[v]; 52 maxs = std::max(maxs, size[v]); 53 } 54 maxs = std::max(maxs, total - size[u]); 55 if (mins > maxs) 56 { 57 mins = maxs; 58 root = u; 59 } 60 } 61 void getDis(int u, int fa, Node d) 62 { 63 dist.push_back(d); 64 for (int i = 0;i < g[u].size();++i) 65 { 66 int v = g[u][i]; 67 if (v == fa || vis[v]) continue; 68 Node tmp; 69 for (int j = 0;j < k;++j) 70 tmp.sta[j] = (d.sta[j] + node[v].sta[j]) % 3; 71 getDis(v, u, tmp); 72 } 73 } 74 LL counts(int u)//计算经过u点的路径 75 { 76 mp.clear(); 77 mp[0] = 1; 78 LL ret = 0; 79 for (int i = 0;i < g[u].size();++i) 80 { 81 int v = g[u][i]; 82 if (vis[v]) continue; 83 dist.clear(); 84 getDis(v, u, node[v]); 85 for (int j = 0;j < dist.size();++j) 86 { 87 LL sta = 0; 88 for (int z = 0;z < k;++z) 89 { 90 sta += (3 - (node[u].sta[z] + dist[j].sta[z]) % 3) % 3 * _3bit[z]; 91 } 92 ret += mp[sta]; 93 } 94 for (int j = 0;j < dist.size();++j) 95 { 96 LL sta = 0; 97 for (int z = 0;z < k;++z) 98 sta += dist[j].sta[z] * _3bit[z]; 99 mp[sta]++; 100 } 101 } 102 return ret; 103 } 104 LL ans; 105 void go(int u) 106 { 107 vis[u] = true; 108 ans += counts(u); 109 for (int i = 0;i < g[u].size(); ++i) 110 { 111 int v = g[u][i]; 112 if (vis[v]) continue; 113 total = size[v]; 114 mins = INF; 115 getRoot(v, u); 116 go(root); 117 } 118 119 } 120 int main() 121 { 122 int u, v; 123 LL x; 124 init(); 125 while (scanf("%d%d", &n, &k) != EOF) 126 { 127 for (int i = 0;i < k;++i) 128 scanf("%I64d", &prime[i]); 129 ans = 0; 130 for (int i = 1;i <= n;++i) 131 { 132 g[i].clear(); 133 vis[i] = 0; 134 scanf("%I64d", &x); 135 memset(node[i].sta, 0, sizeof(node[i].sta)); 136 int tmp = 0; 137 for (int j = 0;j <k;++j) 138 { 139 140 while (x%prime[j] == 0 && x) 141 { 142 node[i].sta[j]++; 143 x /= prime[j]; 144 } 145 node[i].sta[j] %= 3; 146 if (node[i].sta[j] != 0)tmp++; 147 } 148 if (tmp == 0)//统计只有一个点的 149 ans++; 150 } 151 for (int i = 1;i < n;++i) 152 { 153 scanf("%d%d", &u, &v); 154 g[u].push_back(v); 155 g[v].push_back(u); 156 } 157 total = n; 158 mins = INF; 159 getRoot(1, -1); 160 go(root); 161 printf("%I64d ", ans); 162 } 163 return 0; 164 }