(考场时只想到暴力容斥,24分滚了)
题目:https://loj.ac/problem/3340#submit_code
DP方程怎么来就不写了,本文重点分析如何用线段树合并及正确性
DP方程为:令dp[u][h]表示u的子树中上端点已经处理好了,下段点在子树中的,上段点最深为h的方案数,特别的,如果h==0,则代表已经处理好了
每当遇到一个儿子v时都更新一遍dp数组
dp’[u][h] = dp[v][h] * sum[u][h-1] + dp[u][h] * (sum[v][h] + sum[v][dep[u]);
考虑线段树合并(以下子树)
我们可以便遍历边更新各式的值
我们开n颗线段树,下标 i 维护dp[u][i],维护区间和,维护区间乘积
不妨令S1 = sum[u][h-1],S2 = (sum[v][h] + sum[v][dep[u][)
在线段树合并的同时更新S1,S2
我们只要讨论4种情况
(I) (!U && !V)很显然这种情况可以直接返回
(II) (!U || !V)
即合并中有其中一颗树没有子树了,反映在DP方程其实就是这一整块都没值了,都不合法
这就意味着U或者V(没有值的那个)的sum在这段区间不会再变了,变成常量,于是就可以愉快地区间乘法了
(III)(l == r)
这就意味遍历到叶子节点,这是最简单的,直接套公式就可以了
(IV)
otherwise
即既有左子树又有右子树
那我们只需要遍历左子树完遍历右子树即可
为什么是对的呢
因为我们处理左子树的答案,S1,S2已经被左子树的前缀和更新了,在处理右子树的答案时,同时考虑了左边的贡献
类似于CDQ分治的思想
然后就做完了(本蒟觉得还是一道很妙的题)
代码如下
/*命运*/ #include<cstdio> #include<iostream> #include<cstring> #include<algorithm> using namespace std; #define ll long long #define mod 998244353 const int maxn = 5e5 + 10; int Add(int x,int y){ x += y; return (x >= mod)?x - mod:x; } int rt[maxn],h[maxn]; struct SegmentTree{ int lc,rc; ll lzy; ll sum; #define lc(p) t[p].lc #define rc(p) t[p].rc #define sum(p) t[p].sum #define lzy(p) t[p].lzy }t[maxn<<5]; int cnt = 0; void pushup(int p){ sum(p) = (sum(lc(p)) + sum(rc(p))) % mod; } void pushdown(int p){ if(lzy(p) == 1) return; lzy(lc(p)) = lzy(lc(p)) * lzy(p) % mod; sum(lc(p)) = lzy(p) * sum(lc(p)) % mod; lzy(rc(p)) = lzy(rc(p)) * lzy(p) % mod; sum(rc(p)) = lzy(p) * sum(rc(p)) % mod; lzy(p) = 1; } void Ins(int &p,int l,int r,int pos,int v){ if(!p) p = ++cnt; t[p].lzy = t[p].sum = v; if(l == r){ return; } int mid = (l + r) >> 1; if(pos <= mid) Ins(lc(p),l,mid,pos,v); else Ins(rc(p),mid+1,r,pos,v); //pushup(p); } ll query(int p,int l,int r,int a,int b){ if(a <= l && b >= r){ return sum(p); } pushdown(p); ll ans = 0; int mid = (l + r) >> 1; if(a <= mid) ans = Add(ans,query(lc(p),l,mid,a,b)); if(b > mid) ans = Add(ans,query(rc(p),mid+1,r,a,b)); return ans; } int merge(int u,int v,int l,int r,ll &S1,ll &S2){ if(!u && !v) return 0; if(!u || !v){ if(!u){ S2 = Add(S2,sum(v)); sum(v) = sum(v) * S1 % mod; lzy(v) = lzy(v) * S1 % mod; return v; } S1 = Add(S1,sum(u)); sum(u) = sum(u) * S2 % mod,lzy(u) = lzy(u) * S2 % mod; return u; } if(l == r){ S2 = Add(S2,sum(v)); ll idu = sum(u); sum(u) = Add(S1 * sum(v) % mod,S2 * sum(u) % mod); S1 = Add(S1,idu); return u; } pushdown(u),pushdown(v); int mid = (l + r) >> 1; lc(u) = merge(lc(u),lc(v),l,mid,S1,S2); rc(u) = merge(rc(u),rc(v),mid+1,r,S1,S2); pushup(u); return u; } int read(){ char c = getchar(); int x = 0; while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + c - 48,c = getchar(); return x; } struct Edge{ int nxt,point; }edge[maxn*2]; int tot = 0; int dep[maxn]; int n; int head[maxn]; void add(int x,int y){ edge[++tot].nxt = head[x]; edge[tot].point = y; head[x] = tot; } void Dfs(int x,int fa){ dep[x] = dep[fa] + 1; for(int i = head[x]; i ; i = edge[i].nxt){ int y = edge[i].point; if(y == fa) continue; Dfs(y,x); } } void TreeDP(int u,int fa){ Ins(rt[u],0,n,h[u],1); for(int i = head[u]; i ; i = edge[i].nxt){ int v = edge[i].point; if(v == fa) continue; TreeDP(v,u); ll S1 = 0,S2 = query(rt[v],0,n,0,dep[u]); rt[u] = merge(rt[u],rt[v],0,n,S1,S2); } } int main() { freopen("destiny.in","r",stdin) ; freopen("destiny.out","w",stdout); n = read(); for(int i = 1; i < n; ++i){ int x = read(),y = read(); add(x,y); add(y,x); } Dfs(1,0); int m = read(); for(int i = 1; i <= m; ++i){ int u = read(),v = read(); if(!h[v]) h[v] = dep[u]; else h[v] = max(h[v],dep[u]); } TreeDP(1,0); printf("%lld ",query(rt[1],0,n,0,0)); return 0; }