首先能观察出, 只有完美匹配的情况下方案数唯一。
dp[ i ][ 0 ], dp[ i ][ 1 ], dp[ i ][ 2 ] 分别表示
对于 i 这棵子树 0: 不向上连边完成 1:向上连边完成 2:向上连边未完成 的方案数
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 3e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 998244353; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;} int n; vector<int> G[N]; LL power(LL a, LL b) { LL ans = 1; while(b) { if(b & 1) ans = ans * a % mod; a = a * a % mod; b >>= 1; } return ans; } LL dp[N][3]; void go(int u, int fa) { int pos = -1; for(int i = 0; i < SZ(G[u]); i++) { if(G[u][i] == fa) { pos = i; continue; } go(G[u][i], u); } if(~pos) { swap(G[u][pos], G[u][SZ(G[u]) - 1]); G[u].pop_back(); } } void dfs(int u) { dp[u][0] = 1; dp[u][1] = 0; dp[u][2] = 1; if(!SZ(G[u])) return; for(auto& v : G[u]) dfs(v); int cnts = SZ(G[u]); vector<LL> prefix[3]; for(int i = 0; i < 3; i++) { prefix[i].resize(cnts); for(int j = 0; j < cnts; j++) { int v = G[u][j]; if(!j) prefix[i][j] = dp[v][i]; else prefix[i][j] = prefix[i][j - 1] * dp[v][i] % mod; } } vector<LL> prefix01(cnts); vector<LL> suffix01(cnts); for(int i = 0; i < cnts; i++) { int v = G[u][i]; if(!i) prefix01[i] = (dp[v][0] + dp[v][1]) % mod; else prefix01[i] = prefix01[i - 1] * (dp[v][0] + dp[v][1]) % mod; } for(int i = cnts - 1; i >= 0; i--) { int v = G[u][i]; if(i == cnts - 1) suffix01[i] = (dp[v][0] + dp[v][1]) % mod; else suffix01[i] = suffix01[i + 1] * (dp[v][0] + dp[v][1]) % mod; } // 0: 不向上连边完成 1:向上连边完成 2:向上连边未完成 dp[u][0] = prefix[0][cnts - 1]; for(int i = 0; i < cnts; i++) { int v = G[u][i]; LL tmp = dp[v][2]; if(i - 1 >= 0) tmp = tmp * prefix01[i - 1] % mod; if(i + 1 < cnts) tmp = tmp * suffix01[i + 1] % mod; add(dp[u][0], tmp); add(dp[u][1], tmp); } dp[u][2] = prefix01[cnts - 1]; } int main() { scanf("%d", &n); for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } go(1, 0); dfs(1); printf("%lld ", dp[1][0]); return 0; } /* */