https://atcoder.jp/contests/arc087/tasks/arc087_d
题解
作为3200难度的题是不是有点水。。。
考虑每条边最多会被经过多少次,假设一条边 (e) 两端的子树大小分别是 (a_e,b_e),那么它最多被经过 (2*min(a_e,b_e)) 次
所以总距离的上界即为 (sumlimits_{ein E} 2*min(a_e,b_e))
下面通过计算总距离取到这个上界的方案数来说明这个上界一定能被取到
考虑在需要保证能取到上界的情况下,对排列 (p) 有什么限制:
对于任意一条边,将其断开后形成两棵子树,则对于较小的那棵子树里的每个点 (x),必须满足 (p_x) 在较大那棵子树里面,这样当 (x) 走到 (p_x) 的时候才会经过这条边
这启发我们找到重心作为根,以重心为根的每棵子树的点数一定小于等于这棵子树外的点数
所以此时的限制就变为,找到重心当根,对于根的每个儿子的子树,这棵子树中的每个点 (x),它的 (p_x) 必须在这棵子树外
情况1
如果这棵树有两个重心,那么将连接两个重心的那条边 (e) 断开形成两棵树,那么 (x) 和 (p_x) 必须不在同一棵树中,这样才能让 (e) 取到经过次数的上界
此时方案数就是 ((dfrac{n}{2}!)^2)
情况2
如果这棵树只有一个重心,将这个重心作为根,考虑它的每个儿子的子树,子树内的点需要满足上述限制
不妨计算有 (k) 个点不满足限制的方案数:
如果一个点 (x) 不满足限制,那么就是说 (p_x) 和它在同一棵子树里
假设根的某个儿子的子树大小为 (m),那么 (m) 个点中有 (k) 个不满足限制的方案数即为 (inom{m}{k}*m^{underline k})
使用背包 (O(n^2)) 合并所有子树的方案,记共有 (k) 个点不满足限制的方案数为 (f_k),那么剩下 (n-k) 个点对于的 (p) 任选,容斥一下,那么所有点都满足限制的方案数即为
#include <bits/stdc++.h>
#define N 5005
#define pb push_back
using namespace std;
const int mod = 1000000007;
inline int qmod(int x) { return x<mod?x:x-mod; }
inline int fpow(int x, int t) { int r=1;for(;t;t>>=1,x=1ll*x*x%mod)if(t&1)r=1ll*r*x%mod;return r; }
int n, siz[N], MN, RT, fac[N], Inv[N], f[N], g[N];
inline int C(int p, int q) { return p<q?0:1ll*fac[p]*Inv[q]%mod*Inv[p-q]%mod; }
vector<int> E[N];
void findrt(int x, int fa) {
siz[x] = 1; int mx = 0;
for (auto y : E[x]) if (y != fa) {
findrt(y, x); mx = max(mx, siz[y]);
siz[x] += siz[y];
}
mx = max(mx, n-siz[x]);
if (MN > mx) MN = mx, RT = x;
}
inline int calc(int m, int i) {
return 1ll*fac[m]*Inv[m-i]%mod*C(m,i)%mod;
}
int main() {
scanf("%d", &n);
fac[0] = Inv[0] = 1;
for (int i = 1; i <= n; i++) fac[i] = 1ll*fac[i-1]*i%mod;
Inv[n] = fpow(fac[n], mod-2);
for (int i = n-1; i; i--) Inv[i] = 1ll*Inv[i+1]*(i+1)%mod;
for (int i = 1, u, v; i < n; i++) {
scanf("%d %d", &u, &v);
E[u].pb(v); E[v].pb(u);
}
MN = 0x3f3f3f3f;
findrt(1, 0); findrt(RT, 0);
f[0] = 1; int cnt = 0, flg = 0;
for (auto x : E[RT]) {
memset(g, 0, sizeof(g)); int m = siz[x];
for (int i = 0; i <= cnt; i++) for (int j = 0; j <= m; j++) {
g[i+j] = qmod(g[i+j]+1ll*f[i]*calc(m,j)%mod);
}
memcpy(f, g, sizeof(f)); cnt += m;
if (n % 2 == 0 && m == n/2) flg = 1;
}
if (flg) { printf("%lld
", 1ll*fac[n/2]*fac[n/2]%mod); return 0; }
int ans = 0;
for (int i = 0; i <= n; i++) {
f[i] = 1ll*f[i]*fac[n-i]%mod;
if (i & 1) ans = qmod(ans+mod-f[i]);
else ans = qmod(ans+f[i]);
}
printf("%d
", ans);
return 0;
}