虚树小记
简明地讲解了虚树的构建过程。
引入
比如给你一棵树,然后有多次询问,每次询问都给出树上若干个点 \(|m_i|\) 个(称为关键点),然后对这些点的性质作查询(比如求它们的最远点对什么的)
注意到如果每次询问都扫描整棵树时间复杂度会爆炸,但是如果能够保证总查询扫描的点与 \(\sum |m_i|\) 为同一个数量级就没问题了,故考虑构建虚树。
构建
考虑使用单调栈构建——核心是:使用单调栈维护一条树链。
-
首先,对关键点按照 \(dfs\) 序(\(dfn\))排序。
-
方便起见,先将点 \(1\) 入栈。
-
若当前点与栈顶节点 \(LCA\)(记为 \(anc\)) 为栈顶节点,说明二者在同一条链,故将当前点入栈即可。
-
否则,一直执行出栈操作。直到栈顶下一节点的 \(dfn\) 小于等于 \(anc\)。(注意在出栈的过程中弹出的点与栈顶节点建边)
-
如果等于,说明 \(anc\) 在栈中,对 \(anc\) 与栈顶建边然后弹出栈顶。
-
否则,\(anc\) 与栈顶建边并且将栈顶修改为 \(anc\)(等价于将栈顶弹出然后将 \(anc\) 入栈)
-
-
最后,将还在栈中的节点(构成一条树链)连边即可。
可以结合代码理解:
void build_tree(){
sort(pt+1, pt+1+m, cmp);
cur=1, stk[1]=1, g[1].clear();
rep(i,1,m){
int u=pt[i];
if(u==1) continue;
int anc=lca(u, stk[cur]);
if(anc!=stk[cur]){
while(id[stk[cur-1]]>id[anc]) Add(stk[cur-1], stk[cur]), cur--;
if(anc==stk[cur-1]) Add(anc, stk[cur]), cur--;
else g[anc].clear(), Add(anc, stk[cur]), stk[cur]=anc;
}
stk[++cur]=u, g[u].clear();
}
rep(i,1,cur-1) Add(stk[i], stk[i+1]);
}
性质
为了保留原树的性质,虚树自然需要保证点之间的祖孙后代关系不变。
可以发现,构建的过程中会引入一些非关键点(\(LCA\)),但注意到每次加入点最多引入一个 \(anc\),所以加入的点规模仍然是 \(O(m)\) 级别的。
例题
https://codeforces.com/gym/102784/problem/K
分析
求每种颜色的点最远点对长度。
事实上就是求对应的虚树的直径,但不能够将非关键点纳入贡献中。
因此可以进行两次 \(dfs\),第一次任意挑一个关键点并找到离它最远的关键点 \(U\),第二次从 \(U\) 出发找到离它最远的关键点的距离即可。
代码:
// Problem: K. Territorial Tarantulas
// Contest: Codeforces - UTPC Contest 10-23-20 Div. 1
// URL: https://codeforces.com/gym/102784/problem/K
// Memory Limit: 256 MB
// Time Limit: 3000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#pragma GCC optimize("O3")
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define debug(x) cerr << #x << ": " << (x) << endl
#define pb push_back
#define eb emplace_back
#define set0(a) memset(a,0,sizeof(a))
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define dwn(i,a,b) for(int i=(a);i>=(b);i--)
#define ceil(a,b) (a+(b-1))/(b)
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define ll_INF 0x7f7f7f7f7f7f7f7f
#define x first
#define y second
using pii = pair<int, int>;
using pdd = pair<double, double>;
using vi = vector<int>;
using vvi = vector<vi>;
using vb = vector<bool>;
using vpii = vector<pii>;
using ll = long long;
using ull = unsigned long long;
inline void read(int &x){
int s=0; x=1;
char ch=getchar();
while(ch<'0' || ch>'9') {if(ch=='-')x=-1;ch=getchar();}
while(ch>='0' && ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
x*=s;
}
const int N=4e5+5, M=N<<1;
struct Edge{
int to, next;
}e[M];
int h[N], tot;
void add(int u, int v){
e[tot].to=v, e[tot].next=h[u], h[u]=tot++;
}
int n, m;
int pt[N];
int fa[N], son[N], sz[N], dep[N];
int cnt, id[N], top[N];
void dfs1(int u, int father, int depth){
fa[u]=father, sz[u]=1, dep[u]=depth;
for(int i=h[u]; ~i; i=e[i].next){
int go=e[i].to;
if(go==father) continue;
dfs1(go, u, depth+1);
sz[u]+=sz[go];
if(sz[go]>sz[son[u]]) son[u]=go;
}
}
void dfs2(int u, int t){
id[u]=++cnt, top[u]=t;
if(!son[u]) return;
dfs2(son[u], t);
for(int i=h[u]; ~i; i=e[i].next){
int go=e[i].to;
if(go==fa[u] || go==son[u]) continue;
dfs2(go, go);
}
}
int lca(int u, int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u, v);
u=fa[top[u]];
}
return dep[u]<dep[v]? u: v;
}
int dis(int u, int v){
return dep[u]+dep[v]-(dep[lca(u, v)]<<1);
}
vpii g[N];
int cur_col;
int cw[N];
void Add(int u, int v){
int w=dis(u, v);
g[u].pb({v, w}), g[v].pb({u, w});
}
int stk[N], cur;
bool cmp(int a, int b){
return id[a]<id[b];
}
void build_tree(){
sort(pt+1, pt+1+m, cmp);
cur=1, stk[1]=1, g[1].clear();
rep(i,1,m){
int u=pt[i];
if(u==1) continue;
int anc=lca(u, stk[cur]);
if(anc!=stk[cur]){
while(id[stk[cur-1]]>id[anc]) Add(stk[cur-1], stk[cur]), cur--;
if(anc==stk[cur-1]) Add(anc, stk[cur]), cur--;
else g[anc].clear(), Add(anc, stk[cur]), stk[cur]=anc;
}
stk[++cur]=u, g[u].clear();
}
rep(i,1,cur-1) Add(stk[i], stk[i+1]);
}
vi col[N];
int C;
int D[N];
void dfs(int u, int fa=-1, int dist=0){
D[u]=dist;
for(auto &[go, w]: g[u]){
if(go==fa) continue;
dfs(go, u, dist+w);
}
}
int main(){
memset(h, -1, sizeof h);
cin>>n>>C;
rep(i,1,n){
read(cw[i]);
col[cw[i]].pb(i);
}
rep(i,1,n-1){
int u, v; read(u), read(v);
add(u, v), add(v, u);
}
dfs1(1, -1, 1), dfs2(1, 1);
rep(c,1,C){
cur_col=c;
m=0;
for(auto i: col[c]) pt[++m]=i;
build_tree();
rep(i,1,m) D[pt[i]]=0;
dfs(pt[1], -1, 0);
int U, val=0;
rep(i,1,m){
int u=pt[i];
if(D[u]>=val && cw[u]==c) val=D[u], U=u;
}
dfs(U, -1, 0);
int res=0;
rep(i,1,m) res=max(res, D[pt[i]]);
cout<<res<<endl;
}
return 0;
}