前言
菜就多练练。
题目
讲解
直接考虑 dp。
令 \(dp_{x,i}\) 表示 \(x\) 子树已经考虑完了,目前(下端在子树内)没处理的链的最深的上端深度为 \(i\) 。
对于其儿子 \(v\),有转移:
\[dp_{x,i}\leftarrow \sum_{j=0}^{depth_x} dp_{x,i}\times dp_{v,j}+\sum_{j=0}^{i} dp_{x,i}\times dp_{v,j}+\sum_{j=0}^{i-1}dp_{v,i}\times dp_{x,j}
\]
显然可以前缀和优化:
\[dp_{x,i}\leftarrow dp_{x,i}\times (pre_{v,depth_x}+pre_{v,i})+pre_{x,i-1}\times dp_{v,i}
\]
某个大佬说过(忘了是谁了),树上跟深度有关的东西都可以线段树合并,这道题也不例外。
对于 \(pre_{v,depth_x}\) 我们可以提前查好,而其它的可以在合并时先递归左子树,后递归右子树,在合并过程中可以顺便统计前缀和。
时空复杂度 \(O(n\log_2n)\)。
代码
洛谷rk2
//12252024832524
#include <bits/stdc++.h>
#define TT template<typename T>
using namespace std;
typedef long long LL;
const int MAXN = 500005;
const int MOD = 998244353;
int n;
LL Read()
{
LL x = 0,f = 1;char c = getchar();
while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
return x * f;
}
TT void Put1(T x)
{
if(x > 9) Put1(x/10);
putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
if(x < 0) putchar('-'),x = -x;
Put1(x); if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}
int head[MAXN],etot;
struct edge{
int v,nxt;
}e[MAXN<<1];
void Add_Edge(int u,int v){
e[++etot] = edge{v,head[u]};
head[u] = etot;
}
void Add_Double_Edge(int u,int v) {
Add_Edge(u,v);
Add_Edge(v,u);
}
#define lc (t[x].ch[0])
#define rc (t[x].ch[1])
int rt[MAXN],tot;
struct node{
int ch[2],s,mul;
}t[MAXN*40];
void calc(int x,int val){
if(!x) return;
t[x].mul = 1ll * t[x].mul * val % MOD;
t[x].s = 1ll * t[x].s * val % MOD;
}
void down(int x){
if(t[x].mul != 1){
calc(lc,t[x].mul);
calc(rc,t[x].mul);
t[x].mul = 1;
}
}
void Add(int &x,int l,int r,int pos){
x = ++tot; t[x].s = t[x].mul = 1;
if(l == r) return;
int mid = (l+r) >> 1; down(x);
if(pos <= mid) Add(lc,l,mid,pos);
else Add(rc,mid+1,r,pos);
}
/*
s1 <- s1 + dp[y][i]
dp'[x][i] <- dp[x][i]*s1 + dp[y][i]*s2
s2 <- s2 + dp[x][i]
*/
int mge(int x,int y,int l,int r,int &s1,int &s2){//新写法 get!
if(!x && !y) return 0;
if(!x || !y){
if(y){
s1 = (s1 + t[y].s) % MOD;
calc(y,s2);
return y;
}
else{
s2 = (s2 + t[x].s) % MOD;
calc(x,s1);
return x;
}
}
if(l == r){
s1 = (s1 + t[y].s) % MOD;
int tmp2 = t[x].s;
t[x].s = (1ll * t[x].s * s1 + 1ll * t[y].s * s2) % MOD;
s2 = (s2 + tmp2) % MOD;
return x;
}
down(x); down(y);
int mid = (l+r) >> 1;
lc = mge(lc,t[y].ch[0],l,mid,s1,s2);
rc = mge(rc,t[y].ch[1],mid+1,r,s1,s2);
t[x].s = (t[lc].s + t[rc].s) % MOD;
return x;
}
int Query(int x,int l,int r,int qr){
if(!x) return 0;
if(r <= qr) return t[x].s;
down(x);
int mid = (l+r) >> 1,ret = 0;
if(mid+1 <= qr) ret += Query(rc,mid+1,r,qr);
ret += Query(lc,l,mid,qr);
return ret % MOD;
}
int d[MAXN];
void dfs1(int x,int fa){
d[x] = d[fa] + 1;
for(int i = head[x],v; i ;i = e[i].nxt)
if((v = e[i].v) ^ fa) dfs1(v,x);
}
int MAX[MAXN];
void dfs2(int x){
Add(rt[x],0,n,MAX[x]);
for(int i = head[x],v; i ;i = e[i].nxt)
if(d[v = e[i].v] > d[x]){
dfs2(v);
int s1 = Query(rt[v],0,n,d[x]),s2 = 0;
rt[x] = mge(rt[x],rt[v],0,n,s1,s2);
}
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
n = Read();
for(int i = 1;i < n;++ i) Add_Double_Edge(Read(),Read());
dfs1(1,0);
for(int m = Read(); m ;-- m){
int u = Read(),v = Read();
MAX[v] = Max(MAX[v],d[u]);
}
dfs2(1);
Put(Query(rt[1],0,n,0),'\n');
return 0;
}