题目:http://acm.hdu.edu.cn/showproblem.php?pid=5909
点分治的话,每次要做一次树形DP;但时间应该是 siz*m2 的。可以用 FWT 变成 siz*mlogm ,但这里写的是把树变成序列来 DP 的方法,应该是 nlogn*m 的。
树上的一个点,如果选,就可以选它的孩子,所以它向它的第一个孩子连边;如果不选,就会跳到它的下一个兄弟或者是父亲的下一个兄弟之类的,向那边连一条边。
做出树的 dfs 序,把边都连在 dfs 序上;其实那个第一条边一定连向自己 dfs 序+1,即使自己没有孩子也是符合的,所以可以不用连了;第二条边可以通过传父亲的连边对象来解决。
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int N=1005,M=1025,mod=1e9+7; int T,n,m,w[N],hd[N],xnt,to[N<<1],nxt[N<<1],siz[N],rt,mn; int dfn[N],tot,sta[N],top,f[N][M],g[N],nt[N],ans[M]; bool vis[N]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} void upd(int &x){x>=mod?x-=mod:0;} void init() { xnt=0;memset(hd,0,sizeof hd); memset(ans,0,sizeof ans); memset(vis,0,sizeof vis); } void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void getrt(int cr,int fa,int s) { siz[cr]=1; int mx=0; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) { getrt(v,cr,s);siz[cr]+=siz[v]; mx=Mx(mx,siz[v]); } mx=Mx(mx,s-siz[cr]);if(mx<mn)mn=mx,rt=cr; } void dfs(int cr,int fa) { dfn[cr]=++tot;g[tot]=w[cr]; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)dfs(v,cr); } void dfsx(int cr,int fa,int lst) { nt[dfn[cr]]=lst; int l=top+1; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)sta[++top]=v; int r=top; for(int i=hd[cr],v,p0=l;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) { dfsx(v,cr,p0==r?lst:dfn[sta[p0+1]]);p0++; } } void solve(int cr,int s) { vis[cr]=1; tot=0;dfs(cr,0);top=0;dfsx(cr,0,s+1); for(int i=1;i<=s+1;i++)memset(f[i],0,sizeof f[i]); f[1][0]=1; for(int i=1;i<=s;i++) for(int j=0;j<m;j++) { if(!f[i][j])continue; f[i+1][j^g[i]]+=f[i][j];upd(f[i+1][j^g[i]]); f[nt[i]][j]+=f[i][j];upd(f[nt[i]][j]); } f[s+1][0]--;//dec the empty for(int j=0,k=s+1;j<m;j++)ans[j]+=f[k][j],upd(ans[j]); for(int i=hd[cr],v,ts;i;i=nxt[i]) if(!vis[v=to[i]]) { ts=(siz[cr]>siz[v]?siz[v]:s-siz[cr]); mn=N;getrt(v,cr,ts);solve(rt,ts); } } int main() { T=rdn(); while(T--) { n=rdn();m=rdn();for(int i=1;i<=n;i++)w[i]=rdn(); init(); for(int i=1,u,v;i<n;i++)u=rdn(),v=rdn(),add(u,v),add(v,u); mn=N;getrt(1,0,n);solve(rt,n); for(int i=0,j=m-1;i<j;i++)printf("%d ",ans[i]); printf("%d ",ans[m-1]); } return 0; }