首先一开始有个误区,对于排列u的时候,把他的父节点和子节点放进来一起考虑了。
但是其实父节点到父节点的时候考虑,就是用u变成两种排列中另一种排列来考虑。
然后我们考虑一下合并子树:这里我们假设当前已经合并了的序列A大小为n,对于要合并的子树序列B为m
钦定根u在序列A的第i个位置,根v在序列B的第j个位置。
显然u和v只有u < v和u > v的情况,这里我们只需要算一种就可以了。最后方案 * 2即可。
我们计算u > v。
现在我们钦定ai即A的根u在排列后的序列的第i + j个位置。
枚举ai的原位置来转移,即a 在原序列第i个位置。
那么对于序列的排列情况,我们以ai为分解划分为:
前半部分即为A{a1...ai - 1} II B{b1.b2...bj}.
考虑求这个序列的组合方案:因为A,B序列的内部情况我们已经处理出来了,所以我们可以当成无差别的元素来计算方案即C(i - 1 + j,i - 1),即a序列插入位置中,然后乘上两序列的方案数dp.
对于后半部分A{ai + 1,ai + 2...an} II B{bj + 1,b2....bm},方案同理前面C(n - i +m - j,n - i)
因为我们计算的是u > v,那么对于u = ai,v = b1 ~ bj的情况都满足这个限制条件.
所以对于每个i方案为C(i - 1 + j,i - 1) * C(n - i +m - j,n - i) * dpA[i] * (dpB[1] + dpB[2] + ... dpB[j])
显然对于dpB可以前缀和预处理一下,然后我们就可以枚举i,j来解决这个问题。
// Author: levil #include<bits/stdc++.h> using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef long double ld; typedef pair<int,int> pii; const int N = 3e3 + 5; const int M = 1e4 + 5; const LL Mod = 998244353; #define INF 1e9 #define dbg(ax) cout << "now this num is " << ax << endl; inline long long ADD(long long x,long long y) { if(x + y < 0) return ((x + y) % Mod + Mod) % Mod; return (x + y) % Mod; } inline long long MUL(long long x,long long y) { if(x * y < 0) return ((x * y) % Mod + Mod) % Mod; return x * y % Mod; } inline long long DEC(long long x,long long y) { if(x - y < 0) return (x - y + Mod) % Mod; return (x - y) % Mod; } LL fac[N],dp[N][N],inv[N],pre[N],f[N];//dp[i][j] - i在序列第j个位置 vector<int> G[N]; int n,sz[N]; LL quick_mi(LL a,LL b) { LL re = 1; while(b) { if(b & 1) re = re * a % Mod; a = a * a % Mod; b >>= 1; } return re; } void init() { fac[0] = 1; for(int i = 1;i < N;++i) fac[i] = fac[i - 1] * i % Mod; inv[N - 1] = quick_mi(fac[N - 1],Mod - 2) % Mod; for(int i = N - 2;i >= 0;--i) inv[i] = inv[i + 1] * (i + 1) % Mod; } LL C(int n,int m) { return fac[n] * inv[m] % Mod * inv[n - m] % Mod; } void dfs(int u,int fa) { sz[u] = 1; dp[u][1] = 1; for(auto v : G[u]) { if(v == fa) continue; dfs(v,u); memset(pre,0,sizeof(pre)); memset(f,0,sizeof(f)); for(int i = 1;i <= sz[v];++i) pre[i] = ADD(pre[i - 1],dp[v][sz[v] - i + 1]); for(int i = 1;i <= sz[u];++i) { for(int j = 1;j <= sz[v];++j) { LL tmp1 = C(i - 1 + j,i - 1),tmp2 = C(sz[u] - i + sz[v] - j,sz[u] - i); f[i + j] = ADD(f[i + j],MUL(MUL(MUL(C(i - 1 + j,i - 1),C(sz[u] - i + sz[v] - j,sz[u] - i)),dp[u][i]),pre[j])); } } sz[u] += sz[v]; for(int i = 1;i <= sz[u];++i) dp[u][i] = f[i]; } } void solve() { init(); scanf("%d",&n); for(int i = 1;i < n;++i) { int u,v;scanf("%d %d",&u,&v); G[u].push_back(v); G[v].push_back(u); } dfs(1,0); LL ans = 0; for(int i = 1;i <= sz[1];++i) ans = ADD(ans,dp[1][i]); printf("%lld\n",MUL(2,ans)); } int main() { //int _; //for(scanf("%d",&_);_;_--) { solve(); //} //system("pause"); return 0; }