题目链接
写在前面:
显而易见的线段树开根号模板,网上的题解比比皆是,但大多数的思路是(sqrt{0}=0,sqrt{1}=1),除此之外的都是一个一个开根号,而一个元素最多开6次,所以时间最多可以卡到其它线段树的6倍,而我用的开根号则会尝试不去一个一个地去开根号.
思路:
维护每个线段树节点的区间最大值和最小值,如果这个区间最大值开根后要减去的数和最小值一样,那么用加法操作使这段区间减去这个数,否则继续向下二分,因为原来的思路也要维护一个最大值,所以原思路的优化在这里也可以用上.
具体优化的地方:
例如要开方区间[10,10,10],原思路要对3+3=6个数开方,而此思路只要对2个数开方(10和3),所以理论上此思路更快(尤其是在大数据下),但似乎是因为常数和多了个add的问题,实际效果并不比原思路快:
接下来是贴代码了(注释部分为调试功能):
(为了突出被执行语段,注释与背景主题近色,如需查看,请选中语段或者放入IDE)
#include <bits/stdc++.h>
using namespace std;
long long n,m;
long long a[100005];
struct node{
long long l;
long long r;
long long sum;
long long minn;
long long maxx;
long long lz;
}tr[400005];
void pushup(long long u){
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
tr[u].maxx=max(tr[u<<1].maxx,tr[u<<1|1].maxx);
tr[u].minn=min(tr[u<<1].minn,tr[u<<1|1].minn);
}
void pushdown(long long u){
tr[u<<1].sum+=(tr[u<<1].r-tr[u<<1].l+1)*tr[u].lz;
tr[u<<1].lz+=tr[u].lz;
tr[u<<1].maxx+=tr[u].lz;
tr[u<<1].minn+=tr[u].lz;
tr[u<<1|1].sum+=(tr[u<<1|1].r-tr[u<<1|1].l+1)*tr[u].lz;
tr[u<<1|1].lz+=tr[u].lz;
tr[u<<1|1].maxx+=tr[u].lz;
tr[u<<1|1].minn+=tr[u].lz;
tr[u].lz=0;
}
void biuld(long long u,long long l,long long r){
tr[u].l=l;
tr[u].r=r;
if(l==r){
tr[u].maxx=a[l];
tr[u].minn=a[l];
tr[u].sum=a[l];
tr[u].lz=0;
return ;
}
long long mid=(tr[u].l+tr[u].r)>>1;
biuld(u<<1,l,mid);
biuld(u<<1|1,mid+1,r);
pushup(u);
}
void add(long long u,long long l,long long r,long long k){
if(l<=tr[u].l&&tr[u].r<=r){
tr[u].lz+=k;
tr[u].sum+=(tr[u].r-tr[u].l+1)*k;
tr[u].maxx+=k;
tr[u].minn+=k;
return ;
}
pushdown(u);
long long mid=(tr[u].l+tr[u].r)>>1;
if(l<=mid)add(u<<1,l,r,k);
if(mid<r)add(u<<1|1,l,r,k);
pushup(u);
}
long long query(long long u,long long l,long long r){
// cout<<"query:"<<u<<' '<<tr[u].l<<' '<<tr[u].r<<endl;
if(l<=tr[u].l&&tr[u].r<=r){
return tr[u].sum;
}
pushdown(u);
long long mid=(tr[u].l+tr[u].r)>>1,res=0;
if(l<=mid)res+=query(u<<1,l,r);
if(mid<r)res+=query(u<<1|1,l,r);
return res;
}
void sqroot(long long u,long long l,long long r){
if(tr[u].maxx<=1)return ;
// cout<<"sqrt:"<<u<<' '<<tr[u].l<<' '<<tr[u].r<<' '<<tr[u].sum<<endl;
if(floor(sqrt(tr[u].minn))-(long long)tr[u].minn==floor(sqrt(tr[u].maxx))-(long long)tr[u].maxx&&l<=tr[u].l&&tr[u].r<=r){
// cout<<floor(sqrt(tr[u].minn))-tr[u].minn<<' '<<tr[u].maxx<<' '<<tr[u].minn<<endl;
add(u,tr[u].l,tr[u].r,floor(sqrt(tr[u].minn))-tr[u].minn);
return ;
}
pushdown(u);
long long mid=(tr[u].l+tr[u].r)>>1;
if(l<=mid)sqroot(u<<1,l,r);
if(mid<r)sqroot(u<<1|1,l,r);
pushup(u);
}
int main() {
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
scanf("%lld",&n);
for(long long i=1;i<=n;i++)cin>>a[i];
biuld(1,1,n);
scanf("%lld",&m);
while(m--){
long long x,l,r;
scanf("%lld%lld%lld",&x,&l,&r);
if(l>r)swap(l,r);
if(x==1){
printf("%lld
",query(1,l,r));
}else{
sqroot(1,l,r);
/*
for(long long i=1;i<=n;i++)cout<<query(1,i,i)<<' ';
cout<<endl;
cout<<tr[93].maxx<<' '<<tr[93].minn<<' '<<tr[93].sum<<endl;
*/
}
}
return 0;
}
后记:
其实还有更好的优化思路,比如:如果某区间开方后的数一样,那么直接将整个区间重新赋值.这里不多赘述(主要是因为懒).