P8349 [SDOI/SXOI2022] 整数序列 解题报告:
题意
给定序列 \(a,b\),\(q\) 次询问两种颜色 \((x,y)\),定义一个区间的权值为其中 \(a\) 值为 \(x\) 或 \(y\) 的 \(b\) 值之和,求 \(x\) 数量等于 \(y\) 数量的子区间权值最大值。
\(1\leqslant n\leqslant 3\times 10^5,1\leqslant q\leqslant 10^6,\text{7s}\)。
分析
轻视了这道题啊。
考虑根号分治,令阈值为 \(S\)。
小集合对小集合直接把位置归并成一个序列暴力扫一遍即可,大集合对大集合可以直接暴力+记忆化,这里的复杂度是 \(O(qS+\frac{n^2}{S})\)。
难点主要是小集合对大集合(令 \(x\) 为小集合,\(y\) 为大集合),我们发现我们将 \(x\) 和 \(y\) 对应位置取出来之后,对于一个 \(y\) 连续段,只要其长度大到没有合法子区间能跨越它,那么它的长度就是无关紧要的了,可以缩到恰好满足条件的长度。
更具体地说,我们维护一个 \(y\) 坐标集合,对于每个长度为 \(x\) 的连续段,都向前找 \(x+1\) 个没加入集合的 \(y\) 加入集合,向后找 \(x+1\) 个没加入集合的 \(y\) 加入集合,那么保留集合内的 \(y\) 与 \(x\) 进行运算也能保证答案的正确性。
而这样得到的集合大小是不超过 \(2S\) 的,那么我们得到了一个 \(O(qS\log n+\frac{n^2}{S})\) 的做法,令 \(S=\frac{n}{\sqrt{q\log n}}\) 就是 \(O(n\sqrt{q\log n})\) 了。
实际上这个 \(\log\) 可以去掉。我们离线,把询问挂在对应大集合上。
扫两遍,分开维护向左取的 \(y\) 和向右取的 \(y\),每次用一个栈维护若干个连续的连续段集合,判断新加入的集合能否和栈顶合并显然可以 \(O(1)\)。最后我们将每个连续段内部的 \(y\) 以及其向左(或向右)还能扩展的 \(y\) 加入集合即可。
这样的复杂度就是 \(O(qS+\frac{n^2}{S})\) 了,令 \(S=\frac{n}{\sqrt q}\) 即可得到 \(O(n\sqrt q)\) 的复杂度。
代码
#include<stdio.h>
#include<string.h>
#include<vector>
#include<algorithm>
#include<map>
char buf[1<<23],*p1=buf,*p2=buf,obuf[1<<23],*O=obuf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<23,stdin),p1==p2)?EOF:*p1++)
using namespace std;
const int maxn=300005,maxq=1000005,S=300;
int n,m,ps,top,tmps,stp;
int a[maxn],b[maxn],stk[maxn][3],pos[maxn],tmp[maxn],tmpp[maxn],c[maxn],xx[maxq],yy[maxq],cnt[maxn],vis[maxn];
long long inf;
long long ans[maxq],mn[maxn<<1];
vector<int>v[maxn],q[maxn];
map<pair<int,int>,long long>mp;
long long solve(int x,int y){
for(int i=1;i<=ps;i++)
c[i]=c[i-1]+(a[pos[i]]==x? 1:-1);
long long res=-1e18,s=0;
mn[ps]=0;
for(int i=1;i<=ps;i++)
s+=b[pos[i]],res=max(res,s-mn[c[i]+ps]),mn[c[i]+ps]=min(mn[c[i]+ps],s);
for(int i=0;i<=ps;i++)
mn[c[i]+ps]=inf;
return res;
}
void read(int &x){
x=0;
char c=getchar();
int f=0;
for(;c<'0'||c>'9';c=getchar())
f|=(c=='-');
for(;c>='0'&&c<='9';c=getchar())
x=x*10+c-48;
if(f==1)
x=-x;
}
void print(long long x){
if(x<0)
putchar('-'),x=-x;
if(x>9)
print(x/10);
putchar(x%10+48);
}
int main(){
// freopen("tmp.in","r",stdin);
// freopen("tmp.out","w",stdout);
memset(mn,127,sizeof(mn)),inf=mn[0];
read(n),read(m);
for(int i=1;i<=n;i++)
read(a[i]),v[a[i]].push_back(i);
for(int i=1;i<=n;i++)
read(b[i]);
for(int t=1,x,y;t<=m;t++){
read(x),read(y),xx[t]=x,yy[t]=y;
if((v[x].size()>S)+(v[y].size()>S)==1){
if(v[x].size()>S)
q[x].push_back(t);
else q[y].push_back(t);
continue;
}
if(mp.count(make_pair(min(x,y),max(x,y)))){
ans[t]=mp[make_pair(min(x,y),max(x,y))];
continue;
}
int i=0,j=0;
ps=0;
while(i<v[x].size()||j<v[y].size()){
if(i<v[x].size()&&(j==v[y].size()||v[x][i]<v[y][j]))
pos[++ps]=v[x][i],i++;
else pos[++ps]=v[y][j],j++;
}
ans[t]=mp[make_pair(min(x,y),max(x,y))]=solve(x,y);
}
for(int t=1;t<=n;t++)
if(v[t].size()>S&&q[t].size()){
for(int i=1;i<=n;i++)
cnt[i]=cnt[i-1]+(a[i]==t);
for(int p=0;p<q[t].size();p++){
int k=q[t][p],r=xx[k]+yy[k]-t;
ps=top=tmps=0,stp++;
for(int i=0;i<v[r].size();i++)
pos[++ps]=v[r][i];
for(int i=0;i<v[r].size();i++){
int j=i;
while(j+1<v[r].size()&&cnt[v[r][j+1]-1]==cnt[v[r][j]])
j++;
int p=v[r][i],q=v[r][j],c=j-i+2;
while(top){
int w=cnt[p-1]-cnt[stk[top][2]];
if(w<=c)
p=stk[top][1],c=c+stk[top][3]-w,top--;
else break;
}
top++,stk[top][0]=c,stk[top][1]=p,stk[top][2]=q,i=j;
}
int lst=0;
for(int i=1;i<=top;i++){
int c=stk[i][0],p=stk[i][1],q=stk[i][2];
lst=max(lst,cnt[p]-c);
while(lst<=cnt[q]-1){
if(vis[v[t][lst]]!=stp)
tmp[++tmps]=v[t][lst],vis[v[t][lst]]=stp;
lst++;
}
}
merge(pos+1,pos+1+ps,tmp+1,tmp+1+tmps,tmpp+1);
int tmpps=ps+tmps;
tmps=top=0;
for(int i=v[r].size()-1;i>=0;i--){
int j=i;
while(j>=1&&cnt[v[r][j-1]]==cnt[v[r][j]-1])
j--;
int p=v[r][j],q=v[r][i],c=i-j+2;
while(top){
int w=cnt[stk[top][1]-1]-cnt[q];
if(w<=c)
q=stk[top][2],c=c+stk[top][3]-w,top--;
else break;
}
top++,stk[top][0]=c,stk[top][1]=p,stk[top][2]=q,i=j;
}
lst=cnt[n]-1;
for(int i=1;i<=top;i++){
int c=stk[i][0],p=stk[i][1],q=stk[i][2];
lst=min(lst,cnt[q]+c-1);
while(lst>=cnt[p]){
if(vis[v[t][lst]]!=stp)
tmp[++tmps]=v[t][lst],vis[v[t][lst]]=stp;
lst--;
}
}
reverse(tmp+1,tmp+1+tmps),merge(tmpp+1,tmpp+1+tmpps,tmp+1,tmp+1+tmps,pos+1);
ps=tmpps+tmps,ans[k]=solve(t,r);
}
}
for(int i=1;i<=m;i++)
print(ans[i]),putchar('\n');
return 0;
}