SP10707 COT2 - Count on a tree II
参考:树上莫队
树上莫队和普通的莫队差不多,只是把区间从普通的数组,转到欧拉序上(其实也就是括号序)
该问题求解的是 x,y 两个点之间的最短路径上,有多少个不同颜色的点
对于这个问题,分两种情况讨论
- \(lca(x,y)=x\ or\ y\),我们只需要记录\([b[x],b[y]]\),这段区间的贡献即可,(默认\(b[x]<b[y]\),不然交换)
- 若不成立,则记录\(e[x],b[y]\)的贡献
b[i] 表示 i 这个点 dfs 时开始的时间,e[i] 表示 i 这个点 dfs 结束的时间。
需要注意的几个点
- 块的大小为\(\frac{n}{\sqrt{m}}\)时最佳,本题为\(\frac{2n}{\sqrt{m}}\)
- 在写分块,求 belo 数组的时候,不要写假了,不然复杂度也是假的。
- 在比较两个点 x,y 的先后次序的时候,要用 b[x] 和 b[y] 来进行比较,不能用 dep 来直接进行比较。
- dfs 取根的时候,可以用 rand 来随机取根,这样可以防止出题人卡数据。
//Created by CAD
#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
int b[maxn],e[maxn],id[maxn<<1],a[maxn];
int belo[maxn<<1];
vector<int> g[maxn];
struct query{
int l,r,x,y,LCA,id;
bool operator<(const query& q){
if(belo[l]!=belo[q.l]) return belo[l]<belo[q.l];
return (belo[l]&1)?r<q.r:r>q.r;
}
}q[maxn];
int fa[maxn][30],dep[maxn],lg[maxn];
int tin=0;
void dfs(int x,int o){
fa[x][0]=o,dep[x]=dep[o]+1;
for(int i=1;i<=lg[dep[x]];++i)
fa[x][i]=fa[fa[x][i-1]][i-1];
id[b[x]=++tin]=x;
for(int i:g[x])
if(i!=o) dfs(i,x);
id[e[x]=++tin]=x;
}
inline int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]-1];
if(x==y) return x;
for(int k=lg[dep[x]]-1;k>=0;--k){
if(fa[x][k]!=fa[y][k])
x=fa[x][k],y=fa[y][k];
}
return fa[x][0];
}
int now=0;
int ans[maxn],f[maxn],cnt[maxn];
inline void modify(int &x){
int num=a[x];
if(f[x]){
cnt[num]--;
if(!cnt[num]) now--;
}
else{
if(!cnt[num]) now++;
cnt[num]++;
}
f[x]^=1;
}
unordered_map<int,int> vis;
int main(){
int n,m;
scanf("%d%d",&n,&m);
int blo=double(2*n)/sqrt(m);
for(int i=1;i<=n;++i){
scanf("%d",a+i);
if(!vis.count(a[i]))
vis[a[i]]=vis.size();
a[i]=vis[a[i]];
belo[i]=(i-1)/blo+1;
belo[i*2]=(i*2-1)/blo+1;
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
}
for(int i=1;i<n;++i){
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
for(int i=1;i<=m;++i){
int l,r;
scanf("%d%d",&l,&r);
if(b[l]>b[r])
swap(l,r);
int LCA=lca(l,r);
if(LCA==l) q[i]={b[l],b[r],l,r,LCA,i};
else q[i]={e[l],b[r],l,r,LCA,i};
}
sort(q+1,q+m+1);
int l=1,r=0;
for(int i=1;i<=m;++i){
int ql=q[i].l,qr=q[i].r,x=q[i].x,y=q[i].y,LCA=q[i].LCA;
while(l<ql) modify(id[l++]);
while(l>ql) modify(id[--l]);
while(r<qr) modify(id[++r]);
while(r>qr) modify(id[r--]);
int bj=0;
if(!f[LCA]) modify(LCA),bj=1;
ans[q[i].id]=now;
if(bj) modify(LCA);
}
for(int i=1;i<=m;++i)
printf("%d\n",ans[i]);
}