Weak Pair
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 2081 Accepted Submission(s): 643
Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weakif
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k.
Can you find the number of weak pairs in the tree?
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k.
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
思路:将公式au*av<=k变换为 au<=k/av。 在遍历结点v的过程中,统计au<=k/av的节点u的个数。
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std; typedef long long LL; const int MAXN = 100005; int n, bit[MAXN+MAXN], deg[MAXN], len, vis[MAXN]; LL val[MAXN], k; vector<int> arc[MAXN]; LL res; LL buf[MAXN+MAXN]; void add(int i, int x) { while(i < MAXN + MAXN) { bit[i] += x; i += (i & (-i)); } } int sum(int i) { int s = 0; while(i > 0) { s += bit[i]; i -= (i & (-i)); } return s; } void dfs(int u) { int id = lower_bound(buf, buf + len, k / val[u]) - buf + 1; int pre = sum(id); for(int i = 0, size = arc[u].size(); i < size; i++) { dfs(arc[u][i]); } int post = sum(id); res += (post - pre); int index = lower_bound(buf, buf + len, val[u]) - buf + 1; add(index, 1); } int main() { // freopen("input.in", "r", stdin); int T; scanf("%d", &T); while(T--) { scanf("%d %I64d", &n, &k); memset(bit, 0, sizeof(bit)); memset(deg, 0, sizeof(deg)); len = 0; res = 0; for(int i = 1; i <= n; i++) arc[i].clear(); for(int i = 1; i <= n; i++) { scanf("%I64d", &val[i]); buf[len++] = val[i]; buf[len++] = k / val[i]; } for(int i = 0; i < n - 1; i++) { int u, v; scanf("%d %d", &u, &v); arc[u].push_back(v); deg[v]++; } sort(buf, buf + len); for(int i = 1; i <= n; i++) { if(deg[i] == 0) { dfs(i); break; } } printf("%I64d ", res); } return 0; }