题意
给出 (n) 个区间 ([l_i,r_i]),问满足 (a_i in [l_i,r_i]) 的一个数列 (a) 的最长上升子序列长度
(n le 3 imes 10^5,1 le l_i,r_i le 10^9)
传送门
思路
第一次考试时很快想出了简单暴力原型 (dp[i]) 表示长度为 (i) 的子序列末尾的最小值,暴力转移:
- 当 (dp[i-1]<l) 时,(dp[i]=min(dp[i],l))
- 当 (l le dp[i-1] <r) 时,(dp[i]=min(dp[i],dp[i-1]+1))
- 当 (r<dp[i-1]) 时,不用更新
考虑如何优化。第二次见到它,还是没有任何思路。
有一个很重要且显然的性质,(dp) 数组一定是单调递增的,观察第二个操作,一定是会更新的,所以说第二个操作就从取小变成了赋值。还有一点,因为各不相同,所以答案数组中的元素个数,就是最后的答案,然后就可以用平衡树优化了。
第一种转移只会修改最后一个 (dp[j−1]<l) 的位置。因为下一个一定(≥l),不满足转移条件。且前面的值本来就都(<l),前面不用更新。
第二个操作就是一个区间平移和区间加法。转移后最后一个 (dp[j]<r) 后的值就没了(被覆盖了),只要把原先的 (dp[nxt(j)]) 删掉。然后在那个直接赋值的位置的地方直接插入一个 (l) 的值,后面的就自动往后退了一位。
最后输出总大小即可。
#include <bits/stdc++.h>
const int N=300005;
int tag[N],ch[N][2],v[N],f[N],n,l,r,cnt,s[N],tp,rt,ans;
void pushdown(int x){
int t=tag[x];
if (t){
if (ch[x][0]) tag[ch[x][0]]+=t,v[ch[x][0]]+=t;
if (ch[x][1]) tag[ch[x][1]]+=t,v[ch[x][1]]+=t;
tag[x]=0;
}
}
void rotate(int x,int &k){
int y=f[x],z=f[y],kind=ch[y][1]==x;
if (y==k) k=x;
else if (ch[z][0]==y) ch[z][0]=x;
else ch[z][1]=x;
ch[y][kind]=ch[x][kind^1],f[ch[x][kind^1]]=y;
ch[x][kind^1]=y,f[y]=x,f[x]=z;
}
int splay(int x,int &k){
s[tp=1]=x;
for (int i=x;f[i];i=f[i]) s[++tp]=f[i];
while (tp) pushdown(s[tp--]);
while (x!=k){
int y=f[x],z=f[y];
if (y!=k){
if ((ch[y][0]==x)^(ch[z][0]==y)) rotate(x,k);
else rotate(y,k);
}
rotate(x,k);
}
}
int find(int x){
int u=rt,ret=0;
while (u){
pushdown(u);
if (v[u]<x) ret=u,u=ch[u][1];
else u=ch[u][0];
}
return ret;
}
int getpre(int x){
splay(x,rt);
int u=ch[x][0];
while (ch[u][1]) u=ch[u][1];
return u;
}
int getnxt(int x){
splay(x,rt);
int u=ch[x][1];
while (ch[u][0]) u=ch[u][0];
return u;
}
void del(int x){
int pre=getpre(x),nxt=getnxt(x);
splay(pre,rt),splay(nxt,ch[rt][1]);
ch[nxt][0]=0,f[x]=0;
}
void ins(int x){
int u=rt,fa=0;
while (u){
pushdown(u);
fa=u;
if (v[u]<=x) u=ch[u][1];
else u=ch[u][0];
}
u=++cnt;
if (fa) ch[fa][x>=v[fa]]=u;
f[u]=fa,v[u]=x;
if (u==1) rt=u;
splay(u,rt);
}
void solve(int L,int R){
int l=find(L),r=find(R),nt=getnxt(r);
if (l!=r){
splay(l,rt),splay(nt,ch[l][1]);
int t=ch[nt][0];
tag[t]+=1,v[t]+=1;
}
if (nt!=2) ans--,del(nt);
ans++,ins(L);
}
int main(){
scanf("%d",&n);
ins(-1);ins(1e9+1);
for (int i=1;i<=n;i++){
scanf("%d%d",&l,&r);
solve(l,r);
}
printf("%d
",ans);
return 0;
}