题解:
因为这是一棵树,所以我们可以假设根是1号节点
我们设S(x)表示1号节点到x号节点的路径上第t种花的数量,v(x)表示x号节点上是第几种花
那么,x号节点到y号节点的路径上第t种花的数量为S(x)+S(y)−S(lca)+[v(lca)==t]
所以我们就只需要求出S(x)就可以了
咋求呢?线段树!
所以我们可以对每一种花开一棵线段树,由于内存的限制,我们需要使用动态开点
问题是如何修改呢?
我们可以对于每一个节点,记录下这个节点的DFS序,就可以进行修改了!
修改时只需要把这个点DFS序的起始到结束中间的所有数都+1就可以了
#include<iostream> #include<cstdio> #include<map> using namespace std; #define l(x) t[x].l #define r(x) t[x].r #define a(x) t[x].add #define v(x) t[x].val struct Segment_Tree { int l,r,add,val; } t[10000010]; map<int,int> val; int n,q,ans,sum,size,T[100010],root[300010]; int tot,top,head[100010],to[200010],nxt[200010],dep[100010],fa[100010][20],dfn[200010],b[100010],e[100010]; void add(int u,int v) { nxt[++tot]=head[u],head[u]=tot,to[tot]=v; } void dfs(int x) { dfn[++top]=x; for(int i=1; i<=16; i++) if((1<<i)<=dep[x]) fa[x][i]=fa[fa[x][i-1]][i-1]; else break; for(int i=head[x]; i; i=nxt[i]) if(to[i]!=fa[x][0]) fa[to[i]][0]=x,dep[to[i]]=dep[x]+1,dfs(to[i]); dfn[++top]=x; } int lca(int u,int v) { if(dep[u]<dep[v]) swap(u,v); int temp=dep[u]-dep[v]; for(int i=0; i<=16; i++) if(temp&(1<<i)) u=fa[u][i]; for(int i=16; i>=0; i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i]; return u==v?u:fa[u][0]; } void spread(int p,int l,int r) { if(!a(p)||l==r) return; int temp=a(p); if(!l(p)) l(p)=++size; if(!r(p)) r(p)=++size; a(p)=0,v(l(p))+=temp,a(l(p))+=temp,v(r(p))+=temp,a(r(p))+=temp; } void change(int &p,int l,int r,int x,int y,int d) { if(!p) p=++size; spread(p,l,r); if(x==l&&y==r) { v(p)+=d,a(p)+=d; return; } int mid=(l+r)>>1; if(x<=mid) change(l(p),l,mid,x,min(y,mid),d); if(y>mid) change(r(p),mid+1,r,max(x,mid+1),y,d); } int ask(int p,int l,int r,int x) { if(!p) return 0; spread(p,l,r); if(l==r) return v(p); int mid=(l+r)>>1; if(x<=mid) return ask(l(p),l,mid,x); else return ask(r(p),mid+1,r,x); } int main() { scanf("%d%d",&n,&q); for(int i=1; i<=n; i++) { scanf("%d",&T[i]); if(!val[T[i]]) val[T[i]]=++sum; T[i]=val[T[i]]; } for(int i=1,u,v; i<n; i++) scanf("%d%d",&u,&v),add(u,v),add(v,u); dfs(1); for(int i=1; i<=top; i++) if(!b[dfn[i]]) b[dfn[i]]=i; else e[dfn[i]]=i; for(int i=1; i<=n; i++) change(root[T[i]],1,top,b[i],e[i],1); for(int i=1,x,y,z; i<=q; i++) { char op[3]; scanf("%s %d%d",op,&x,&y),x^=ans,y^=ans; if(op[0]=='Q') { scanf("%d",&z),z^=ans; int LCA=lca(x,y); if(!val[z]) { ans=0,printf("0 "); continue; } z=val[z],ans=ask(root[z],1,top,b[x])+ask(root[z],1,top,b[y])-2*ask(root[z],1,top,b[LCA]); if(T[LCA]==z) ans++; printf("%d ",ans); } if(op[0]=='C') { if(!val[y]) val[y]=++sum; y=val[y],change(root[T[x]],1,top,b[x],e[x],-1),change(root[y],1,top,b[x],e[x],1),T[x]=y; } } }