草,遇到这种时候,上来就应该说一句 Sooke 牛逼!
考虑这个 (m leq 35) 那么很显然是一个 mid in middle 的范围
压成线性基,搜两次,算一下高位的 bit数,由于低位不存在高位的bit,拿来和后边存在高位的bit卷一下就可以了。
我们把 (>=17) 的位拿来 (FWT),然后卷积,没了。
需要注意的是,加入一共有 (n) 个数字,线性基里面有 (k) 个数字,那么线性基的每个能构造出来的数字数量是 (2^{n-k})
所以答案应该乘上一个 (2^{n-k})
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 998244353;
const int maxn = 4e5 + 54;
int n, m;
int a[maxn];
int f[28][maxn], g[maxn];
int d[2333], c = 0;
int cnt[maxn << 1];
int ans[maxn], tmp[maxn];
void ins(int x) {
for(int i = m - 1 ; ~ i ; i --) {
if(x & (1ll << i)) {
if(! d[i]) {
d[i] = x;
++ c;
}
x ^= d[i];
}
}
}
int mid;
void dfs(int u, int num) {
if(u == mid - 1) {
f[cnt[num >> mid]][num & ((1 << mid) - 1)] ++;
return;
}
dfs(u - 1, num);
if(d[u]) {
dfs(u - 1, num ^ d[u]);
}
}
void dfs2(int u, int num) {
if(u == mid) {
g[num] ++;
return ;
}
dfs2(u + 1, num);
if(d[u]) {
dfs2(u + 1, num ^ d[u]);
}
}
int qpow(int x,int y) {
int ans = 1;
for(; y; y >>= 1, x = x * x % mod)
if(y & 1)
ans = ans * x % mod;
return ans;
}
const int inv2 = qpow(2, mod - 2);
void fwt(int *f, int type, int n) {
for(int len = 1 ; len < n ; len <<= 1) {
for(int i = 0 ; i < n ; i += len << 1) {
for(int j = 0; j < len ; j ++) {
int x = f[i + j];
int y = f[i + j + len];
f[i + j] = (x + y) % mod;
f[i + j + len] = (x - y + mod) % mod;
if(type == -1)
f[i + j] = f[i + j] * inv2 % mod,
f[i + j + len] = f[i + j + len] * inv2 % mod;
}
}
}
}
signed main() {
ios :: sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
cin >> n >> m;
for(int i = 1 ; i <= n ; i ++)
cin >> a[i];
for(int i = 1 ; i <= n ; i ++)
ins(a[i]);
mid = m + 1 >> 1;
for(int i = 1 ; i <= 300000 ; i ++)
cnt[i] = cnt[i >> 1] + (i & 1);
dfs(m - 1, 0);
dfs2(0, 0);
memcpy(tmp, g, sizeof(g));
for(int i = 0; i <= m - mid ; i ++) {
memcpy(g, tmp, sizeof(g));
fwt(f[i], 1, 1 << mid);
fwt(g, 1, 1 << mid);
for(int j = 0; j < (1 << mid); ++j)
f[i][j] = f[i][j] * g[j] % mod;
fwt(f[i], -1, 1 << mid);
for(int j = 0; j < (1 << mid); ++j)
ans[i + cnt[j]] = (ans[i + cnt[j]] + f[i][j]) % mod;
}
int qwq = qpow(2, n - c) ;
for(int i = 0 ; i <= m ; i++)
ans[i] = ans[i] * qwq % mod;
for(int i = 0 ; i <= m ; i ++)
cout << ans[i] << ' ';
cout << '
';
return 0;
}