一棵树,每个点有点权,多次操作
1.单点修改一个点的点权
2.询问有多少棵子树点权异或和为 $k$
$n leq 30000,k leq 128,q leq 30000$
sol:
动态 dp
为防止自己忘,再写一遍
一个点的 dp 值 = $sum dp_{轻儿子} + dp_{重儿子}$
这样就可以一条重链一起转移
用线段树维护重链上的转移,每次修改时跳 log n 条重链
这题还要强行整个 fwt,甚至还要写一个模 mod 剩余类,维护乘了多少个 0
#include <bits/stdc++.h> #define LL long long #define rep(i, s, t) for (register int i = (s), i##end = (t); i <= i##end; ++i) #define dwn(i, s, t) for (register int i = (s), i##end = (t); i >= i##end; --i) using namespace std; inline int read() { int x = 0, f = 1; char ch; for (ch = getchar(); !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = 10 * x + ch - '0'; return x * f; } const int mod = 10007, maxn = 65010, maxk = 256; int n, m, inv[maxn]; int val[maxn], e[maxk][maxk], ans[maxk], tmp[maxk]; struct Mint { int val, tms; Mint(int cur = 0) { if (!cur) { val = tms = 1; } else { tms = 0, val = cur; } } friend Mint operator*(Mint a, int b) { if (b == 0) a.tms++; else (a.val *= b) %= mod; return a; } friend Mint operator/(Mint a, int b) { if (b == 0) a.tms--; else (a.val *= inv[b]) %= mod; return a; } int real() { return tms ? 0 : val; } } f[maxn][maxk]; inline int inc(int x, int y) { x += y; if (x >= mod) x -= mod; return x; } inline int dec(int x, int y) { x -= y; if (x < 0) x += mod; return x; } void fwt(int *a, int n, int f) { for (register int i = 1; i < n; i <<= 1) for (register int j = 0; j < n; j += (i << 1)) rep(k, 0, i - 1) { int x = a[j + k], y = a[j + k + i]; if (f == 1) a[j + k] = inc(x, y), a[j + k + i] = dec(x, y); else a[j + k] = inc(x, y) * inv[2] % mod, a[j + k + i] = dec(x, y) * inv[2] % mod; } } vector<int> G[maxn], ch[maxn]; int fa[maxn], dep[maxn], bl[maxn], mxs[maxn], size[maxn], stk[maxn], top; inline void dfs1(int x) { size[x] = 1; for (int i=0;i<G[x].size();i++) { int to = G[x][i]; if (to == fa[x]) continue; dep[to] = dep[x] + 1; fa[to] = x; dfs1(to); if (size[to] > size[mxs[x]]) mxs[x] = to; size[x] += size[to]; } } inline void dfs2(int x, int col) { bl[x] = col; ch[col].push_back(x); if (bl[x] == x) stk[++top] = x; if (!mxs[x]) return; dfs2(mxs[x], col); for (int i=0;i<G[x].size();i++) { int to = G[x][i]; if (to != fa[x] && to != mxs[x]) dfs2(to, to); } } int cmp(int i, int j) { return dep[i] > dep[j]; } int root[maxn], ls[maxn << 6], rs[maxn << 6], anc[maxn << 6], pos[maxn << 6], ToT; int h[maxn][maxk], lval[maxn][maxk], rval[maxn][maxk], sum[maxn][maxk]; void pushup(int x) { rep(i, 0, m - 1) { h[x][i] = (h[ls[x]][i] + h[rs[x]][i] + rval[ls[x]][i] * lval[rs[x]][i]) % mod; lval[x][i] = (lval[ls[x]][i] + sum[ls[x]][i] * lval[rs[x]][i]) % mod; rval[x][i] = (rval[rs[x]][i] + sum[rs[x]][i] * rval[ls[x]][i]) % mod; sum[x][i] = sum[ls[x]][i] * sum[rs[x]][i] % mod; } } inline void build(int &x, int l, int r, int ps) { x = ++ToT; if (l == r) { rep(i, 0, m - 1) h[x][i] = lval[x][i] = rval[x][i] = sum[x][i] = f[ch[ps][l - 1]][i].real(); pos[ch[ps][l - 1]] = x; return; } int mid = (l + r) >> 1; build(ls[x], l, mid, ps); anc[ls[x]] = x; build(rs[x], mid + 1, r, ps); anc[rs[x]] = x; // anc[ls[x]] = anc[rs[x]] = x; pushup(x); } inline void modify(int u) { int x = pos[u], tp = bl[u]; if (fa[tp]) rep(i, 0, m - 1) f[fa[tp]][i] = f[fa[tp]][i] / ((lval[root[tp]][i] + e[0][i]) % mod); rep(i, 0, m - 1) ans[i] = (ans[i] - h[root[tp]][i] + mod) % mod; rep(i, 0, m - 1) sum[x][i] = lval[x][i] = rval[x][i] = h[x][i] = f[u][i].real(); for (x = anc[x]; x; x = anc[x]) pushup(x); if (fa[tp]) rep(i, 0, m - 1) f[fa[tp]][i] = f[fa[tp]][i] * ((lval[root[tp]][i] + e[0][i]) % mod); rep(i, 0, m - 1) ans[i] = (ans[i] + h[root[tp]][i]) % mod; // rep(i, 0, m-1) cout << ans[i] << " "; // cout << endl; } int main() { n = read(), m = read(); int w = 1; for (; w <= m; w <<= 1) ; m = w; rep(i, 1, n) val[i] = read(); inv[1] = 1; rep(i, 2, mod - 1) inv[i] = (mod - (mod / i)) * inv[mod % i] % mod; rep(i, 2, n) { int u = read(), v = read(); G[u].push_back(v); G[v].push_back(u); } dfs1(1); dfs2(1, 1); rep(i, 0, m - 1) { e[i][i] = 1; fwt(e[i], m, 1); } rep(i, 1, n) rep(j, 0, m - 1) f[i][j] = Mint(e[val[i]][j]); sort(stk + 1, stk + top + 1, cmp); rep(i, 1, top) { int now = stk[i]; build(root[now], 1, ch[now].size(), now); rep(j, 0, m - 1) ans[j] = (ans[j] + h[root[now]][j]) % mod; if (fa[now]) rep(j, 0, m - 1) f[fa[now]][j] = f[fa[now]][j] * ((lval[root[now]][j] + e[0][j]) % mod); } // rep(i, 0, m-1) cout << ans[i] << " "; // cout << endl; int q = read(); char opt[10]; while (q--) { scanf("%s", opt + 1); if (opt[1] == 'C') { int x = read(), y = read(); rep(i, 0, m - 1) f[x][i] = f[x][i] / e[val[x]][i]; val[x] = y; rep(i, 0, m - 1) f[x][i] = f[x][i] * e[val[x]][i]; for (; x; x = fa[bl[x]]) modify(x); } else { int x = read(); // rep(i, 0, m-1) cout << ans[i] << " "; rep(i, 0, m - 1) tmp[i] = ans[i]; fwt(tmp, m, -1); // rep(i, 0, m-1) cout << tmp[i] << " "; // cout << endl; printf("%d ", tmp[x]); } } }