「HNOI2018」毒瘤
解题思路
先考虑只有一棵树的情况,经典独立集计数。
[dp[u][0]=prod (dp[v][0]+dp[v][1]) \
dp[u][1]=prod dp[v][0]
]
然后考虑将所有非树边的端点建一棵虚树,那么虚树以外的节点的 ( ext{dp}) 值是不会改变的,那么就可以推出虚树上一个节点对它父亲贡献的系数。
然后枚举一下所有非树边能选取的合法状态,再在虚树上计算一遍贡献,令 (k = m-n+1),这样复杂度是 (mathcal O(k3^k+m)) 。
事实上只需要枚举每一条非树边的左端点是否选,当左端点选的时候,右端点只能不选,否则右端点可选可不选,这样涵盖了所有三种合法情况,复杂度 (mathcal O(k2^k+m)) 。
code
/*program by mangoyang*/
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
#define int ll
const int N = 300005, mod = 998244353;
vector<pair<int, int> > ed;
vector<int> g[N], e[N], vec;
int dfn[N], dep[N], fa[N], s[N][2], dp[N][2], dp2[N][2], ali[N][2];
int tr[N][2][2], a[N], b[N], in[N], pa[N], st[N], n, m, ans;
inline int ask(int x){
if(x == fa[x]) return x; else return fa[x] = ask(fa[x]);
}
namespace PR{
int Log[N], f[N][22], tot;
inline int chkmin(int x, int y){ return dep[x] < dep[y] ? x : y; }
inline void dfs(int u, int fa){
dep[u] = dep[fa] + 1, dfn[u] = ++tot, f[tot][0] = u;
dp[u][0] = dp[u][1] = 1;
for(int i = 0; i < (int) g[u].size(); i++){
int v = g[u][i];
if(v == fa) continue;
dfs(v, u), f[++tot][0] = u;
(dp[u][1] *= dp[v][0]) %= mod;
(dp[u][0] *= (dp[v][1] + dp[v][0]) % mod) %= mod;
}
}
inline void solve(){
dfs(1, 0);
for(int i = 2; i <= tot; i++) Log[i] = Log[i>>1] + 1;
for(int j = 1; j <= 21; j++)
for(int i = 1; i + (1 << j) - 1 <= tot; i++)
f[i][j] = chkmin(f[i][j-1], f[i+(1<<(j-1))][j-1]);
}
inline int Lca(int u, int v){
int x = dfn[u], y = dfn[v];
if(x > y) swap(x, y); int g = Log[y-x+1];
return chkmin(f[x][g], f[y-(1<<g)+1][g]);
}
}
inline bool cmp(int x, int y){ return dfn[x] < dfn[y]; }
inline void buildtree(int a[], int len){
sort(a + 1, a + len + 1, cmp); int top = 0, tot = 0;
for(int i = 1; i <= len; i++){
int u = a[i];
if(!top){ st[++top] = b[++tot] = u; continue; }
int ca = PR::Lca(u, st[top]);
for(; top > 1 && dep[st[top]] > dep[ca]; top--)
if(dep[st[top-1]] < dep[ca]) pa[st[top]] = ca;
if(st[top] != ca)
pa[ca] = st[top], st[++top] = b[++tot] = ca;
pa[u] = ca, st[++top] = b[++tot] = u;
}
for(int i = 1; i <= tot; i++){
in[b[i]] = 1, e[pa[b[i]]].push_back(b[i]);
ali[b[i]][0] = ali[b[i]][1] = 1;
}
}
inline int dfs(int u, int fa){
int x = 0, k0 = 1, k1 = 1;
for(int i = 0; i < (int) g[u].size(); i++){
int v = g[u][i];
if(v == fa) continue;
int tmp = dfs(v, u);
if(tmp) x = tmp;
else (k0 *= dp[v][0]) %= mod, (k1 *= (dp[v][0] + dp[v][1]) % mod) %= mod;
}
s[u][0] = k1, s[u][1] = k0;
if(in[u]) return tr[u][0][0] = 1, tr[u][1][1] = 1, u;
if(!x) return 0;
int tmp[2][2];
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++) tmp[i][j] = tr[x][i][j];
tr[x][0][0] = (tmp[0][0] + tmp[0][1]) % mod * k1 % mod;
tr[x][0][1] = tmp[0][0] * k0 % mod;
tr[x][1][0] = (tmp[1][0] + tmp[1][1]) % mod * k1 % mod;
tr[x][1][1] = tmp[1][0] * k0 % mod;
return x;
}
inline void dfs2(int u){
for(int i = 0; i < 2; i++) dp2[u][i] = ali[u][i] * s[u][i];
for(int i = 0; i < (int) e[u].size(); i++){
int v = e[u][i];
dfs2(v);
int k0 = (dp2[v][0] * tr[v][0][0] % mod + dp2[v][1] * tr[v][1][0] % mod) % mod;
int k1 = (dp2[v][0] * tr[v][0][1] % mod + dp2[v][1] * tr[v][1][1] % mod) % mod;
(dp2[u][0] *= (k0 + k1) % mod) %= mod, (dp2[u][1] *= k0) %= mod;
}
}
inline void solve(int mask){
for(int i = 0; i < (int) vec.size(); i++)
ali[vec[i]][0] = ali[vec[i]][1] = 1;
for(int i = 0; i < (int) ed.size(); i++){
int x = ed[i].first, y = ed[i].second;
int tmp = (1 << i) & mask;
if(tmp) ali[vec[x]][0] = ali[vec[y]][1] = 0; else ali[vec[x]][1] = 0;
}
dfs2(1), (ans += dp2[1][0] + dp2[1][1]) %= mod;
}
signed main(){
int len = 0;
read(n), read(m);
for(int i = 1; i <= n; i++) fa[i] = i;
for(int i = 1, x, y; i <= m; i++){
read(x), read(y);
if(ask(x) == ask(y)){
ed.push_back(make_pair(x, y));
vec.push_back(x), vec.push_back(y);
}
else{
fa[ask(x)] = ask(y);
g[x].push_back(y), g[y].push_back(x);
}
}
PR::solve();
if(m == n - 1) return cout << (dp[1][0] + dp[1][1]) % mod, 0;
sort(vec.begin(), vec.end());
vector<int>::iterator newend = unique(vec.begin(), vec.end());
vec.erase(newend, vec.end());
for(int i = 0; i < (int) vec.size(); i++) a[++len] = vec[i];
if(!len || vec[0] != 1) a[++len] = 1;
buildtree(a, len), dfs(1, 0);
for(int i = 0; i < (int) ed.size(); i++){
int x = lower_bound(vec.begin(), vec.end(), ed[i].first) - vec.begin();
int y = lower_bound(vec.begin(), vec.end(), ed[i].second) - vec.begin();
ed[i] = make_pair(x, y);
}
for(int i = 0; i < (1 << (m - n + 1)); i++) solve(i);
cout << ans << endl;
return 0;
}