Description
Crystal家有一棵树。树上有(n)个节点,编号由(1)到(n)((1)号点是这棵树的根),两点之间距离为1当且仅当它们直接相连。每个点都有各自的权值,第(i)号节点的权值为(value_i)。Crystal现在指着编号为(x)的点问,在以点(x)为根的子树中,与点(x)距离大于等于(k)的所有点的点权和是多少。
Input Format
第(1)行两个整数(n,Q),分别表示树上点的个数和Crystal有(Q)个问题。
第(2)行,(n)个整数,分别表示(1)至(n)号点的点权。
接下来的(n - 1)行,每行两个整数(u,v),表示编号为(u)的点与编号为(v)的点直接相连。
接下来(Q)行,每行两个整数(x,k),表示询问在以点(x)为根的子树中,与点(x)距离大于等于为(k)的所有点的点权和是多少。
Output Format
(Q)行,每行一个整数,表示对第(i)个询问的回答。
Sample Input
5 3
1 1 1 1 1
1 2
1 3
3 4
4 5
1 3
1 2
1 1
Sample Output
1
2
4
Hints
对于(30\%)的数据,保证(n le 1000, k < 1, Q le 1000)。
对于(60\%)的数据,保证(n le 1000, k < 1000, Q le 1000)
对于(80\%)的数据,保证(n le 1000, k < 1000, Q le 1000000);
对于最后(20\%)的数据,保证(n le 50000, k < 100, Q le 1000000);
对于(100\%)的数据,保证所有输入数据均为非负整数,且在(int)范围内。
这题(O(NK))的做法不难想(用总的减去小于(K)的),现在假设(N,K)同级怎么做。
首先考虑离线做法,我们可以考虑按照询问最深的深度从小到大一层层加点,答案还是用总的减去小于(K)的。
再考虑在线所做法,我们可以先处理出dfs序和子树和,然后对于树的每层开一个vector,vector中记录该层点的编号,按dfs序排序。对于每个询问(x,k),我们只需要跳到(dep[x]+k)层的vector中,找到在(x)子树中的点,且一定是段连续区间,二分即可。现在只需要对该区间求子树和的和即可。
代码是(O(NK))的
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
#define maxn (50010)
int cnt = 1,side[maxn],toit[maxn*2],next[maxn*2],val[maxn],N,Q,mxk;
int tx[maxn*20],tk[maxn*20],num[20],len; ll sum[maxn]; vector <ll> res[maxn];
inline int read()
{
char ch; int f = 1,ret = 0;
do ch = getchar(); while (!(ch >= '0'&&ch <= '9')&&ch != '-');
if (ch == '-') f = -1,ch = getchar();
do ret = ret*10+ch-'0',ch = getchar(); while (ch >= '0'&&ch <= '9');
return ret*f;
}
inline void add(int a,int b) { next[++cnt] = side[a]; side[a] = cnt; toit[cnt] = b; }
inline void ins(int a,int b) { add(a,b); add(b,a); }
inline void dfs(int now,int fa)
{
for (int i = 0;i <= mxk;++i) res[now].push_back(val[now]);
sum[now] = val[now];
for (int i = side[now];i;i = next[i])
{
if (toit[i] == fa) continue;
dfs(toit[i],now);
sum[now] += sum[toit[i]];
for (int j = 0;j < mxk;++j)
res[now][j+1] += res[toit[i]][j];
}
}
inline void print(ll a)
{
do num[++len] = a%10,a /= 10; while (a);
while (len) putchar('0'+num[len--]);
puts("");
}
int main()
{
//freopen("a.in","r",stdin);
//freopen("a.out","w",stdout);
N = read(); Q = read();
for (int i = 1;i <= N;++i) val[i] = read();
for (int i = 1;i < N;++i) ins(read(),read());
for (int i = 1;i <= Q;++i) tx[i] = read(),tk[i] = read(),mxk = max(mxk,tk[i]);
dfs(1,0);
// print(123456LL);
// print(0LL);
// print(12LL);
for (int i = 1;i <= Q;++i)
{
if (!tk[i]) //cout << sum[tx[i]] << endl;
print(sum[tx[i]]);
else //cout << sum[tx[i]]-res[tx[i]][tk[i]-1] << endl;
print(sum[tx[i]]-res[tx[i]][tk[i]-1]);
}
//fclose(stdin); fclose(stdout);
return 0;
}