题:https://ac.nowcoder.com/acm/contest/7831/H
题意:给定n个点的树,每个节点都有颜色;
- 询问[Q,y]:求把所有y颜色的节点联通起来用的最少的边数。
- 更新[U,x,y]:将x节点的颜色改为y;
分析:
- 对于询问,我们可以假象为有俩个点作为总边,剩余颜色的点就连上这条总边;
- 关键在于如何确定这条总边,让其他相同颜色的点x连上这条总边的点y,dis(x,y)不与其他“边”重合;
- 这条总边就可以用最小的dfs序的点和最大的dfs序的点连接的边,就可以保证statement.2;
- 关于更新答案可以发现,插入一个点x的影响只与和x的dfs序相邻的点(l和r)有关,+dis(l,x)+dis(r,x)-dis(l,r)这部分就可以用重载排序的set来维护;
- 删除就上面表达式去反号,简单看出统计的答案的“分支”是2倍的,而总边只有1倍,那么查询就是(ans[col[y]]+总边长度)/2;
#include<bits/stdc++.h> using namespace std; #define pb push_back #define MP make_pair #define lson root<<1,l,midd #define rson root<<1|1,midd+1,r typedef long long ll; const int mod=1e9+7; const int M=1e5+5; const int inf=0x3f3f3f3f; const ll INF=1e18; vector<int>g[M]; int tot; int dfn[M],f[M],deep[M],sz[M],son[M],top[M],ans[M],col[M]; struct cmp{ bool operator() (const int &x,const int &y)const{ return dfn[x]<dfn[y]; } }; set<int,cmp>st[M]; void dfs1(int u){///cout<<u<<"!!"<<endl; deep[u]=deep[f[u]]+1; sz[u]=1; for(auto v:g[u]){ if(v!=f[u]){ f[v]=u; dfs1(v); sz[u]+=sz[v]; if(!son[u]||sz[v]>sz[son[u]]) son[u]=v; } } } void dfs2(int u,int tp){ dfn[u]=++tot; top[u]=tp; if(son[u]) dfs2(son[u],tp); for(auto v:g[u]){ if(v!=f[u]&&v!=son[u]) dfs2(v,v); } } int LCA(int u,int v){ while(top[u]!=top[v]){ if(deep[top[u]]<deep[top[v]]) swap(u,v); u=f[top[u]]; } if(deep[u]>deep[v]) swap(u,v); return u; } int dis(int u,int v){ return deep[u]+deep[v]-2*deep[LCA(u,v)]; } void add(int x){ st[col[x]].insert(x); auto it=st[col[x]].find(x); int l=0,r=0; ++it; if(it!=st[col[x]].end()){ r=*it; } it--; if(it!=st[col[x]].begin()){ it--; l=*it; } if(l) ans[col[x]]+=dis(l,x); if(r) ans[col[x]]+=dis(r,x); if(l&&r) ans[col[x]]-=dis(l,r); } void del(int x){ auto it=st[col[x]].find(x); it++; int l=0,r=0; if(it!=st[col[x]].end()){ r=*it; } it--; if(it!=st[col[x]].begin()){ it--; l=*it; } if(l) ans[col[x]]-=dis(l,x); if(r) ans[col[x]]-=dis(r,x); if(l&&r) ans[col[x]]+=dis(l,r); st[col[x]].erase(x); } char s[2]; int main(){ int n; scanf("%d",&n); for(int u,v,i=1;i<n;i++){ scanf("%d%d",&u,&v); g[u].pb(v); g[v].pb(u); } dfs1(1); dfs2(1,1); for(int i=1;i<=n;i++){ scanf("%d",&col[i]); add(i); } int m; scanf("%d",&m); while(m--){ scanf("%s",s); int x,y; if(s[0]=='U'){ scanf("%d%d",&x,&y); del(x); col[x]=y; add(x); } else{ scanf("%d",&y); if(st[y].size()==0) puts("-1"); else printf("%d ",(ans[y]+dis(*st[y].begin(),*st[y].rbegin()))/2); } } return 0; }