Description
给定一个 (n) 个顶点,(m) 条边的无向联通图,点、边带权。
先有 (q) 次修改或询问,每个指令形如 ( ext{opt} x y):
- ( ext{opt}=1):将顶点 (x) 的点权修改为 (y);
- ( ext{opt}=2):查询顶点 (x, y) 间所有路径中路径上最大值中,最小的哪一个最大值(瓶颈路)。
- ( ext{opt}=3):查询顶点 (x) 可以结果边权 (le y) 的边能到达的所有点上有几种不同的点权。
Hint
- (1le nle 10^5, 1le mle 3 imes 10^5, 1le qle 2 imes 10^5)
- ( ext{点权、边权}in[0, 2^{31}))
Solution
首先对于 ( ext{opt}=2) 的操作,这是个经典问题,我们有很多解决思路。但是看到操作三就发现 Kruskal 重构树才是最好的选择。
我们假设没有操作一,那么操作二可以转化为重构树上两个结点的 LCA,操作三则是子树数颜色。
考虑到一颗子树的 dfs 序连续,那么这又可以转化为序列问题,于是成了区间数颜色。
然而在加上修改操作操作三就变的棘手了,或者树套树应该也能过但肯定不好写。
一看清一色待修莫队,感觉这个题可以不用这样麻烦。
之后在题解区发现了 mrsrz 的题解 的一只 (log) 处理方法,感觉很妙,于是学习一波。
对于一个在结点 (x) 刚插入的一种颜色 (c),它可以贡献的范围是 (x) 的一个深度最浅的祖先 (a) 满足以 (a) 为根的子树中原本不存在任何一个颜色 (c)。于是我们就可以在这上面做链加。大力树剖加树状数组是 (O(log^2n)) 一次的,但直接树上差分则可以做到 (O(log n))。具体地,我们在结点 (x) 的位置 (+1),然后在 (a) 的父亲上 (-1),因为它可以贡献到的最高的位置是 (a)。
然后就是如何找到这样一个 (a) 的问题。其实这个不难处理,这个 (a) 必然是重构树上 dfs 序与 (x) 相邻的两个结点(有可能一个)(y_1, y_2) 的两个 ( ext{LCA}(x, y_1), ext{LCA}(x, y_2)) 中,深度较深的那一个。如果对虚树比较熟那么这就很显然。
具体实现时,为了找到 dfs 序相邻的点,我们对每一种颜色开一个 std::set
,存这个颜色的所有结点并按 dfs 序排序。
这样总复杂度是 (O((n+m+q)log n)) 的。
Code
这个题非常码农,所以写的有点长。不过思路还是很清晰的。
/*
* Author : _Wallace_
* Source : https://www.cnblogs.com/-Wallace-/
* Problem : Luogu P5168 xtq玩魔塔
*/
#include <algorithm>
#include <cstdio>
#include <cctype>
#include <map>
#include <set>
#include <vector>
inline int read() {
int x(0), s(0); char c; while (!isgraph(c = getchar()));
if (x == '-') s = 1, c = getchar();
do x = (x << 1) + (x << 3) + c - 48; while (isdigit(c = getchar()));
return s ? -x : x;
}
const int N = 1e5 + 5;
const int M = 3e5 + 5;
const int Q = 2e5 + 5;
const int V = N << 1;
const int logN = 19;
int n, m, q, a[N];
struct Edge {
int u, v, w;
bool operator < (const Edge& rhs) const {
return w < rhs.w;
}
} e[M];
int uset[V];
int find(int x) {
return x == uset[x] ? x : uset[x] = find(uset[x]);
}
int vcnt;
int ch[V][2], fa[V][logN], val[V];
int timer(0);
int dfn[V], siz[V], dep[V];
void dfs(int x) {
dfn[x] = ++timer, siz[x] = 1, dep[x] = dep[fa[x][0]] + 1;
if (!ch[x][0] && !ch[x][1]) return;
dfs(ch[x][0]), dfs(ch[x][1]), siz[x] += siz[ch[x][0]] + siz[ch[x][1]];
}
int lca(int x, int y) {
if (dep[x] < dep[y]) std::swap(x, y);
for (int j = logN - 1; ~j; j--)
if (dep[fa[x][j]] >= dep[y]) x = fa[x][j];
if (x == y) return x;
for (int j = logN - 1; ~j; j--)
if (fa[x][j] != fa[y][j]) x = fa[x][j], y = fa[y][j];
return fa[x][0];
}
int getanc(int x, int y) {
for (int j = logN - 1; ~j; --j)
if (fa[x][j] && val[fa[x][j]] <= y) x = fa[x][j];
return x;
}
namespace bit {
int tr[V];
void add(int p, int v) {
for (; p <= vcnt; p += p & -p) tr[p] += v;
}
int get(int p) {
int v(0);
for (; p; p -= p & -p) v += tr[p];
return v;
}
}
struct cmp {
bool operator () (const int& a, const int& b) {
return dfn[a] < dfn[b];
}
};
int col_tot(0);
std::map<int, int> idx;
std::set<int, cmp> pos[V + Q];
int getIdx(int col) {
return idx.count(col) ? idx[col] : idx[col] = ++col_tot;
}
void update_col(int x, int c) {
bit::add(dfn[x], 1);
std::set<int>::iterator it = pos[c = getIdx(c)].insert(x).first;
if (pos[c].size() == 1u) return;
std::vector<int> adj; adj.reserve(2);
if (++it != pos[c].end()) adj.push_back(*it);
if (--it != pos[c].begin()) adj.push_back(*--it);
std::pair<int, int> y;
for (int i = 0; i < (int)adj.size(); i++) {
int l = lca(x, adj[i]);
y = std::max(y, std::make_pair(dep[l], l));
}
bit::add(dfn[y.second], -1);
}
void remove_col(int x, int c) {
bit::add(dfn[x], -1);
std::set<int>::iterator it = pos[c = getIdx(c)].find(x);
if (pos[c].size() == 1u) { pos[c].erase(it); return; }
std::vector<int> adj; adj.reserve(2);
if (++it != pos[c].end()) adj.push_back(*it);
if (--it != pos[c].begin()) adj.push_back(*--it), ++it;
pos[c].erase(it);
std::pair<int, int> y;
for (int i = 0; i < (int)adj.size(); i++) {
int l = lca(x, adj[i]);
y = std::max(y, std::make_pair(dep[l], l));
}
bit::add(dfn[y.second], 1);
}
int count(int x, int y) {
x = getanc(x, y);
return bit::get(dfn[x] + siz[x] - 1) - bit::get(dfn[x] - 1);
}
signed main() {
n = read(), m = read(), q = read();
for (int i = 1; i <= n; i++) a[i] = read();
for (int i = 1; i <= m; i++) e[i].u = read(), e[i].v = read(), e[i].w = read();
std::sort(e + 1, e + 1 + m), vcnt = n;
for (int i = 1; i <= n; i++) uset[i] = i;
for (int i = 1; i <= m && vcnt != n * 2 - 1; i++) {
int u = find(e[i].u), v = find(e[i].v);
if (u == v) continue;
val[++vcnt] = e[i].w, uset[u] = uset[v] = uset[vcnt] = vcnt;
fa[ch[vcnt][0] = u][0] = fa[ch[vcnt][1] = v][0] = vcnt;
}
for (int j = 1; j < logN; j++)
for (int i = 1; i <= vcnt; i++)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
dfs(vcnt);
for (int i = 1; i <= n; i++)
update_col(i, a[i]);
while (q--) {
int opt = read(), x = read(), y = read();
if (opt == 1) remove_col(x, a[x]), update_col(x, a[x] = y);
if (opt == 2) printf("%d
", val[lca(x, y)]);
if (opt == 3) printf("%d
", count(x, y));
}
return 0;
}