考虑暴力的dp,即用$f_{i,j}$表示以$i$为根的子树内,强制$i$必须选且异或为$j$的方案数,转移用FWT即可,求出该dp数组的时间复杂度为$o(nmlog_{2}m)$
由于是全局的方案数,再记录一个$sum_{i,j}=f_{i,j}+sum_{son}sum_{son,j}$,那么即求$sum_{1,x}$
令$f'_{i}=FWT(f_{i})$,则有$f'_{i,j}=a_{i,j}prod_{son}(f'_{son,j}+1)$(其中$a_{i,j}$指点$i$初始的dp数组(即$f_{i,v_{i}}=1$)FWT后的结果,加1是最后对$f_{son,0}$加1,FWT后即对所有位置加1)
根据FWT的分配律,可得$sum'_{i,j}=f'_{i,j}+sum_{son}sum'_{son,j}$,最后求出$sum_{1}=IFWT(sum'_{1})$即可
这样做单次询问复杂度降为$o(nm)$,但还是无法通过
注意到这样的每一个$j$除了在最后$IFWT$以外,都是独立的,因此考虑求某一个$sum'_{1,j}$,以下就省略数组的第二维(都是$j$)
对其树链剖分,记其重儿子为$hs_{k}$,先统计轻儿子的信息,即:
令$g_{k}=a_{k}prod_{son e hs_{k}}(f'_{son}+1)$那么就有$f'_{k}=g_{k}(f'_{hs_{k}}+1)$
令$h_{k}=sum_{son e hs_{k}}sum'_{son}$,则$sum'_{k}=h_{k}+sum'_{hs_{k}}+f_{k}$
考虑一条重链的维护,构建矩阵$A_{k}=[1 f'_{k} sum'_{k}]$,那么即$A_{k}=A_{hs_{k}}egin{bmatrix}1& g_{k}&h_{k}+g_{k}\0&g_{k}&g_{k}\0&0&1end{bmatrix}$
根据矩阵乘法的结合律,用线段树维护区间转移矩阵的乘积,再通过将该点直至重链尾部的转移矩阵全部乘起来(初始状态为$[1 0 0]$),即可求出每一个$k$的$f'_{k}$以及$sum'_{k}$(询问即$k=1$)
对于修改,会改变$k$的转移矩阵,即改变了$A_{top}$(重链顶端),将其求出后再根据轻链的转移修改到$g_{fa_{top}}$和$h_{fa_{top}}$,重复此过程即可,复杂度即为$o(3^{3}qlog^{2}n)$
(特别的,对于$g_{k}$需要存储其轻儿子中0的个数,来支持除法)
事实上,矩阵只需要维护右上角的4个位置(其余位置相乘后不变),复杂度降为$o(2^{2}qlog^{2}n)$,
(另外,矩阵乘法不具备交换律,因此线段树上要右边乘左边)
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 30005 4 #define M (1<<7) 5 #define mod 10007 6 #define L (k<<1) 7 #define R (L+1) 8 #define mid (l+r>>1) 9 struct ji{ 10 int nex,to; 11 }edge[N<<1]; 12 int E,n,m,x,y,head[N],v[N],fa[N],sz[N],son[N],id[N],top[N],las[N]; 13 char s[11]; 14 int ksm(int n,int m){ 15 int s=n,ans=1; 16 while (m){ 17 if (m&1)ans=ans*s%mod; 18 s=s*s%mod; 19 m>>=1; 20 } 21 return ans; 22 } 23 void add(int x,int y){ 24 edge[E].nex=head[x]; 25 edge[E].to=y; 26 head[x]=E++; 27 } 28 void dfs1(int k,int f){ 29 fa[k]=f; 30 sz[k]=1; 31 for(int i=head[k];i!=-1;i=edge[i].nex) 32 if (edge[i].to!=f){ 33 dfs1(edge[i].to,k); 34 sz[k]+=sz[edge[i].to]; 35 if ((!son[k])||(sz[son[k]]<sz[edge[i].to]))son[k]=edge[i].to; 36 } 37 } 38 void dfs2(int k,int fa,int t){ 39 id[k]=++x; 40 top[k]=t; 41 if (!son[k])las[k]=k; 42 else{ 43 dfs2(son[k],k,t); 44 las[k]=las[son[k]]; 45 } 46 for(int i=head[k];i!=-1;i=edge[i].nex){ 47 int x=edge[i].to; 48 if ((x!=fa)&&(x!=son[k]))dfs2(x,k,x); 49 } 50 } 51 struct num{ 52 int t,v; 53 num operator * (const num &a){ 54 return num{t+a.t,v*a.v%mod}; 55 } 56 num inv(){ 57 return num{-t,ksm(v,mod-2)}; 58 } 59 int value(){ 60 if (t)return 0; 61 return v; 62 } 63 }; 64 num turn(int k){ 65 k%=mod; 66 if (!k)return num{1,1}; 67 return num{0,k}; 68 } 69 struct mat{ 70 int a,b,c,d; 71 mat operator * (const mat &k)const{ 72 mat ans; 73 ans.a=(k.a+a*k.c)%mod; 74 ans.b=(b+k.b+a*k.d)%mod; 75 ans.c=c*k.c%mod; 76 ans.d=(c*k.d+d)%mod; 77 return ans; 78 } 79 }; 80 struct Seg{ 81 int h[N]; 82 num g[N]; 83 mat f[N<<2]; 84 void init(){ 85 f[0].c=1; 86 for(int i=1;i<=n;i++)g[i]=turn(1); 87 } 88 void update(int k,int l,int r,int x){ 89 if (l==r){ 90 f[k].a=f[k].c=f[k].d=g[x].value(); 91 f[k].b=(g[x].value()+h[x])%mod; 92 return; 93 } 94 if (x<=mid)update(L,l,mid,x); 95 else update(R,mid+1,r,x); 96 f[k]=f[R]*f[L]; 97 } 98 mat query(int k,int l,int r,int x,int y){ 99 if ((l>y)||(x>r))return f[0]; 100 if ((x<=l)&&(r<=y))return f[k]; 101 return query(R,mid+1,r,x,y)*query(L,l,mid,x,y); 102 } 103 mat get(int k){ 104 return query(1,1,n,id[k],id[las[k]]); 105 } 106 void update(int k,num x,int y){ 107 while (k){ 108 mat ans=get(top[k]); 109 g[id[k]]=g[id[k]]*x; 110 h[id[k]]+=y; 111 x=turn(ans.a+1).inv(),y=mod-ans.b; 112 update(1,1,n,id[k]); 113 ans=get(top[k]); 114 x=x*turn(ans.a+1),y=(y+ans.b)%mod; 115 k=fa[top[k]]; 116 } 117 } 118 }T[M]; 119 struct FWT{ 120 int a[M]; 121 void fwt(int p){ 122 for(int i=0;i<7;i++) 123 for(int j=0;j<M;j++) 124 if (j&(1<<i)){ 125 int x=a[j^(1<<i)],y=a[j]; 126 a[j^(1<<i)]=(x+y)%mod; 127 a[j]=(x+mod-y)%mod; 128 } 129 if (p){ 130 int s=ksm(M,mod-2); 131 for(int i=0;i<M;i++)a[i]=1LL*a[i]*s%mod; 132 } 133 } 134 }ans; 135 void update(int k,int p){ 136 for(int i=0;i<M;i++)ans.a[i]=(i==v[k]); 137 ans.fwt(0); 138 for(int i=0;i<M;i++) 139 if (!p)T[i].update(k,turn(ans.a[i]),0); 140 else T[i].update(k,turn(ans.a[i]).inv(),0); 141 } 142 int main(){ 143 scanf("%d%*d",&n); 144 for(int i=1;i<=n;i++)scanf("%d",&v[i]); 145 memset(head,-1,sizeof(head)); 146 for(int i=1;i<n;i++){ 147 scanf("%d%d",&x,&y); 148 add(x,y); 149 add(y,x); 150 } 151 dfs1(1,0); 152 x=0; 153 dfs2(1,0,1); 154 for(int i=0;i<M;i++)T[i].init(); 155 for(int i=1;i<=n;i++)update(i,0); 156 scanf("%d",&m); 157 for(int i=1;i<=m;i++){ 158 scanf("%s%d",s,&x); 159 if (s[0]=='Q'){ 160 for(int j=0;j<M;j++)ans.a[j]=T[j].get(1).b; 161 ans.fwt(1); 162 printf("%d ",ans.a[x]); 163 } 164 else{ 165 update(x,1); 166 scanf("%d",&v[x]); 167 update(x,0); 168 } 169 } 170 }