本质是维护斜率递增序列。
用分块的方法就是把序列分成sqrt(n)块,每个块分别用一个vector维护递增序列。查询的时候遍历所有的块,同时维护当前最大斜率,二分找到每个块中比当前最大斜率大的那个点。修改的时候只需要修改点所在的那个块即可。复杂度$O(msqrt nlogn)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef double db; 5 const int N=1e5+10; 6 int n,m,h[N],in[N],L[N],R[N],sqrtn,n2; 7 vector<db> v[1000]; 8 9 int main() { 10 scanf("%d%d",&n,&m),sqrtn=sqrt(n+0.5); 11 for(int i=1; i<=n; ++i)L[i]=-1; 12 for(int i=1; i<=n; ++i) { 13 in[i]=i/sqrtn; 14 if(!~L[in[i]])L[in[i]]=i; 15 R[in[i]]=i; 16 n2=max(n2,in[i]); 17 } 18 while(m--) { 19 int x,y; 20 scanf("%d%d",&x,&y); 21 h[x]=y; 22 v[in[x]].clear(); 23 for(int i=L[in[x]]; i<=R[in[x]]; ++i) { 24 db t=(db)h[i]/i; 25 if(v[in[x]].size()&&v[in[x]].back()>t)continue; 26 v[in[x]].push_back(t); 27 } 28 db mx=0; 29 int ans=0; 30 for(int i=0; i<=n2; ++i) { 31 int j=upper_bound(v[i].begin(),v[i].end(),mx)-v[i].begin(); 32 ans+=v[i].size()-j; 33 if(v[i].size())mx=max(mx,v[i].back()); 34 } 35 printf("%d ",ans); 36 } 37 return 0; 38 }
用线段树的方法是利用“每个区间的递增序列长度为左区间的递增序列长度加上右区间中比左区间最大值大的递增序列长度”的性质来维护递增序列长度。左区间的递增序列长度可以直接加上,右区间的与左区间的最大值有关,姑且用$qry(mx[ls],rs)$来表示($mx[ls]$代表左区间最大值),则有区间合并公式:$cnt[u]=cnt[ls]+qry(mx[ls],rs)$。
然后就是$qry(mx[ls],rs)$的计算问题了。设$qry(x,u)$为区间u中比x大的部分的递增序列长度,则分两种情况讨论:
1)u的左区间最大值大于x,此时u的递增序列长度中位于右区间中的部分全部包含,因此$qry(x,u)=cnt[u]-cnt[ls]+qry(x,ls)$(注意不是$cnt[rs]+qry(x,rs)$,因为右区间中的部分序列会被左区间挡住)。
2)u的左区间最大值小于等于x,此时u的递增序列长度中位于左区间中的部分全部不包含,因此$qry(x,u)=qry(x,rs)$。
综上,只需要维护每个区间的最大值和递增序列长度,即可在$O(log^2n)$的时间内完成一次修改操作,而查询操作是$O(1)$的,因此总时间复杂度为$O(mlog^2n)$。
再一次体会到了区间分治的威力。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef double db; 5 const int N=1e5+10; 6 int cnt[N<<2],n,m; 7 db mx[N<<2]; 8 #define mid ((l+r)>>1) 9 #define ls (u<<1) 10 #define rs (u<<1|1) 11 int qry(db x,int u,int l,int r) { 12 if(mx[u]<=x)return 0; 13 if(l==r)return cnt[u]; 14 return mx[ls]>x?cnt[u]-cnt[ls]+qry(x,ls,l,mid):qry(x,rs,mid+1,r); 15 } 16 void pu(int u,int l,int r) { 17 mx[u]=max(mx[ls],mx[rs]); 18 cnt[u]=cnt[ls]+qry(mx[ls],rs,mid+1,r); 19 } 20 void upd(int p,db x,int u=1,int l=1,int r=n) { 21 if(l==r) {cnt[u]=1,mx[u]=x; return;} 22 p<=mid?upd(p,x,ls,l,mid):upd(p,x,rs,mid+1,r); 23 pu(u,l,r); 24 } 25 int main() { 26 scanf("%d%d",&n,&m); 27 while(m--) { 28 int x,y; 29 scanf("%d%d",&x,&y); 30 upd(x,(db)y/x); 31 printf("%d ",cnt[1]); 32 } 33 return 0; 34 }