传送门:http://acm.hdu.edu.cn/showproblem.php?pid=5909
【题解】
设$f_{x,i}$表示以$x$节点的子树中,权值为$i$的子树个数,其中$x$必选。
那么有dp方程:$f_{x,i} = sum_{y = son[x]} f_{x,i} + sum_{j oplus k = i}f_{x, j}f_{y, k}$
用FWT优化转移即可,复杂度$O(nmlogm)$。
# include <stdio.h> # include <string.h> # include <iostream> # include <algorithm> // # include <bits/stdc++.h> using namespace std; typedef long long ll; typedef long double ld; typedef unsigned long long ull; const int N = 1e3 + 5, H = 1024 + 5; const int mod = 1e9+7; int n, L, w[N], inv2; int f[N][H]; int ans[H]; int head[N], nxt[N + N], to[N + N], tot = 0; inline void add(int u, int v) { ++tot; nxt[tot] = head[u]; head[u] = tot; to[tot] = v; } inline void adde(int u, int v) { add(u, v), add(v, u); } inline int pwr(int a, int b) { int ret = 1; while(b) { if(b&1) ret = 1ll * ret * a % mod; a = 1ll * a * a % mod; b >>= 1; } return ret; } int s[H], t[H]; inline void FWT(int *a, int op) { if(op) { for (int len = 2; len <= L; len<<=1) { int m = len >> 1; for (int *p = a; p != a+L; p += len) { for (int k=0; k<m; ++k) { int x = p[k], y = p[k+m]; p[k] = 1ll * (x+y) * inv2 % mod; p[k+m] = 1ll * (x-y+mod) * inv2 % mod; } } } } else { for (int len = 2; len <= L; len<<=1) { int m = len >> 1; for (int *p = a; p != a+L; p += len) { for (int k=0; k<m; ++k) { int x = p[k], y = p[k+m]; p[k] = (x+y) % mod; p[k+m] = (x-y+mod) % mod; } } } } } inline void FWT_combine(int *A, int *B) { for (int i=0; i<L; ++i) s[i] = A[i], t[i] = B[i]; FWT(s, 0); FWT(t, 0); for (int i=0; i<L; ++i) s[i] = 1ll * s[i] * t[i] % mod; FWT(s, 1); for (int i=0; i<L; ++i) (A[i] += s[i]) %= mod; } inline void dfs(int x, int fa = 0) { for (int j=0; j<L; ++j) f[x][j] = 0; f[x][w[x]] = 1; for (int i=head[x]; i; i=nxt[i]) { if(to[i] == fa) continue; dfs(to[i], x); FWT_combine(f[x], f[to[i]]); } for (int j=0; j<L; ++j) (ans[j] += f[x][j]) %= mod; } inline void sol() { tot = 0; memset(head, 0, sizeof head); memset(ans, 0, sizeof ans); cin >> n >> L; for (int i=1; i<=n; ++i) scanf("%d", w+i); for (int i=1, u, v; i<n; ++i) { scanf("%d%d", &u, &v); adde(u, v); } dfs(1); printf("%d", ans[0]); for (int i=1; i<L; ++i) printf(" %d", ans[i]); puts(""); } int main() { int T; cin >> T; inv2 = pwr(2, mod-2); while(T--) sol(); return 0; }