Link
一条Hamilton回路可以被拆分成由若干条树上路径组成的环,其中相邻两条树上路径不能同属于一棵树。
假如我们求出了将一棵树分为若干链的方案数,那么剩下的就是求给环染色,相邻位置颜色不同的方案数。
第一部分可以用树形背包简单求出,设(f_{u,i,0/1/2})表示(u)的子树内选了(i)条链,(u)的状态为:在一条进入了两个不同的子树的链上/只有一个单点/在一条至少有两个点的单链上,转移分类讨论一下就行了。(一条长度不小于(2)的链应该算两遍!!1)
现在我们已经求出了(f_k)表示选出(k)条链的带权方案数,直接EGF组合的话会出现环上相邻两条链属于同一棵树的情况。
考虑容斥,钦定最后环上属于该棵树的链构成(j)个极长连续段,那么(f_k)对该项的贡献为((-1)^{k-j}{k-1choose j-1}f_kk!),所以EGF为
[widehat f(x)=sumlimits_{i=1}^nf_ii!sumlimits_{j=1}^i(-1)^{i-j}{i-1choose j-1}frac{x^j}{j!}
]
对于第一棵树而言,限制(1)必须是环的开头,同时首尾不能相同,推完式子可以发现EGF恰好是(frac{widehat f(x)}x)。
最后把所有EGF乘起来就可以得到答案的EGF了。
注意这个做法在(m=1)的时候有问题的。
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using i64=long long;
const int N=5007,P=998244353;
int size[N];i64 s[N][N],f[N][N][3];std::vector<int>e[N];
int read(){int x;scanf("%d",&x);return x;}
void inc(i64&a,i64 b){a+=b-P,a+=a>>63&P;}
void dec(i64&a,i64 b){a-=b,a+=a>>63&P;}
struct poly{int deg;i64 a[N];poly(int n){deg=n,memset(a,0,8*n+8);}i64&operator[](const int&x){return a[x];}};
poly operator*(poly f,poly g)
{
poly a(f.deg+g.deg);
for(int i=0;i<=f.deg;++i) for(int j=0;j<=g.deg;++j) inc(a[i+j],f[i]*g[j]%P);
return a;
}
void dfs(int u,int fa)
{
static i64 t[N][3];
size[u]=1,memset(f[u],0,48),f[u][0][1]=1;
for(int v:e[u])
{
if(v==fa) continue;
dfs(v,u),memset(t,0,24*(size[u]+size[v]+1));
for(int i=0;i<=size[u];++i)
for(int j=0;j<=size[v];++j)
{
for(int k=0;k<3;++k) inc(t[i+j][k],f[u][i][k]*f[v][j][0]%P);
inc(t[i+j][2],f[u][i][1]*f[v][j][1]%P),inc(t[i+j+1][0],2*f[u][i][2]*f[v][j][1]%P);
}
size[u]+=size[v],memcpy(f[u],t,24*(size[u]+1));
}
for(int i=0;i<=size[u];++i) (f[u][i+1][0]+=2*f[u][i][2]+f[u][i][1])%=P,inc(f[u][i][1],f[u][i][2]);
}
poly solve()
{
int n=read();poly a(n);
for(int i=1;i<=n;++i) e[i].clear();
for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
dfs(1,0);
for(int i=1;i<=n;++i) a[i]=f[1][i][0];
for(int i=n;i;--i) for(int j=i+1;j<=n;++j) dec(a[i],a[j]*s[j][i]%P);
return a;
}
int main()
{
int m=read();poly ans(0);i64 sum=0,fac=1;ans[0]=s[1][1]=1;
for(int i=2;i<=5000;++i) for(int j=1;j<=i;++j) s[i][j]=(s[i-1][j-1]+(i-1+j)*s[i-1][j])%P;
for(int i=1;i<=m;++i) ans=ans*solve();
for(int i=1;i<=ans.deg;++i) inc(sum,ans[i]*fac%P),fac=fac*i%P;
printf("%lld",sum);
}