题意
给出一棵 $N$($N le 10^5$)个点的树,有点权和边权。回答 $q$($q le 10^5$) 组询问:
($u, r$):距离节点 $u$ 不超过 $r$ 的点中权值最大的点
输出点的编号,如有多解,输出最小编号。
Time Limit: 每个测试点 3s
做法
离线。树的点分治。
以树的重心为根,将无根树转化为有根树。
对于询问 ($u,r$),我们把「与 $u$ 的距离不超过 $r$ 的点」按「从 $u$ 到该点是否要经过根节点」分成两类。
问题化为
求从 $u$ 出发, 经根节点,移动距离不超过 $r$ 所能到达的点中权值最大的那个点的编号。
-
用
std::map<long long, int> opt
维护 <到根节点的距离,节点编号>(key-value pair)。
opt[d]
表示到当前分治的根节点距离为d
的所有点中最优的那个点。
按 key 从小到大的顺序,对于相邻两 key-value pair 用前一 key 的 value 更新后一 key 的 value 。复杂度 $O(nlog n)$($n$ 是树的节点数,下同)
-
对于询问 ($u,r$),设 $u$ 到根的距离为 $d_u$,以 $r-d_u$ 为参数,用
std::map::upper_bound()
查询,更新该询问的答案。复杂度 $O(sumlimits_{u ext{ in the tree}}|{(u,r)}| imes log n)$ 。
总复杂度为 $O((m+n)log^2 n)$ 。
Implementation
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
vector<pair<int,int>> g[N], q[N];
int w[N];
bool used[N];
int size[N];
int tot;
pair<int,int> centroid(int u, int f){
size[u]=1;
int ma=0;
pair<int,int> res={INT_MAX, 0};
for(auto e: g[u]){
int v=e.first;
if(v!=f && !used[v]){
res=min(res, centroid(v, u));
size[u]+=size[v];
ma=max(ma, size[v]);
}
}
ma=max(ma, tot-size[u]);
res=min(res, {ma, u});
return res;
}
int better(int u, int v){
return w[u]>w[v] || (w[u]==w[v] && u<v) ? u: v;
}
map<long long, int> opt;
void dfs(int u, int f, long long d){
auto it=opt.find(d);
if(it==opt.end())
opt[d]=u;
else it->second=better(u, it->second);
for(auto e: g[u]){
int v=e.first, w=e.second;
if(v!=f && !used[v])
dfs(v, u, d+w);
}
}
int res[N];
void upd(int u, int f, long long d){
for(auto query: q[u]){
int r=query.first, id=query.second;
auto it=opt.upper_bound(r-d);
if(it!=opt.begin()){
res[id]=better(res[id], (--it)->second); //error-prone
}
}
for(auto e: g[u]){
int v=e.first, w=e.second;
if(v!=f && !used[v])
upd(v, u, d+w);
}
}
void DC(int u){
int root=centroid(u, u).second;
opt.clear(); // error-prone
dfs(root, root, 0);
for(auto it=opt.begin(); ;){
int tmp=it->second;
if(++it!=opt.end())
it->second=better(it->second, tmp);
else break;
}
upd(root, root, 0);
used[root]=true;
int _tot = tot;
for(auto e: g[root]){
int v=e.first;
if(!used[v]){
if(size[v] < size[root]) tot = v;
else tot = _tot - size[root];
DC(v);
}
}
int main(){
int n;
scanf("%d", &n);
for(int i=1; i<=n; i++)
scanf("%d", w+i);
for(int i=1; i<n; i++){
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
g[u].push_back({v, w});
g[v].push_back({u, w});
}
int m;
scanf("%d", &m);
for(int i=0; i<m; i++){
int u, r;
scanf("%d%d", &u, &r);
q[u].push_back({r, i});
}
tot = n;
DC(1);
for(int i=0; i<m; i++)
printf("%d
", res[i]);
return 0;
}