参考了官方题解的做法。
第 3 个限制本质要求是什么?
或者说,什么样的树可以满足第 3 个限制?
- 叶向树
- 根向树
- 一棵根向树 + 一棵叶向树 + 一条边 拼起来的树
为第三种情况画了一幅直观的图:
用 (f_i) 表示 无标号、有根、边同向(即都为叶向或都为根向,但此时只算一种)、( ext{deg}(root) le 2)、最长链为 (i) 的树的数量。
转移式为:
第一个黑色的部分表示的是 ( ext{deg}(root) = 1) 的情况,这时候直接取绝于第二层子树的形态。
第二个蓝色的部分表示的是 ( ext{deg}(root) = 2),且两边子树的深度不相等的情况。显然有一边必须是深度 (=i-1),另一边深度 (<i-1),直接乘起来就好了。
第三个红色的部分表示的是 ( ext{deg}(root) = 2),且两边子树深度均为 (i - 1) 的情况。注意到儿子不考虑顺序,所以对于方案 ((a, b)) 和 ((b, a)) 应当被认为是相同的,那么答案就是 (inom{f_{i - 1}}{2} + f_{i - 1} = frac{f_{i - 1}(f_{i - 1} + 1)}{2} = inom{f_{i - 1} + 1}{2}),也就是要么两边方案不同,有 (inom{f_{i - 1}}{2}) 种;要么两边方案相同,有 (f_{i - 1}) 种。
可以利用前缀和,在 (O(n)) 的时间内预处理出 (f_i)。
用 (f'_i) 表示除了强制 ( ext{deg}(root) = 2) 外,其他限制和上面一样 的方案数。显然就是去掉了第一项,可以通过差分得到:(f_i' = f_i - f_{i - 1})。
有了 (f_i) 和 (f'_i) 后,我们考虑回答询问,先计算前两种情况(根向/叶向树),注意 ( ext{deg}(root)) 可以为 (1/2/3)。
本质上也就是分类讨论 (root) 的各个子树的情况,推导和上面 (f) 的递推式是相似的。
至于 ( imes 2) 是因为,当前 (f) 没有对边定向,而我们必须钦定根向 / 叶向。
而 (-1) 是因为当树的形态为一条链时,两者无区别。
接下来计算第三种情况,我们枚举两边的树的深度 (i) 和 (n-i-1),得到:
红边为“关键边”,而黄边不能为“关键边”(为了防止重复计数),而 (i) 指的是 ( ext{A}) 部分的深度。所以发现 ( ext{A}) 部分是一棵 ( ext{deg}(root) le 2) 的非链根向树(如果为链的话,整棵树退化为一棵叶向树),方案数为 (f_i - 1);而 ( ext{B}) 部分是一棵 ( ext{deg}(root) = 2) 的叶向树(因为红边一定是最靠右的),方案数为 (f_{n - i - 1}')。根据乘法原理乘起来即可。
两部分加起来即是答案。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 5, MOD = 998244353;
typedef long long ll;
int n, ans, f[N], pref[N];
int C2(ll x) { return (ll)x * (x - 1) % MOD * 499122177 % MOD; }
int C3(ll x) { return (ll)x * (x - 1) % MOD * (x - 2) % MOD * 166374059 % MOD; }
int main()
{
ios::sync_with_stdio(false);
cin >> n;
f[0] = 1; f[1] = 2; pref[0] = 1; pref[1] = 3;
for(int i = 2; i <= n; i++)
{
f[i] = (f[i - 1] + (ll)f[i - 1] * pref[i - 2] % MOD + C2(f[i - 1] + 1)) % MOD;
pref[i] = (pref[i - 1] + f[i]) % MOD;
}
ans = (ans + 2ll * f[n] - 1 + MOD) % MOD;
if(n >= 1) ans = (ans + 2ll * C3(f[n - 1] + 2)) % MOD;
if(n >= 2) ans = (ans + 2ll * ((ll)f[n - 1] * C2(pref[n - 2] + 1) + (ll)pref[n - 2] * C2(f[n - 1] + 1))) % MOD;
for(int i = 0; i < n; i++) ans = (ans + (ll)(f[i] + MOD - 1) % MOD * ((n - i - 1 >= 1) ? (f[n - i - 1] - f[n - i - 2] + MOD) % MOD : 0) % MOD) % MOD;
cout << ans << endl;
return 0;
}