考虑树形dp
dp[ u ][ 0 ] 表示 u 这棵子树处理完, 不能向上延伸的方案数。
dp[ u ][ 1 ] 表示 u 这棵子树处理完, 必须向上延伸的方案数。
dp[ u ][ 2 ] 表示 u 这棵子树处理完, 可以向上延伸的方案数。
然后转移的时候细心一点就好了。
#include<bits/stdc++.h> using namespace std; const int N = (int)2e5 + 7; const int mod = 998244353; int n; vector<int> G[N]; int dp[N][3]; /** 0: can not up 1: must up 2: can up **/ void dfs(int u) { if((int)G[u].size() == 0) { dp[u][0] = 0; dp[u][1] = 0; dp[u][2] = 1; return; } for(auto &v : G[u]) { dfs(v); } int f[] = {1, 0, 0}; int g[] = {0, 0, 0}; for(auto &v : G[u]) { memcpy(g, f, sizeof(g)); f[0] = 1LL * g[0] * (dp[v][0] + dp[v][2]) % mod; f[1] = (1LL * g[1] * (dp[v][0] + dp[v][2]) % mod + 1LL * g[0] * (dp[v][1] + dp[v][2]) % mod) % mod; f[2] = (1LL * g[1] * (dp[v][1] + dp[v][2]) % mod + 1LL * g[2] * (dp[v][0] + dp[v][2]) % mod + 1LL * g[2] * dp[v][2] % mod + 1LL * g[2] * dp[v][1] % mod) % mod; } dp[u][0] = f[0]; dp[u][1] = f[1]; dp[u][2] = f[2]; } int main() { scanf("%d", &n); for(int i = 2; i <= n; i++) { int pa; scanf("%d", &pa); G[pa].push_back(i); } dfs(1); printf("%d ", (dp[1][0] + dp[1][2]) % mod); return 0; } /** **/