• 「十二省联考 2019」希望(长链剖分优化dp)


    Solution

    • 题意简述:选出 \(k\) 个树上连通块,使得存在一个点 \(u\) 满足:
      1.\(u\) 在这 \(k\) 个连通块的交集之中。
      2.对于这 \(k\) 个连通块中的任意一点 \(v\),都有:\(dist(v,u)≤L\)

    1.容斥

    • 显然对于每一个连通块集合,满足条件的点 \(u\) 构成的也是一个连通块,记这个连通块为 \(S\)
    • 我们钦定根节点为1。记 \(u\) 的父亲为 \(fa[u]\) 。记所有合法方案中,\(S\) 包含 \(u\) 的方案数为 \(A(u)\),包含边 \((u,fa[u])\) 的方案数为 \(B(u)\),那么 \(\sum_{i=1}^{n}A(i)-\sum_{i=2}^{n}B(i)\) 就是答案。
    • 考虑这样算为什么是对的:
    • 对于 \(S\) 中的每个点 \(u\)\(A(u)\) 中都包含此方案。对于 \(S\) 中的每条边 \((v,fa[v])\)\(B(v)\) 中都包含此方案。众所周知,树上连通块有一个绝妙的性质:点数 \(=\) 边数 \(+1\)。那么 \(u\) 的个数会比 \(v\) 的个数多 \(1\),也就是说此合法方案正好被统计一次。由此可得,每个合法方案都正好被统计一次,那么这样算就是对的。

    2.朴素dp

    • \(f[u][i]\) 表示满足以下条件的连通块 \(T\) 的个数 \(+1\)
      1.\(T\) 中必须包含 \(u\)
      2.\(T\) 中只能包含 \(u\) 子树的点。
      3.\(T\) 中任意一点 \(v\) 均满足 \(dist(u,v)\le i\)
      初值:\(f[u][0]=2\)
      \(dp\)式:\(f[u][i]=\prod_{v∈child[u]}f[v][i-1]+1\)

    • \(g[u][i]\) 表示满足以下条件的连通块 \(T\) 的个数:
      1.\(T\) 中必须包含 \(u\)
      2.\(T\) 中只能包含 \(u\)\(u\) 子树的点。
      3.\(T\) 中任意一点 \(v\) 均满足 \(dist(u,v)\le i\)
      初值:\(g[u][0]=1\)
      \(dp\)式:\(g[v][i]=g[u][i-1]*\prod_{x∈child[u]且x!=v}f[x][i-2]+1\)
      特殊地,\(i=1\) 时不用乘 \(\prod_{x∈child[u]且x!=v}f[x][i-2]\)

    • \(A(u)=[(f[u][L]-1)*(g[u][L])]^k\)

    • \(B(u)=[(f[u][L-1]-1)*(g[u][L]-1)]^k\)

    3.长链剖分优化\(f\)

    • 考虑如果 \(v\)\(u\) 的轻儿子怎么转移:
    • \(mx[v]\) 为点 \(v\) 子树里深度最大的点到 \(v\) 的距离。
    • 转移 \(f[u][i]=f[v][i-1]+1\) 的时候,\(i\) 最大会到 \(mx[u]\)。但是,为了保证时间复杂度为 \(\sum\) 链长,\(i\) 是不可以枚举到 \(mx[u]\) 的。
    • 怎么办呢?我们发现对于任意 \(i-1>mx[v]\),都有 \(f[v][i-1]=f[v][mx[v]]\)
    • 那么我们不用枚举到 \(mx[u]\)。我们只要枚举 \(i\)\(mx[v]+1\) 就可以了。
    • 至于 \(v\)\(f[u][mx[v]+2...mx[u]]\) 的贡献,我们考虑打一个标记 \(mul[u]\),表示 \(f[u][0...mx[u]]\) 都乘上 \(mul[u]\)
    • 显然要: \(mul[u]*=f[v][mx[v]]\)
    • 然后对于 \(f[u][0...mx[v]+1]\),暴力乘上 \(f[v][mx[v]]\) 的逆元即可。这样时间复杂度就是对的了。
    • 还有注意一点,\(f[v][mx[v]]\) 的逆元不可以用快速幂计算(太慢),要提前 \(O(n)\) 预处理所有的 \(f[v][mx[v]]\),然后 \(O(n)\) 求所有 \(f[v][mx[v]]\) 的逆元。不会线性求逆元的左转 Luogu5431
    • 然后不是递推式的末尾有个 \(+1\) 吗, 那么再记个 \(add[u]\)。现在 \(f[u][i]\) 的真实值就是 \(f[u][i]*mul[u]+add[u]\) 了。
    • 然后为了保证 \(f[u][i]*mul[u]+add[u]\)\(f[u][i]\) 的真实值,我们递推的时候不能直接 \(f[u][i]*=f[v][i-1]\),而是要这样:
    • \(ask0(u,i)\) 表示 \(f[u][i]\) 真实值。
    • \(ask1(u,res)\) 表示已知 \(f[u][i]\) 的真实值为 \(res\),求 \(f[u][i]\)
    • 显然:\(ask1(u,res)=(res-add[u])*imul[u]\),其中 \(imul[u]\)\(mul[u]\) 的逆元,维护 \(mul[u]\) 的同时也要维护 \(imul[u]\)
    • 转移:\(f[u][i]=ask1(ask0(u,i)*ask0(v,i-1),i)\)
    • 然后暴力乘逆元的时候也要类似地利用 \(ask1,ask0\)
    • 如果 \(v\) 是重儿子,直接把 \(f[v],imul[v],mul[v],add[v]\) 都传给 \(u\) 就行了。
    • 但是,如果 \(f[v][mx[v]]\mod 998244353 = 0\) 呢?
    • 那么此时相当于把 \(f[u][mx[v]+1...mx[u]]\) 都赋值成 \(0\)
    • 我们考虑多记两个标记 \(lim[u],zero[u]\),然后定义:
      \(i<lim[u],ask0(u,i)=f[u][i]*mul[u]+add[u]\)
      否则 \(ask0(u,i)=zero[u]*mul[u]+add[u]。\)
    • 那么如果 \(f[v][mx[v]]=0\),令 \(lim[u]=mx[v]+2,zero[u]=ask1(u,0)\)
    • 转移轻儿子 \(v\),枚举到 \(i\) 的时候,如果发现 \(lim[u]==i\),说明此时 \(lim[u]\) 的值需要增加。那么 \(lim[u]++,f[u][i]=zero[u]\) 即可。

    长链剖分优化f-参考程序

    int *f[e], tmp1[e * 10], *it = tmp1 + 2;
    
    namespace sf
    {
    	int mul[e], imul[e], add[e], lim[e], zero[e];
    	
    	inline void init(int u)
    	{
    		f[u] = it;
    		it += (mx[u] + 5) * 2;
    	}
    	
    	inline int ask0(int u, int i)
    	{
    		i = min(i, mx[u]);
    		if (lim[u] <= i) return ((ll)mul[u] * zero[u] + add[u]) % mod;
    		return ((ll)mul[u] * f[u][i] + add[u]) % mod;
    	}
    	
    	inline int ask1(int u, int res)
    	{	
    		return (ll)sub(res, add[u]) * imul[u] % mod;
    	}
    	
    	inline void dfs3(int u)
    	{
    		if (!son[u])
    		{
    			mul[u] = imul[u] = 1; add[u] = 2; lim[u] = n + 1;
    			f1[u] = sub(ask0(u, L), 1); f2[u] = sub(ask0(u, L - 1), 1);
    			return;
    		}
    		else
    		{
    			f[son[u]] = f[u] + 1; dfs3(son[u]); add[u] = add[son[u]]; 
    			mul[u] = mul[son[u]]; imul[u] = imul[son[u]]; lim[u] = lim[son[u]] + 1;
    			zero[u] = zero[son[u]]; f[u][0] = ask1(u, 1);
    		}
    		for (auto v : adj[u])
    		{
    			if (v == son[u]) continue;
    			init(v); dfs3(v);
    			for (int i = 0; i <= mx[v] + 1; i++)
    			{
    				if (lim[u] == i) f[u][lim[u]++] = zero[u];
    				f[u][i] = ask1(u, (ll)ask0(u, i) * (i ? ask0(v, i - 1) : 1) % mod);
    			} 
    			if (!p[v])
    			{
    				lim[u] = mx[v] + 2; zero[u] = ask1(u, 0);
    			}
    			else
    			{
    				mul[u] = (ll)mul[u] * p[v] % mod; add[u] = (ll)add[u] * p[v] % mod;
    				imul[u] = (ll)imul[u] * inv[v] % mod;
    				for (int i = 0; i <= mx[v] + 1; i++)
    				f[u][i] = ask1(u, (ll)ask0(u, i) * inv[v] % mod);
    			}
    		}
    		add[u] = plu(add[u], 1); 
    		f1[u] = sub(ask0(u, L), 1); f2[u] = sub(ask0(u, L - 1), 1);
    		// f1[u]=f[u][L]真实值-1,f2[u]=f[u][L-1]真实值-1
    	}
    }
    

    4.长链剖分优化g

    • 回顾 \(g\)\(dp\) 式:
      \(g[v][i]=g[u][i-1]*\prod_{x∈child[u]且x!=v}f[x][i-2]+1\)

    • 对于 \(g[v][i]=g[u][i-1]\) 这一部分,直接把 \(g\) 传给重儿子,然后轻儿子暴力转移即可。

    • 注意轻儿子只要转移到 \(g[v][max(L-mx[v],0)...L]\)

    • 然后 \(\prod f[x][i-2]\) 怎么办呢?

    • 对于重儿子依然可以暴力计算。和 \(f\) 一样,也要记 \(lim,add,imul,mul,zero\)

    • 对于轻儿子呢?我们发现它等于 \(\frac{f[u][i-1]-1}{f[v][i-2]}\),但要是 \(f[v][i-2] \mod 998244353 =0\) 呢?

    • 所以我们只能把它拆成一段前缀和一段后缀相乘的形式。

    • 先考虑前缀怎么办:

    • 我们把轻儿子\(mx[v]\) 升序排序。然后记 \(b[i]\) 表示:$$\prod_{x的dfs序≤v且x∈child[u]}f[x][i]$$

    • 也就是在枚举到 \(v\) 的时候把 \(f[v][0...mx[v]]\) 计入 \(b[0...mx[v]]\) 就行了。

    • 然后对于 \(b[mx[v]+1...∞]\),显然 \(v\) 对它们的贡献相同,那么额外记一个值就行了。

    • 大概就是这样:

    int now = 0, tot = 1;
    for (int i = 0; i <= mx[v]; i++)
    if (i > now) b[i] = (ll)tot * sf::ask0(v, i) % mod;
    else b[i] = (ll)b[i] * sf::ask0(v, i) % mod;
    now = mx[v]; 
    tot = (ll)tot * sf::ask0(v, mx[v]) % mod; //额外记一个 tot
    
    • 我们发现计算 \(f[u][i]\) 的时候,\(f[u][i]\) 的真实值要不断乘上 \(f[v][i-1]\)。也就是说,枚举到轻儿子 \(v\) 的时候,\(f[u][i]\) 的真实值是 $$\prod_{x的dfs序≤v且x∈child[u]}f[x][i-1]$$
    • 这就是一段前缀的形式了。如果我们计算 \(g\) 的时候,把子节点的 \(dfs\) 顺序全部反过来,那它就是一段后缀的形式了。
    • 那就是说,在计算 \(f\) 的时候,把轻儿子按 \(mx[v]\) 降序排序,然后计算 \(g\) 的时候反过来就行了。
    • 但是我们现在并没有存下对于每个 \(v\),上式的值。
    • 那么我们要做的就是:假设计算 \(f\) 的时候,儿子的 \(dfs\) 顺序为:\(v_1,v_2,v_3,...,v_m\)
    • 然后现在反过来,枚举到 \(v_m\) 的时候,我们要算出:\(\prod_{j=1}^{m-1}f[v_j][i]\),枚举到 \(v_{m-1}\) 的时候,我们要算出:\(\prod_{j=1}^{m-2}f[v_j][i]\)……
    • 然后把算出的这个和 \(b\) 相乘就可以得出 \(\prod f[x][i-2]\) 了。
    • 具体地,我们可以在计算 \(f\) 的时候,枚举到每个轻儿子 \(v_j\) 的时候,都记一下在计算 \(v_j\) 的贡献之前,\(f[u][i]\) 的真实值。也就是对于每个 \(v_j\) 都记下 \(\prod_{k=1}^{j-1}f[v][i-1]\)
    • 但是不能直接把真实值记下来。因为计算 \(f\) 的时候转移的是 \(f[v][0...mx[v]+1]\)。而计算 \(g\) 的时候要算的是 \(g[v][max(L-mx[v],0)...L]\)
    • 所以我们在计算 \(f\) ,枚举到轻儿子 \(v\) 的时候,如果 \(f[u][i],lim[u],zero[u],add[u],mul[u],imul[u]\) 中任意一个的值改变了,都要把改变之前的值记下来。
    • 最后 \(f[u][i]\) 的值有 \(+1\),看作是枚举到最后一个轻儿子时做的修改。
    • 具体实现时可以对每个 \(v\) 开一个栈 (\(list\)实现),计算 \(f\) 的时候记录下修改的地址和值,然后计算 \(g\) 的时候,按从栈顶到栈底的顺序还原这些元素的值。然后利用 \(sf::ask0\) 就能知道 \(\prod_{k=1}^{j-1}f[v][i-1]\) 的真实值了。

    Code

    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define ll long long
    #define pb push_back
    
    template <class t>
    inline void read(t & res)
    {
    	char ch;
    	while (ch = getchar(), !isdigit(ch));
    	res = ch ^ 48;
    	while (ch = getchar(), isdigit(ch))
    	res = res * 10 + (ch ^ 48);
    }
    
    const int e = 1e6 + 5, mod = 998244353;
    vector<int>adj[e], h[e];
    int *f[e], tmp1[e * 10], *it = tmp1 + 2, ans, n, L, fa[e], k, mx[e], b[e], ret[e];
    int son[e], p[e], inv[e], pre[e], suf[e], f1[e], f2[e], *g[e], tmp2[e * 10];
    struct point
    {
    	int x, y;
    }a[e]; 
    
    inline int plu(int x, int y)
    {
    	x += y;
    	if (x >= mod) x -= mod;
    	return x;
    }
    
    inline int sub(int x, int y)
    {
    	x -= y;
    	if (x < 0) x += mod;
    	return x;
    }
    
    inline int ksm(int x, int y)
    {
    	int res = 1;
    	while (y)
    	{
    		if (y & 1) res = (ll)res * x % mod;
    		y >>= 1;
    		x = (ll)x * x % mod;
    	}
    	return res;
    }
    
    inline void dfs0(int u, int pa)
    {
    	fa[u] = pa;
    	for (auto v : h[u])
    	{
    		if (v == pa) continue;
    		dfs0(v, u);
    	}
    }
    
    inline bool cmp(int x, int y)
    {
    	return mx[x] > mx[y];
    }
    
    inline void dfs1(int u)
    {
    	for (auto v : adj[u])
    	{
    		dfs1(v);
    		mx[u] = max(mx[u], mx[v] + 1);
    		if (mx[v] > mx[son[u]]) son[u] = v;
    	}
    	sort(adj[u].begin(), adj[u].end(), cmp);
    }
    
    inline void dfs2(int u)
    {
    	p[u] = 1;
    	for (auto v : adj[u])
    	{
    		dfs2(v);
    		p[u] = (ll)p[u] * p[v] % mod;
    	}
    	p[u] = plu(p[u], 1);
    }
    
    inline void prepare() // 线性求逆元
    {
    	int i; pre[0] = 1;
    	for (i = 1; i <= n; i++) 
    	if (p[i]) pre[i] = (ll)pre[i - 1] * p[i] % mod;
    	else pre[i] = pre[i - 1];
    	suf[n] = ksm(pre[n], mod - 2);
    	suf[n + 1] = 0;
    	for (i = n - 1; i >= 0; i--) 
    	if (p[i + 1]) suf[i] = (ll)suf[i + 1] * p[i + 1] % mod;
    	else suf[i] = suf[i + 1];
    	for (i = 1; i <= n; i++) inv[i] = (ll)pre[i - 1] * suf[i] % mod; 
    }
    
    struct work
    {
    	struct node
    	{
    		int *w, v; // 记录修改的地址和修改之前的值
    	};
    	list<node>q;
    	
    	inline void ins(int &x)
    	{
    		q.pb((node){&x, x});
    	}
    	
    	inline void regain()
    	{
    		while (!q.empty()) *(q.back().w) = q.back().v, q.pop_back(); // 还原 
    	}
    }q[e];
    
    namespace sf
    {
    	int mul[e], imul[e], add[e], lim[e], zero[e];
    	
    	inline void init(int u) // 地址分配
    	{
    		f[u] = it;
    		it += (mx[u] + 5) * 2;
    	}
    	
    	inline int ask0(int u, int i)
    	{
    		i = min(i, mx[u]);
    		if (lim[u] <= i) return ((ll)mul[u] * zero[u] + add[u]) % mod;
    		return ((ll)mul[u] * f[u][i] + add[u]) % mod;
    	}
    	
    	inline int ask1(int u, int res)
    	{	
    		return (ll)sub(res, add[u]) * imul[u] % mod;
    	}
    	
    	inline void dfs3(int u)
    	{
    		if (!son[u])
    		{
    			mul[u] = imul[u] = 1; add[u] = 2; lim[u] = n + 1;
    			f1[u] = sub(ask0(u, L), 1); f2[u] = sub(ask0(u, L - 1), 1);
    			return;
    		}
    		else
    		{
    			f[son[u]] = f[u] + 1; dfs3(son[u]); add[u] = add[son[u]]; 
    			mul[u] = mul[son[u]]; imul[u] = imul[son[u]]; lim[u] = lim[son[u]] + 1;
    			zero[u] = zero[son[u]]; f[u][0] = ask1(u, 1);
    		}
    		int lst = 0;
    		for (auto v : adj[u])
    		{
    			if (v == son[u]) continue;
    			init(v); dfs3(v); lst = v;
    			for (int i = 0; i <= mx[v] + 1; i++)
    			{
    				if (lim[u] == i) 
    				q[v].ins(lim[u]), q[v].ins(f[u][i]), f[u][lim[u]++] = zero[u];
    				q[v].ins(f[u][i]);
    				f[u][i] = ask1(u, (ll)ask0(u, i) * (i ? ask0(v, i - 1) : 1) % mod);
    			} 
    			if (!p[v])
    			{
    				q[v].ins(lim[u]); q[v].ins(zero[u]);
    				lim[u] = mx[v] + 2; zero[u] = ask1(u, 0);
    			}
    			else
    			{
    				q[v].ins(mul[u]); q[v].ins(add[u]); q[v].ins(imul[u]);
    				mul[u] = (ll)mul[u] * p[v] % mod; add[u] = (ll)add[u] * p[v] % mod;
    				imul[u] = (ll)imul[u] * inv[v] % mod;
    				for (int i = 0; i <= mx[v] + 1; i++)
    				q[v].ins(f[u][i]), f[u][i] = ask1(u, (ll)ask0(u, i) * inv[v] % mod);
    			}
    		}
    		if (lst) q[lst].ins(add[u]);
    		add[u] = plu(add[u], 1); 
    		f1[u] = sub(ask0(u, L), 1); f2[u] = sub(ask0(u, L - 1), 1);
    	}
    }
    
    namespace sg
    {
    	int mul[e], imul[e], add[e], lim[e], zero[e];
    	
    	inline void init(int u)
    	{
    		it += mx[u] + 5;
    		g[u] = it - max(L - mx[u], 0);
    		it += mx[u] + 5;
    	}
    	
    	inline int ask0(int u, int i)
    	{
    		if (lim[u] <= i) return ((ll)mul[u] * zero[u] + add[u]) % mod;
    		else return ((ll)mul[u] * g[u][i] + add[u]) % mod;
    	}
    	
    	inline int ask1(int u, int res)
    	{	
    		return (ll)sub(res, add[u]) * imul[u] % mod;
    	}
    	
    	inline void dfs4(int u)
    	{
    		int gu = ask0(u, L), tot = 1, now = 0, x = son[u];
    		ans = plu(ans, ksm((ll)f1[u] * gu % mod, k));
    		ret[u] = ask0(u, L);
    		if (u != 1) 
    		{
    			gu = sub(gu, 1);
    			ans = sub(ans, ksm((ll)f2[u] * gu % mod, k));
    		}
    		if (!x) return; 
    		b[0] = 1; reverse(adj[u].begin(), adj[u].end());
    		for (auto v : adj[u])
    		{
    			if (v == x) continue;
    			q[v].regain(); // 还原
    			init(v); 
    			mul[v] = imul[v] = 1; 
    			lim[v] = n + 1;
    			for (int i = max(0, L - mx[v]); i <= L; i++)
    			{
    				g[v][i] = (ll)(i ? ask0(u, i - 1): 1) * 
    				(i - 2 > now ? tot : i >= 2 ? b[i - 2] : 1) 
    				% mod * (i ? sf::ask0(u, i - 1) : 1) % mod;	
    				if (i) g[v][i] = plu(g[v][i], 1);
    			}
    			for (int i = 0; i <= mx[v]; i++)
    			if (i > now) b[i] = (ll)tot * sf::ask0(v, i) % mod;
    			else b[i] = (ll)b[i] * sf::ask0(v, i) % mod;
    			now = mx[v]; 
    			tot = (ll)tot * sf::ask0(v, mx[v]) % mod;
    		}
    		add[x] = add[u]; mul[x] = mul[u]; imul[x] = imul[u];
    		lim[x] = lim[u] + 1; zero[x] = zero[u]; g[x] = g[u] - 1;
    		int st = max(L - mx[x], 0);
    		for (auto v : adj[u])
    		{
    			if (v == x) continue; int ed = min(mx[v] + 2, L);
    			for (int i = st; i <= ed; i++)
    			{
    				if (lim[x] == i) g[x][lim[x]++] = zero[x];
    				g[x][i] = ask1(x, (ll)ask0(x, i) * 
    				(i >= 2 ? sf::ask0(v, i - 2) : 1) % mod);
    			}
    			if (L <= mx[v] + 2) continue; 
    			if (!p[v]) lim[x] = max(mx[v] + 3, L - mx[x]), zero[x] = ask1(x, 0);
    			else
    			{
    				mul[x] = (ll)mul[x] * p[v] % mod; add[x] = (ll)add[x] * p[v] % mod;
    				imul[x] = (ll)imul[x] * inv[v] % mod;
    				for (int i = st; i <= ed; i++)
    				g[x][i] = ask1(x, (ll)ask0(x, i) * inv[v] % mod);
    			}
    		}
    		add[x] = plu(add[x], 1); 
    		if (L - mx[x] <= 0) g[x][0] = ask1(x, 1);
    		for (auto v : adj[u]) dfs4(v);
    	}	
    	
    	inline void begin()
    	{
    		mul[1] = imul[1] = add[1] = 1;
    		lim[1] = n + 1; init(1);
    	}
    }
    
    int main()
    {
    	int i, x, y, j;
    	read(n); read(L); read(k); mx[0] = -1;
    	for (i = 1; i < n; i++) 
    	{
    		read(x); read(y);
    		a[i].x = x; a[i].y = y;
    		h[x].pb(y); h[y].pb(x);
    	}
    	if (L == 0)
    	{
    		cout << n << endl;
    		return 0;
    	}
    	dfs0(1, 0);
    	for (i = 1; i < n; i++)
    	{
    		x = a[i].x; y = a[i].y;
    		if (fa[y] == x) adj[x].pb(y);
    		else adj[y].pb(x);
    	}
    	dfs1(1); dfs2(1); prepare(); 
    	sf::init(1); sf::dfs3(1);
    	it = tmp2 + 2;
    	sg::begin(); sg::dfs4(1);
    	cout << ans << endl;
    	fclose(stdin);
    	fclose(stdout);
    	return 0;
    }
    
  • 相关阅读:
    查看版本号以及如何升级
    http协商缓存VS强缓存
    「JOISC 2012」星座(凸包)
    「科技」求欧拉数单项
    「科技」在线 O(1) 逆元
    「JOISC 2017 Day 3」自然公园(交互)
    「IOI 2021」分糖果(线段树)
    「EOJ 317A」击鼓传花(类欧)
    「CF 1483E」Vabank(交互,构造)
    「NOIP 2020」微信步数(计数)
  • 原文地址:https://www.cnblogs.com/cyf32768/p/12196975.html
Copyright © 2020-2023  润新知