考虑维护原树的lct,在上面dp,由于dp方程特殊,均为异或卷积或加法,计算中可以只使用fwt后的序列
v[w]表示联通子树的最浅点为w,且不选w的splay子树中的点
l[w]表示联通子树的最浅点在w的lct子树中,且选w的splay子树中极左点(w的splay子树为{w}+{u的splay子树,满足u==ch[w][0]||u==ch[w][1]})
r[w]表示联通子树的最浅点在w的lct子树中,且选w的splay子树中极右点
lr[w]表示联通子树的最浅点为w,且选w的splay子树中所有的点
s[w]表示联通子树的最浅点在w的lct子树中(w的lct子树为{w}+{u的lct子树,满足fa[u]==w})
ss[w]表示联通子树的最浅点在w的lct子树中,且不选w
时间复杂度O((n+q)mlogm),由于dp方程复杂以及lct自带的常数,总体常数较大
#include<bits/stdc++.h> #define fm(i) for(int i=0;i<m;++i) typedef int arr[131]; const int N=31007,P=10007; int n,m,v0[N],ivs[25007],*iv=ivs+P+10; arr x0[131],ans,tmp; std::vector<int>e[N]; struct node{ node*c[2],*f; arr v,vt,l,r,lr,s,ss; bool nrt(); void up(); void setrc(node*w); }ns[N],*nil=ns,*rt=ns+1; void cpy(int*a,int*b){memcpy(a,b,sizeof(int)*m);} bool node::nrt(){return this==f->c[0]||this==f->c[1];} void node::up(){ fm(i){ int vi=vt[i]?0:v[i]; int vc0lr=vi*c[0]->lr[i]%P; int vc0r=vi*c[0]->r[i]%P; l[i]=(vc0lr*c[1]->l[i]+c[0]->l[i])%P; r[i]=(vc0r*c[1]->lr[i]+c[1]->r[i])%P; lr[i]=vc0lr*c[1]->lr[i]%P; s[i]=(vc0r*c[1]->l[i]+ss[i])%P; } } void mul(int*v1,int*t,int*v2){fm(i)v2[i]?v1[i]=v1[i]*v2[i]%P:++t[i];} void div(int*v1,int*t,int*v2){fm(i)v2[i]?v1[i]=v1[i]*iv[v2[i]]%P:--t[i];} void node::setrc(node*w){ if(c[1]!=nil)mul(v,vt,c[1]->l); c[1]=w; if(c[1]!=nil)div(v,vt,c[1]->l); up(); } void rot(node*w){ node*f=w->f,*g=f->f; int d=w==f->c[1]; if(f->nrt())g->c[g->c[1]==f]=w; w->f=g; (f->c[d]=w->c[d^1])->f=f; (w->c[d^1]=f)->f=w; fm(i){ int x=f->ss[i],y=f->c[d]->s[i]; f->ss[i]=(x-w->s[i]+y)%P; w->s[i]=f->s[i]; w->ss[i]-=y; } f->up(); fm(i)w->ss[i]=(w->ss[i]+f->s[i])%P; } void sp(node*w){ if(!w->nrt())return; do{ node*f=w->f; if(f->nrt())rot((f->c[1]==w)==(f->f->c[1]==f)?f:w); rot(w); }while(w->nrt()); w->up(); } void acs(node*x){ rt=x; for(node*y=nil;x!=nil;sp(x),x->setrc(y),y=x,x=x->f); sp(rt); } void dfs(int w,int pa){ for(int i=0;i<e[w].size();++i){ int u=e[w][i]; if(u!=pa){ dfs(u,w); ns[u].f=ns+w; mul(ns[w].v,ns[w].vt,ns[u].l); fm(j)ns[w].ss[j]+=ns[u].s[j]; } } ns[w].up(); } void dwt(int*a){ for(int i=1;i<m;i<<=1){ for(int j=0;j<m;j+=i<<1){ int*b=a+j,*c=b+i; for(int k=0;k<i;++k){ int x=b[k],y=c[k]; b[k]=x+y; c[k]=x-y; } } } fm(i)a[i]%=P; } char buf[2000007],*ptr=buf; int _(){ int x=0; while(*ptr<48)++ptr; while(*ptr>47)x=x*10+*ptr++-48; return x; } int main(){ fread(buf,1,sizeof(buf),stdin); n=_();m=_(); for(int i=0;i<m;++i){ x0[i][i]=1; dwt(x0[i]); } iv[1-P]=iv[1]=1; for(int i=2;i<P;++i)iv[i-P]=iv[i]=(P-P/i)*iv[P%i]%P; for(int i=0;i<=n;++i)ns[i]=(node){nil,nil,nil}; cpy(nil->l,x0[0]); cpy(nil->r,x0[0]); cpy(nil->lr,x0[0]); for(int i=1;i<=n;++i){ v0[i]=_(); cpy(ns[i].v,x0[v0[i]]); } for(int i=1,a,b;i<n;++i){ a=_(),b=_(); e[a].push_back(b); e[b].push_back(a); } dfs(1,0); for(int q=_(),ed=1;q;--q){ if(_()==405033){ if(ed){ cpy(ans,rt->s); dwt(ans); ed=0; } printf("%d ",(ans[_()]*iv[m]%P+P)%P); }else{ ed=1; int w=_(),x=_(); acs(ns+w); int*v1=x0[v0[w]],*v2=x0[x]; fm(i)rt->v[i]=rt->v[i]*v1[i][iv]%P*v2[i]%P; fm(i)ans[i]=rt->vt[i]?0:rt->v[i]; dwt(ans); v0[w]=x; rt->up(); } } return 0; }
由于树形态不改变,也可以建一棵深度不超过logn+O(1)的静态lct,可以明显减小常数
#include<bits/stdc++.h> #define fm(i) for(int i=0;i<m;++i) typedef int arr[131]; const int N=31007,P=10007; int n,m,v0[N],ivs[25007],*iv=ivs+P+10; arr x0[131],ans,tmp; std::vector<int>e[N]; struct node{ node*c[2],*f; arr v,vt,l,r,lr,s,ss; bool isrt(); void up(); }ns[N],*nil=ns,*rt; bool node::isrt(){return this!=f->c[0]&&this!=f->c[1];} void cpy(int*a,int*b){memcpy(a,b,sizeof(int)*m);} void node::up(){ fm(i){ int vi=vt[i]?0:v[i]; int vc0lr=vi*c[0]->lr[i]%P; int vc0r=vi*c[0]->r[i]%P; l[i]=(vc0lr*c[1]->l[i]+c[0]->l[i])%P; r[i]=(vc0r*c[1]->lr[i]+c[1]->r[i])%P; lr[i]=vc0lr*c[1]->lr[i]%P; s[i]=(vc0r*c[1]->l[i]+ss[i])%P; } } void mul(int*v1,int*t,int*v2){fm(i)v2[i]?v1[i]=v1[i]*v2[i]%P:++t[i];} void div(int*v1,int*t,int*v2){fm(i)v2[i]?v1[i]=v1[i]*iv[v2[i]]%P:--t[i];} int fa[N],sz[N],son[N],dep[N],top[N],ws[N],wp; void f1(int w,int pa){ dep[w]=dep[fa[w]=pa]+(sz[w]=1); for(int i=0;i<e[w].size();++i){ int u=e[w][i]; if(u!=pa){ f1(u,w); sz[w]+=sz[u]; if(sz[u]>sz[son[w]])son[w]=u; } } ns[w].up(); } node*build(int L,int R,node*f){ if(L>R)return nil; int L0=L,R0=R; for(int M,s0=ws[L][sz]+ws[R+1][sz];L<R;ws[M=L+R+1>>1][sz]*2<s0?R=M-1:L=M); node*w=ns+ws[L]; w->c[0]=build(L0,L-1,w); w->c[1]=build(L+1,R0,w); w->f=f; w->up(); if(f!=nil){ fm(i)f->ss[i]+=w->s[i]; if(L0==1&&R0==wp)mul(f->v,f->vt,w->l); } return w; } void f2(int w,int tp){ top[w]=tp; for(int i=0;i<e[w].size();++i){ int u=e[w][i]; if(u!=fa[w]&&u!=son[w])f2(u,u); } if(son[w])f2(son[w],tp); else{ wp=0; for(int a=tp;a;a=son[a])ws[++wp]=a; ws[wp+1]=0; rt=build(1,wp,ns+fa[tp]); } } void dwt(int*a){ for(int i=1;i<m;i<<=1){ for(int j=0;j<m;j+=i<<1){ int*b=a+j,*c=b+i; for(int k=0;k<i;++k){ int x=b[k],y=c[k]; b[k]=x+y; c[k]=x-y; } } } fm(i)a[i]%=P; } char buf[2000007],*ptr=buf; int _(){ int x=0; while(*ptr<48)++ptr; while(*ptr>47)x=x*10+*ptr++-48; return x; } int main(){ fread(buf,1,sizeof(buf),stdin); n=_();m=_(); for(int i=0;i<m;++i){ x0[i][i]=1; dwt(x0[i]); } iv[1-P]=iv[1]=1; for(int i=2;i<P;++i)iv[i-P]=iv[i]=(P-P/i)*iv[P%i]%P; for(int i=0;i<=n;++i)ns[i]=(node){nil,nil,nil}; cpy(nil->l,x0[0]); cpy(nil->r,x0[0]); cpy(nil->lr,x0[0]); for(int i=1;i<=n;++i){ v0[i]=_(); cpy(ns[i].v,x0[v0[i]]); } for(int i=1,a,b;i<n;++i){ a=_(),b=_(); e[a].push_back(b); e[b].push_back(a); } f1(1,0);f2(1,1); for(int q=_(),ed=1;q;--q){ if(_()==405033){ if(ed){ cpy(ans,rt->s); dwt(ans); ed=0; } printf("%d ",(ans[_()]*iv[m]%P+P)%P); }else{ ed=1; int w=_(),x=_(),stp=0; node*_w=ns+w,*stk[37]; for(node*a=_w;a!=nil;stk[++stp]=a,a=a->f); for(int i=stp;i>1;--i){ int*v1=stk[i]->ss,*v2=stk[i-1]->s; fm(j)v1[j]-=v2[j]; if(stk[i-1]->isrt())div(stk[i]->v,stk[i]->vt,stk[i-1]->l); } int*v1=x0[v0[w]],*v2=x0[x]; fm(i)_w->v[i]=_w->v[i]*v1[i][iv]%P*v2[i]%P; _w->up(); for(int i=2;i<=stp;++i){ int*v1=stk[i]->ss,*v2=stk[i-1]->s; fm(j)v1[j]=(v1[j]+v2[j])%P; if(stk[i-1]->isrt())mul(stk[i]->v,stk[i]->vt,stk[i-1]->l); stk[i]->up(); } v0[w]=x; } } return 0; }