题意
给一棵(n)个点的树,求删去其中(k)条边再加入(k)条边后仍然是一颗树的方案数。
(2le nle 5 imes10^4,1le kle min(n-1,100))
题解
prufer序列
prufer序列是一种将一颗(n)个节点的有标号树用唯一的一个整数序列表示的方法。
一棵树的prufer序列构造过程如下:每次选择一个编号最小的叶结点并删掉它,然后在序列中记录下它连接到的那个结点编号。重复(n-2)次后就只剩下两个结点,算法结束。所以(n)个点的有标号树共有(n^{n-2})种。
树上每个点的出现次数等于其对应的prufer序列中该数字出现次数加一
首先考虑删去的(k)条边已经确定时的答案,即为在形成的(k+1)个连通块之间连边形成树的方案数。不妨再假设这(k+1)个联通块之间的连边情况也已经确定。记连通块(i)的度数为(s_i),在对应prufer序列中的出现次数为(P_i),则此时答案为(prod_{i=1}^{k+1}s_i^{d_i}=prod_{i=1}^{k+1}s_i^{P_i+1}=prod_{i=1}^{k+1}s_iprod_{i=1}^{k+1}s_i^{P_i}),其中(prod_{i=1}^{k+1}s_i)与连通块之间的连边情况无关。那么,(k+1)个连通块之间连边形成树的方案数为
[sum_{所有长度为k-1的prufer序列}prod_{i=1}^{k+1}s_iprod_{i=1}^{k+1}s_i^{P_i}=prod_{i=1}^{k+1}s_i(sum_{sum_{i=1}^{k+1}P_i=k-1}prod_{i=1}^{k+1}s_i^{P_i})
]
考虑后一个式子的组合意义,相当于一个长度为(k-1)的序列,每个位置可以填入(1...n)的任何数,所以
[sum_{sum_{i=1}^{k+1}P_i=k-1}prod_{i=1}^{k+1}s_i^{P_i}=n^{k-1}
]
OI-wiki上关于这个结论的证明
所以这道题的需要计算的答案为
[sum_{将原树分成k+1个连通块}n^{k-1}prod_{i=1}^{k+1}s_i=n^{k-1}(sum_{将原树分成k+1个连通块}prod_{i=1}^{k+1}s_i)
]
考虑后一个式子的组合意义,相当于将原树去掉(k)条边然后在每个连通块中选一个点的方案数,而这可以用dp计算,具体细节见代码。
#include <bits/stdc++.h>
#define pb(x) emplace_back(x)
using namespace std;
using ll=long long ;
const int N=50005;
const ll M=998244353;
int n,k,sz[N],b[N];
ll f[N][102][2],g[105][2];
void MOD(ll&x){x%=M;}
ll pm(ll x,ll b){x%=M;ll res=1;while(b){if(b&1)res=res*x%M;x=x*x%M;b>>=1;}return res;}
vector<int> e[N];
//f[u][i][0/1] 表示u的子树内删了i条边,u所在的联通块没选/选了点的方案数
void dfs(int u,int fa){
sz[u]=1;
f[u][0][1]=f[u][0][0]=1;
for(auto v:e[u])if(v!=fa){
dfs(v,u);
memset(g,0,sizeof(g));
for(int i=0;i<sz[u]&&i<=k;i++){//此时最多sz[u]-1条边
for(int j=0;j<sz[v]&&i+j<=k;j++){//最多sz[v]-1条边
//不删(u,v)这条边
g[i+j][0]+=f[u][i][0]*f[v][j][0]%M;
MOD(g[i+j][0]);
g[i+j][1]+=(f[u][i][1]*f[v][j][0]%M+f[u][i][0]*f[v][j][1]%M)%M;
MOD(g[i+j][1]);
if(i+j<k){
//删掉(u,v)这条边,此情况下v所在的联通块必须已经选点
g[i+j+1][0]+=f[u][i][0]*f[v][j][1]%M;
MOD(g[i+j+1][0]);
g[i+j+1][1]+=f[u][i][1]*f[v][j][1]%M;
MOD(g[i+j+1][1]);
}
}
}
memcpy(f[u],g,sizeof(g));
sz[u]+=sz[v];
}
}
void f1(){
cin>>n>>k;
for(int i=1;i<n;i++){
int x,y;cin>>x>>y;
e[x].pb(y);e[y].pb(x);
}
dfs(1,0);
ll ans=pm(n,k-1)*f[1][k][1]%M;
cout<<ans;
}
int main(){
f1();
return 0;
}