对每个结点建立两棵线段树,一棵记录该结点的子树下每种颜色对应的最小深度,另一棵记录子树下的每个深度有多少结点(每种颜色的结点只保留最浅的深度即可),自底而上令父节点继承子结点的线段树,如果合并两棵颜色线段树时发现某种颜色重复,则在深度线段树上把较深的深度对应的位置-1。
注意由于强制在线,深度线段树的合并以及更新都需要可持久化。
(ps:不能用map代替颜色线段树,会TLE~~)
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef double db; 5 const int N=1e5+10; 6 int n,m,hd[N],ne,rt1[N],rt2[N],tot1,tot2,a[N],dep[N]; 7 struct E {int v,nxt;} e[N]; 8 struct D1 {int ls,rs,x;} tr1[N*40]; 9 struct D2 {int ls,rs,x;} tr2[N*80]; 10 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 11 #define mid ((l+r)>>1) 12 int newnode1() {tr1[++tot1]= {0,0,0}; return tot1;} 13 int newnode2() {tr2[++tot2]= {0,0,0}; return tot2;} 14 void upd1(int& u,int p,int x,int l=1,int r=n) { 15 if(!u)u=newnode1(); 16 if(l==r) {tr1[u].x=x; return;} 17 p<=mid?upd1(tr1[u].ls,p,x,l,mid):upd1(tr1[u].rs,p,x,mid+1,r); 18 } 19 void upd2(int& w,int u,int p,int x,int l=1,int r=n) { 20 w=newnode2(); 21 tr2[w].x=tr2[u].x+x; 22 if(l==r)return; 23 if(p<=mid)upd2(tr2[w].ls,tr2[u].ls,p,x,l,mid),tr2[w].rs=tr2[u].rs; 24 else upd2(tr2[w].rs,tr2[u].rs,p,x,mid+1,r),tr2[w].ls=tr2[u].ls; 25 } 26 void mg1(int uu,int& u,int v,int l=1,int r=n) { 27 if(!u) {u=v; return;} 28 if(!v)return; 29 if(l==r) { 30 if(!tr1[u].x)tr1[u].x=tr1[v].x; 31 else if(!tr1[v].x); 32 else { 33 int mx=max(tr1[u].x,tr1[v].x),mi=min(tr1[u].x,tr1[v].x); 34 upd2(rt2[uu],rt2[uu],mx,-1),tr1[u].x=mi; 35 } 36 return; 37 } 38 mg1(uu,tr1[u].ls,tr1[v].ls,l,mid); 39 mg1(uu,tr1[u].rs,tr1[v].rs,mid+1,r); 40 } 41 void mg2(int& w,int u,int v,int l=1,int r=n) { 42 if(!u) {w=v; return;} 43 if(!v) {w=u; return;} 44 w=newnode2(); 45 tr2[w].x=tr2[u].x+tr2[v].x; 46 if(l==r)return; 47 mg2(tr2[w].ls,tr2[u].ls,tr2[v].ls,l,mid); 48 mg2(tr2[w].rs,tr2[u].rs,tr2[v].rs,mid+1,r); 49 } 50 void dfs(int u,int d) { 51 rt1[u]=rt2[u]=0,dep[u]=d; 52 upd1(rt1[u],a[u],dep[u]),upd2(rt2[u],rt2[u],dep[u],1); 53 for(int i=hd[u]; ~i; i=e[i].nxt) { 54 int v=e[i].v; 55 dfs(v,d+1); 56 mg1(u,rt1[u],rt1[v]); 57 mg2(rt2[u],rt2[u],rt2[v]); 58 } 59 } 60 int qry(int u,int L,int R,int l=1,int r=n) { 61 if(l>=L&&r<=R)return tr2[u].x; 62 if(l>R||r<L)return 0; 63 return qry(tr2[u].ls,L,R,l,mid)+qry(tr2[u].rs,L,R,mid+1,r); 64 } 65 int main() { 66 int T; 67 for(scanf("%d",&T); T--;) { 68 memset(hd,-1,sizeof hd),ne=tot1=tot2=0; 69 scanf("%d%d",&n,&m); 70 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 71 for(int i=2; i<=n; ++i) { 72 int f; 73 scanf("%d",&f); 74 addedge(f,i); 75 } 76 dfs(1,1); 77 for(int ans=0; m--;) { 78 int u,d; 79 scanf("%d%d",&u,&d),u^=ans,d^=ans; 80 printf("%d ",ans=qry(rt2[u],1,min(dep[u]+d,n))); 81 } 82 } 83 return 0; 84 }