题面
题解
CF1327F AND Segments + 整体 dp。
首先预处理 (mathrm{pre}_i) 表示向上最深的 (f(e) = 1) 的边的深度最小值。
设 (f_{i, j}) 表示当前在点 (i),最深的 (f(e) = 1) 的深度为 (j) 的方案数。
枚举点 (i) 和儿子之间的边是否设成 (1),有:
[f_{i, j} = [mathrm{pre}_i < j leq mathrm{dep}_i]prod_{s in mathrm{son}(i)} (f_{s,j} + f_{s, mathrm {dep}_s})
]
其中 (f_{s, mathrm{dep}_s}) 表示将 (s) 和它父亲的边设成 (1) 的方案数。
考虑用线段树合并维护,那么只需要维护区间加,区间乘和区间赋值,维护一个标记 ((a, b)) 使得 (f stackrel{(a, b)}{longrightarrow} af + b)。
手推一下标记如何合并即可。
代码
#include <cstdio>
#include <algorithm>
#include <vector>
inline int read()
{
int data = 0, w = 1; char ch = getchar();
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}
const int N(5e5 + 10), Mod(998244353);
struct edge { int next, to; } e[N << 1];
int n, m, head[N], e_num, dep[N], pre[N], rt[N];
inline void add_edge(int from, int to)
{ e[++e_num] = (edge) {head[from], to}, head[from] = e_num; }
void dfs(int x, int fa)
{
dep[x] = dep[fa] + 1;
for (int i = head[x]; i; i = e[i].next)
if (e[i].to != fa) dfs(e[i].to, x);
}
int ls[N << 6], rs[N << 6], cur, pool[N << 6], top;
struct node { int a, b; } t[N << 6], I = (node) {1, 0};
inline node operator * (const node &x, const node &y)
{ return (node) {1ll * x.a * y.a % Mod, (y.b + 1ll * x.b * y.a) % Mod}; }
inline int operator == (const node &x, const node &y) { return x.a == y.a && x.b == y.b; }
int newNode(const node &v)
{
int x = top ? pool[top--] : ++cur;
return ls[x] = rs[x] = 0, t[x] = v, x;
}
void pushdown(int x)
{
if (t[x] == I) return;
if (!ls[x]) ls[x] = newNode(t[x]); else t[ls[x]] = t[ls[x]] * t[x];
if (!rs[x]) rs[x] = newNode(t[x]); else t[rs[x]] = t[rs[x]] * t[x];
t[x] = I;
}
void Modify(int x, int ql, int qr, const node &v, int l = 1, int r = n)
{
if (ql > qr) return;
if (ql <= l && r <= qr) return (void) (t[x] = t[x] * v);
int mid = (l + r) >> 1; pushdown(x);
if (ql <= mid) Modify(ls[x], ql, qr, v, l, mid);
if (mid < qr) Modify(rs[x], ql, qr, v, mid + 1, r);
}
int Query(int x, int p, int l = 1, int r = n)
{
if (l == r) return t[x].b;
int mid = (l + r) >> 1; pushdown(x);
if (p <= mid) return Query(ls[x], p, l, mid);
else return Query(rs[x], p, mid + 1, r);
}
int merge(int &x, int &y)
{
if (!ls[x] && !rs[x]) std::swap(x, y);
if (!ls[y] && !rs[y])
return t[x] = t[x] * (node) {t[y].b, 0}, x;
pushdown(x), pushdown(y);
ls[x] = merge(ls[x], ls[y]);
rs[x] = merge(rs[x], rs[y]);
return x;
}
void clear(int &x) { if (!x) return; clear(ls[x]), clear(rs[x]), pool[++top] = x, x = 0; }
void dp(int x, int fa)
{
pre[x] = std::max(pre[x], pre[fa]);
rt[x] = ++cur, t[cur] = (node) {0, 0}, Modify(cur, pre[x] + 1, dep[x], (node) {0, 1});
for (int i = head[x]; i; i = e[i].next) if (e[i].to != fa)
dp(e[i].to, x), rt[x] = merge(rt[x], rt[e[i].to]), clear(rt[e[i].to]);
if (x != 1) Modify(rt[x], 1, dep[x] - 1, (node) {1, Query(rt[x], dep[x])});
}
int main()
{
n = read();
for (int i = 1, a, b; i < n; i++)
a = read(), b = read(), add_edge(a, b), add_edge(b, a);
m = read(), dfs(1, 0);
for (int i = 1, x, y; i <= m; i++)
x = read(), y = read(), pre[y] = std::max(pre[y], dep[x]);
dp(1, 0), printf("%d
", Query(rt[1], 1));
return 0;
}