可持久化并查集
题目链接:ybt金牌导航4-6-4 / luogu P3402
题目大意
要你支持可持久化的并查集。
即可以退回到第 k 次操作后的并查集。
思路
你考虑并查集的过程。
首先是合并,其实就是找到它们所在的集合,然后把一个集合的父亲连向另一个。
即 (fa) 数组发生了改变,而我们要维护的既是 (fa) 数组的可持久化。
那其实就是要维护可持续化数组,那主席树来维护即可。
然后你也可以看出,那你就不能路径压缩。
那复杂都就会不优,我们考虑按秩合并,即把小的连向大的。(有点启发式合并的感觉)
那尽可能的让大小一样,链的深度就是 (logn)。
那 (find) 函数还是一样,你就不断往上跳,直到它的父亲是它自己。
找它的父亲要在主席树上找,(logn),然后跳是跳链深度次数,(logn),所以复杂度是 (O(mlog^2n))。
代码
#include<cstdio>
#include<algorithm>
using namespace std;
int n, m, op, x, y, rt[200001], tot;
int fa[200001 << 5], ls[200001 << 5], rs[200001 << 5], deg[200001 << 5];
void build(int &now, int l, int r) {
now = ++tot;
if (l == r) {
fa[now] = l;//一开始独立,自己父亲是自己
return ;
}
int mid = (l + r) >> 1;
build(ls[now], l, mid);
build(rs[now], mid + 1, r);
}
int query(int now, int l, int r, int pl) {
if (l == r) return now;
int mid = (l + r) >> 1;
if (pl <= mid) return query(ls[now], l, mid, pl);
else return query(rs[now], mid + 1, r, pl);
}
int find(int root, int pl) {
int now = query(root, 1, n, pl);
if (fa[now] == pl) return now;
return find(root, fa[now]);
}
int merge(int bef, int l, int r, int X, int Y) {
int now = ++tot;//记得新开点
ls[now] = ls[bef];
rs[now] = rs[bef];
if (l == r) {
fa[now] = Y;//合并
deg[now] = deg[bef];
return now;
}
int mid = (l + r) >> 1;
if (X <= mid) ls[now] = merge(ls[bef], l, mid, X, Y);
else rs[now] = merge(rs[bef], mid + 1, r, X, Y);
return now;
}
void adddeg(int now, int l, int r, int pl) {
if (l == r) {
deg[now]++;
return ;
}
int mid = (l + r) >> 1;
if (pl <= mid) adddeg(ls[now], l, mid, pl);
else adddeg(rs[now], mid + 1, r, pl);
}
int main() {
scanf("%d %d", &n, &m);
build(rt[0], 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d", &op);
if (op == 1) {
rt[i] = rt[i - 1];
scanf("%d %d", &x, &y);
int X = find(rt[i], x), Y = find(rt[i], y);
if (fa[X] == fa[Y]) continue;
if (deg[X] > deg[Y]) swap(X, Y);//让深度小的连向深度大的
rt[i] = merge(rt[i - 1], 1, n, fa[X], fa[Y]);
if (deg[X] == deg[Y]) adddeg(rt[i], 1, n, fa[Y]);//如果两个深度都一样,就一定要加深度了
continue;
}
if (op == 2) {
scanf("%d", &x);
rt[i] = rt[x];
continue;
}
if (op == 3) {
rt[i] = rt[i - 1];
scanf("%d %d", &x, &y);
int X = find(rt[i], x), Y = find(rt[i], y);
if (fa[X] == fa[Y]) printf("1
");
else printf("0
");
continue;
}
}
return 0;
}