题目描述
虽不能至,心向往之。
$Treap=Tree+Heap$
椎$=$树$+$堆
小$pi$学习了计算机科学中的数据结构$Treap$。
小$pi$知道$Treap$指的是一种树。
小$pi$还知道$Treap$节点上有两个权值$k$和$w$,其中$k$满足二叉搜索树性质。$w$满足堆性质。
小$pi$还知道在$k$和$w$都各不相同的时候,$Treap$的形态是固定的。
但是小$pi$不知道这道题目的做法。
这道题目要求你维护一个大根堆$Treap$,要求支持$n$个操作:
$0 k w$:插入一个关键字为$k$,权值为$w$的点。
$1 k$:删除一个关键字为$k$的点。
$2 ku kv$:返回关键字分别为$k_u$和$k_v$两个节点的距离
保证任意时刻树中结点$key$和$weight$都是两两不同的。不会删除当前$Treap$中不存在的点
小$pi$找到了学信息学竞赛的你
输入格式
输入的第一行包含一个正整数$n$,表示操作数。
下面$n$行,为上述操作之一。
输出格式
对于每个$2$操作,返回对应答案。
样例
样例输入:
12
0 4 17
0 10 38
0 2 21
0 1 61
2 1 10
2 10 4
1 10
1 1
0 8 42
2 8 4
2 8 2
2 4 2
样例输出:
1
2
2
1
1
数据范围与提示
对于$20\%$的数据,$nleqslant 100$。
对于$40\%$的数据,$nleqslant 1,000$。
对于$60\%$的数据,$nleqslant 10,000$。
对于$100\%$的数据,$nleqslant 200,000,0leqslant k,w,k_u,k_vleqslant 2^{31}$。
题解
一句话题意:求$Treap$上两点的距离。
$Treap$毕竟还是一棵树,考虑对一棵树如何求两个点$(u,v)$之间的距离,无非就是$depth[u]+depth[v]-2 imes depth[lca]$。
那么对于这道题,就是要想办法求出$lca$和$depth$就好了。
因为$key$满足二叉搜索树性质,也就是中序遍历从小到大;那么不妨将其拍扁放在序列上,也就是说维护一个$key$从小到大的序列。
注意空间问题,因为这道题$key$值只与相对大小有关,而与具体值无关,所以可以离散化,并在序列中提前留好坑即可。
那么来考虑$lca$是什么。
画个图便可以知道,$k_u$和$k_v$的$lca$就是在这个序列上$k_u$和$k_v$之间的$weight$值的最大值,用线段树维护即可。
接着考虑$depth$。
对于点$u$,其在序列上左(右)边第一个$weight$比它大的一定是它的祖先;而一个点的$depth=$其祖先的个数,对于这道题也就是$u$左边和右边的上升序列长度之和(注意这里的上升序列是指碰见比它大的就选上,而不是求最长上升子序列)。
考虑平常维护上升序列可以用单调栈,这道题可以用线段树维护单调栈(然而考场并不会,花了一个小时$zzyy$了一下)。
知道了$lca$和$depth$,这道题也就解决了。
时间复杂度:$Theta(nlog^2 n)$。
期望得分:$100$分。
实际得分:$100$分。
代码时刻
#include<bits/stdc++.h>
#define L(x) x<<1
#define R(x) x<<1|1
using namespace std;
struct rec{int opt,k,w;}e[300000];
unordered_map<int,int>mp;
int n;
int a[300000];
int trmx[1000000],trl[1000000],trr[1000000],trns[1000000],pos[1000000],cnt;
int pushupl(int x,int l,int r,int w)
{
if(l==r)return trmx[x]>w;int mid=(l+r)>>1;
if(trmx[L(x)]<=w)return pushupl(R(x),mid+1,r,w);
return trl[x]-trl[L(x)]+pushupl(L(x),l,mid,w);
}
int pushupr(int x,int l,int r,int w)
{
if(l==r)return trmx[x]>w;int mid=(l+r)>>1;
if(trmx[R(x)]<=w)return pushupr(L(x),l,mid,w);
return trr[x]-trr[R(x)]+pushupr(R(x),mid+1,r,w);
}
void pushup(int x,int l,int r)
{
if(trmx[L(x)]<=trmx[R(x)]){trns[x]=trns[R(x)];trmx[x]=trmx[R(x)];}
else{trns[x]=trns[L(x)];trmx[x]=trmx[L(x)];}int mid=(l+r)>>1;
trl[x]=trl[L(x)]+pushupl(R(x),mid+1,r,trmx[L(x)]);
trr[x]=trr[R(x)]+pushupr(L(x),l,mid,trmx[R(x)]);
}
void add(int x,int l,int r,int k,int w)
{
if(l==r){trns[x]=k;trmx[x]=w;trl[x]=trr[x]=1;return;}
int mid=(l+r)>>1;
if(k<=mid)add(L(x),l,mid,k,w);
else add(R(x),mid+1,r,k,w);
pushup(x,l,r);
}
void del(int x,int l,int r,int k)
{
if(l==r){trns[x]=k;trmx[x]=-0x3f3f3f3f;trl[x]=trr[x]=0;return;}
int mid=(l+r)>>1;
if(k<=mid)del(L(x),l,mid,k);
else del(R(x),mid+1,r,k);
pushup(x,l,r);
}
pair<int,int> lca(int x,int l,int r,int res1,int res2)
{
if(res1<=l&&r<=res2)return make_pair(trmx[x],trns[x]);
int mid=(l+r)>>1;
if(res2<=mid)return lca(L(x),l,mid,res1,res2);
if(mid<res1)return lca(R(x),mid+1,r,res1,res2);
pair<int,int> flag1=lca(L(x),l,mid,res1,res2);
pair<int,int> flag2=lca(R(x),mid+1,r,res1,res2);
if(flag1<flag2)return flag2;return flag1;
}
pair<int,int> askl(int x,int l,int r,int k,int w)
{
if(!k)return make_pair(0,0);
if(r<=k)return make_pair(max(w,trmx[x]),pushupr(x,l,r,w));
int mid=(l+r)>>1;
if(k<=mid)return askl(L(x),l,mid,k,w);
pair<int,int> flag=askl(R(x),mid+1,r,k,w);
pair<int,int> res=askl(L(x),l,mid,k,max(w,flag.first));
res.second+=flag.second;
return res;
}
pair<int,int> askr(int x,int l,int r,int k,int w)
{
if(l>cnt)return make_pair(0,0);
if(k<=l)return make_pair(max(w,trmx[x]),pushupl(x,l,r,w));
int mid=(l+r)>>1;
if(mid<k)return askr(R(x),mid+1,r,k,w);
pair<int,int> flag=askr(L(x),l,mid,k,w);
pair<int,int> res=askr(R(x),mid+1,r,k,max(w,flag.first));
res.second+=flag.second;
return res;
}
int getdep(int x){return askl(1,1,cnt,x-1,pos[x]).second+askr(1,1,cnt,x+1,pos[x]).second;}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&e[i].opt);
switch(e[i].opt)
{
case 0:scanf("%d%d",&e[i].k,&e[i].w);a[++a[0]]=e[i].k;break;
case 1:scanf("%d",&e[i].k);break;
case 2:scanf("%d%d",&e[i].k,&e[i].w);if(e[i].k>e[i].w)swap(e[i].k,e[i].w);break;
}
}
sort(a+1,a+a[0]+1);
for(int i=1;i<=a[0];i++){if(a[i]!=a[i-1])cnt++;mp[a[i]]=cnt;}
memset(trmx,-0x3f,sizeof(trmx));
for(int i=1;i<=n;i++)
switch(e[i].opt)
{
case 0:add(1,1,cnt,mp[e[i].k],e[i].w);pos[mp[e[i].k]]=e[i].w;break;
case 1:del(1,1,cnt,mp[e[i].k]);pos[mp[e[i].k]]=-0x3f3f3f3f;break;
case 2:printf("%d
",getdep(mp[e[i].k])+getdep(mp[e[i].w])-2*getdep(lca(1,1,cnt,mp[e[i].k],mp[e[i].w]).second));break;
}
return 0;
}
rp++