@description@
给定一棵 n 个点带点权的树。对于 [0, m) 这个值域中的每一个 i,求这棵树有多少连通块的异或和等于 i。
input
多组数据。第一行给出数据组数 T。1 <= T <= 10。
对于每组数据,第一行两个整数 n, m,含义如上。n <= 1000, 1 <= m <= 2^10。
第二行 n 个整数 v1, v2, ... ,vn,表示每个点的权值。0 <= vi < m。
接下来 n-1 行每行两个整数 ai, bi,描述了树上的边。1 <= ai, bi <= n。
保证 m 是 2 的非负次幂。
output
对于每个数据输出 m 个数,第 i 个数表示异或和等于 i 的连通块个数。
模 10^9 + 7。
sample Input
2
4 4
2 0 1 3
1 2
1 3
1 4
4 4
0 1 3 1
1 2
1 3
1 4
sample output
3 3 2 3
2 4 2 3
@solution@
有这样一个树形 dp。
定义 dp(i, j) 表示以 i 为连通块中深度最低的点(其实就是以 i 为根的一棵树),异或和为 j 的方案数。我们逐个考虑每一个 i 的儿子 c,有转移:
前一个表示不选择这棵子树,后一个表示选择这棵子树。初始时 (dp(i, v[i]) = 1)。
这样,把所有结点的 dp 加起来就可以求到答案了。
这个 dp 是 O(n^3) 的。考虑优化。
很容易发现后面那个式子“长得像个卷积”,而且是一个异或卷积。所以我们就使用 FWT 进行优化。先将初始状态通过 FWT 正变换,然后转移就可以是 O(n^2) 的了。求解完所有的 dp 值过后,再通过 FWT 逆变换回来。
总时间复杂度为 (O(n^2log n)),瓶颈在于 FWT 的时间。
@accepted code@
#include<cstdio>
const int MAXN = 1000;
const int MAXK = (1<<10);
const int MOD = int(1E9) + 7;
const int INV = (MOD + 1) >> 1;
struct edge{
int to; edge *nxt;
}edges[2*MAXN + 5], *adj[MAXN + 5], *ecnt;
void addedge(int u, int v) {
edge *p = (++ecnt);
p->to = v, p->nxt = adj[u], adj[u] = p;
p = (++ecnt);
p->to = u, p->nxt = adj[v], adj[v] = p;
}
int n, m;
int v[MAXN + 5], f[MAXN + 5][MAXK + 5], ans[MAXK + 5];
void init() {
for(int i=1;i<=n;i++)
adj[i] = NULL;
for(int i=0;i<m;i++)
ans[i] = 0;
ecnt = &edges[0];
}
void fwt(int *a, int n, int type) {
for(int s=2;s<=n;s<<=1)
for(int t=(s>>1),i=0;i<n;i+=s)
for(int j=0;j<t;j++) {
int x = a[i+j], y = a[i+j+t];
a[i+j] = 1LL*(x+y)%MOD*(type == 1 ? 1 : INV)%MOD;
a[i+j+t] = 1LL*(x+MOD-y)%MOD*(type == 1 ? 1 : INV)%MOD;
}
}
void dfs(int rt, int pre) {
for(int i=0;i<m;i++)
f[rt][i] = 0;
f[rt][v[rt]] = 1;
fwt(f[rt], m, 1);
for(edge *p=adj[rt];p!=NULL;p=p->nxt) {
if( p->to == pre ) continue;
dfs(p->to, rt);
for(int i=0;i<m;i++)
f[rt][i] = (f[rt][i] + 1LL*f[rt][i]*f[p->to][i]%MOD)%MOD;
}
}
void solve() {
scanf("%d%d", &n, &m);
for(int i=1;i<=n;i++)
scanf("%d", &v[i]);
init();
for(int i=1;i<n;i++) {
int a, b;
scanf("%d%d", &a, &b);
addedge(a, b);
}
dfs(1, 0);
for(int i=1;i<=n;i++) {
fwt(f[i], m, -1);
for(int j=0;j<m;j++)
ans[j] = (ans[j] + f[i][j]) % MOD;
}
for(int i=0;i<m;i++)
printf("%d", ans[i]), putchar(i == m-1?'
':' ');
}
int main() {
int T; scanf("%d", &T);
for(int i=1;i<=T;i++)
solve();
}
@details@
他们好像说这个题他们先 FWT 再全部 dp 最后 FWT 回来会 WA ?
明明我没有 WA 啊?奇怪嘞奇怪嘞。