辣鸡抄题解选手。
发现对于询问,每棵树只要把询问的两个点长在该长的位置了,它多长了些什么点,包括不该长的点都无所谓。那么让每棵树先长完了再询问就好。、
也就是说问的是有n棵树,每次让所有树长出一个节点,一开始都从1开始长,然后会修改一个区间的树的长的位置。
先让所有节点长出来,
给每个更换生长点的操作建一个权值为0的虚点,所有节点就先往它之前的最后一个虚点上长。然后把虚点按前后一个个串起来。
离线所有操作,然后从第左到右处理每棵树。
每个虚点生效的是一段区间l~r,在l的时候虚点开始生效,就把它切下来接到它该接的实点上,在r+1的时候虚点失效,就把它切下来接回它的上一个虚点上,说明这个虚点上面长的点真正该长的地方是它的上一个虚点接的地方,或者上一个点也失效了接到前面去,就是前面某个地方,总会到达一个它该长的地方。
然后求两点距离时不能直接求,需要通过lca,因为存在虚点,若两个点的lca是一个虚点,他们真正的lca其实是虚点的某个祖先,直接算路径就会算少,而求lca可以避免这个问题。
方法是access(x),再access(y),最后一次虚边边实边的x就是lca。
//Achen
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<vector>
#include<cstdio>
#include<queue>
#include<cmath>
const int N=4e5+7;
typedef long long LL;
using namespace std;
int n,m,v[N],ch[N][2],p[N],sum[N],sz[N],ans[N];
template<typename T>void read(T &x) {
char ch=getchar(); x=0; T f=1;
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') f=-1,ch=getchar();
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0'; x*=f;
}
#define lc ch[x][0]
#define rc ch[x][1]
int isroot(int x) { return (ch[p[x]][0]!=x&&ch[p[x]][1]!=x); }
void update(int x) { sum[x]=v[x]+sum[lc]+sum[rc]; }
void rotate(int x) {
int y=p[x],z=p[y],l=(x==ch[y][1]),r=l^1;
if(!isroot(y)) ch[z][y==ch[z][1]]=x; p[x]=z;
ch[y][l]=ch[x][r]; p[ch[x][r]]=y;
ch[x][r]=y; p[y]=x;
update(y); update(x);
}
void splay(int x) {
for(;!isroot(x);rotate(x)) {
int y=p[x],z=p[y];
if(!isroot(y))
((x==ch[y][1])^(y==ch[z][1]))?rotate(x):rotate(y);
}
}
int access(int x) {
int t=0;
for(;x;x=p[t=x]) {
splay(x);
rc=t;
update(x);
} return t;
}
void lik(int x,int y) {
splay(x);
p[x]=y;
}
void cut(int x) {
access(x);
splay(x);
p[lc]=0; lc=0;
update(x);
}
int qry(int x,int y) {
int rs=0;
access(x); splay(x); rs+=sum[x];
int t=access(y); splay(y); rs+=sum[y];
access(t); splay(t); rs-=2*sum[t];
return rs;
}
struct node {
int pos,op,x,y;
node(){}
node(int pos,int op,int x,int y):pos(pos),op(op),x(x),y(y){}
friend bool operator <(const node&A,const node&B) {
return A.pos<B.pos||(A.pos==B.pos&&A.op<B.op);
}
}q[N];
int tot,last,id[N],idd,cnt,L[N],R[N],qs;
void newnode(int x) {
tot++; v[tot]=x; sum[tot]=x;
}
int main() {
#ifdef DEBUG
freopen(".in","r",stdin);
freopen(".out","w",stdout);
#endif
read(n); read(m);
newnode(1); idd=1; L[1]=1; R[1]=n; id[1]=1;
newnode(0); last=2; lik(2,1);
for(int ti=1;ti<=m;ti++) {
int o,l,r,x,u,v;
read(o);
if(o==0) {
read(l); read(r);
newnode(1);
L[++idd]=l; R[idd]=r; id[idd]=tot;
q[++cnt]=node(1,ti-m,tot,last);
}
else if(o==1) {
read(l); read(r); read(x);
l=max(l,L[x]); r=min(r,R[x]);
if(l<=r) {
newnode(0); lik(tot,last);
q[++cnt]=node(l,ti-m,tot,id[x]);
q[++cnt]=node(r+1,ti-m,tot,last);
last=tot;
}
}
else {
read(x); read(u); read(v);
q[++cnt]=node(x,++qs,id[u],id[v]);
}
}
sort(q+1,q+cnt+1);
for(int i=1,j=1;i<=n;i++) {
for(;j<=cnt&&q[j].pos==i;j++) {
if(q[j].op<=0) {
cut(q[j].x);
lik(q[j].x,q[j].y);
}
else ans[q[j].op]=qry(q[j].x,q[j].y);
}
}
for(int i=1;i<=qs;i++) printf("%d
",ans[i]);
return 0;
}
/*
5 5
0 1 5
1 2 4 2
0 1 4
2 1 1 3
2 2 1 3
1 3
0 1 1
0 1 1
2 1 2 3
*/