矩阵树
题目描述
给定一个 \(n\) 个点的无向完全图,其中边 \((i,j)\) 的个数是 \(a(i,j)\);有 \(k\) 个要求,第 \(i\) 个要求是点集 \(S\) 的导出子图要连通,问满足条件的生成树个数,答案对 \(998244353\) 取模。
\(n\leq 500,k\leq 2000\)
解法
如果没有限制就是裸的矩阵树定理,这其实我们往矩阵树的方向思考。
首先观察 \(S\) 的导出子图连通有什么性质,我们可以将限制转化到边上,那么就相当于有恰好 \(|S|-1\) 条边(忽略 \(|S|=0\) 的情况),满足其的两个端点都在 \(S\) 中。并且有一个关键的 \(\tt observation\):无论合法还是不合法的情况,这样的边数最多有 \(|S|-1\) 个。
那么说明本题的判据跟最值有一定关联了,那么对于全部的 \(k\) 个限制,设 \(w(i,j)\) 表示边 \((i,j)\) 的两个端点都出现在了多少个 \(S\) 中,那么只有最大生成树才可能成为答案。
可以用 \(\tt bitset\) 以 \(O(\frac{n^2k}{w})\) 的时间求出每条边的边权,最大生成树计数是经典问题。由于每种边权的数量固定,对于每种边权的每个连通块,我们单独跑矩阵树定理,限制好矩阵大小时间复杂度就是 \(O(n^3)\) 的。
#include <cstdio>
#include <bitset>
#include <iostream>
#include <algorithm>
using namespace std;
const int M = 505;
const int N = 2005;
const int MOD = 998244353;
#define int long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,k,sum,ans,a[M][M],fa[M],fn[M],vis[M],id[M];
bitset<N> g[M];char s[N];
struct edge{int u,v,c;};vector<edge> e[N],G[M];
struct matrix
{
int n,a[M][M];
void clear()
{
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
a[i][j]=0;
n=0;
}
void add(int u,int v,int c)
{
a[u][u]=(a[u][u]+c)%MOD;
a[v][v]=(a[v][v]+c)%MOD;
a[u][v]=(a[u][v]+MOD-c)%MOD;
a[v][u]=(a[v][u]+MOD-c)%MOD;
}
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
int gauss()
{
int ans=1;
for(int i=2;i<=n;i++)
{
for(int j=i+1;j<=n;j++)
if(!a[i][i] && a[j][i])
{
ans=MOD-ans;
swap(a[i],a[j]);
break;
}
ans=ans*a[i][i]%MOD;
int inv=qkpow(a[i][i],MOD-2);
for(int j=i+1;j<=n;j++)
{
int tmp=a[j][i]*inv%MOD;
for(int k=i;k<=n;k++)
a[j][k]=(a[j][k]-a[i][k]*tmp
%MOD+MOD)%MOD;
}
}
return ans;
}
}z;
int find(int x)
{
if(x==fa[x]) return x;
return fa[x]=find(fa[x]);
}
int zxy(int x)
{
if(x==fn[x]) return x;
return fn[x]=zxy(fn[x]);
}
void kruskal()
{
ans=1;
for(int i=1;i<=n;i++) fa[i]=fn[i]=i;
for(int i=k;i>=0;i--) if(e[i].size())
{
for(edge x:e[i])
fn[zxy(x.u)]=fn[zxy(x.v)];
for(edge x:e[i])
{
int u=find(x.u),v=find(x.v);
if(u^v) G[zxy(u)].push_back(x);
}
for(int R=1;R<=n;R++) if(G[R].size())
{
for(int j=1;j<=n;j++) vis[j]=id[j]=0;
int m=0;
for(edge x:G[R])
{
int u=find(x.u),v=find(x.v);
if(!vis[u]) vis[u]=1,id[u]=++m;
if(!vis[v]) vis[v]=1,id[v]=++m;
z.add(id[u],id[v],x.c);
}
z.n=m;
ans=ans*z.gauss()%MOD;
z.clear();
G[R].clear();
}
for(edge x:e[i])
{
int u=find(x.u),v=find(x.v);
if(u^v) sum-=i,fa[u]=v;
}
}
printf("%lld\n",(sum==0)?ans:0);
}
signed main()
{
freopen("treecnt.in","r",stdin);
freopen("treecnt.out","w",stdout);
n=read();k=read();
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
a[i][j]=read();
for(int i=1;i<=k;i++)
{
scanf("%s",s+1);
int fl=0;
for(int j=1;j<=n;j++) if(s[j]=='1')
g[j][i]=1,fl=1,sum++;
sum-=fl;
}
for(int i=1;i<=n;i++)
for(int j=i+1;j<=n;j++)
{
int w=(g[i]&g[j]).count();
e[w].push_back({i,j,a[i][j]});
}
kruskal();
}