【算法】线段树||树状数组&&并查集
【题解】修改必须暴力单点修改,然后利用标记区间查询。
优化:一个数经过不断开方很快就会变成1,所以维护区间最大值。
修改时访问到的子树最大值<=1时,该区间就不必修改。
#include<cstdio> #include<cmath> #include<algorithm> using namespace std; const int maxn=100010; struct treess{int k,l,r;long long maxs,sum;}t[maxn*3]; int n,m;long long a[maxn]; void build(int k,int l,int r) { t[k].l=l;t[k].r=r; if(l==r)t[k].maxs=t[k].sum=a[l]; else { int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); t[k].maxs=max(t[k<<1].maxs,t[k<<1|1].maxs);//printf("k=%d maxs=%d",k,t[k].maxs); t[k].sum=t[k<<1].sum+t[k<<1|1].sum; } } void update(int k,int l,int r) { int left=t[k].l,right=t[k].r; if(t[k].maxs<=1)return; if(left==right)a[left]=floor(sqrt(a[left])),t[k].maxs=t[k].sum=a[left]; else { int mid=(left+right)>>1; if(l<=mid)update(k<<1,l,r); if(r>mid)update(k<<1|1,l,r); t[k].maxs=max(t[k<<1].maxs,t[k<<1|1].maxs); t[k].sum=t[k<<1].sum+t[k<<1|1].sum; } } long long ask(int k,int l,int r) { int left=t[k].l,right=t[k].r; if(l<=left&&right<=r)return t[k].sum; int mid=(left+right)>>1;long long ans=0; if(l<=mid)ans=ask(k<<1,l,r); if(r>mid)ans+=ask(k<<1|1,l,r); return ans; } int main() { scanf("%d",&n); for(int i=1;i<=n;i++)scanf("%lld",&a[i]); scanf("%d",&m); build(1,1,n); for(int i=1;i<=m;i++) { int k,l,r; scanf("%d%d%d",&k,&l,&r); if(l>r)swap(l,r); if(k==0)update(1,l,r); else printf("%lld ",ask(1,l,r)); } return 0; }
并查集将ai=1的节点并起来,也就是fa[i]表示i后第一个ai≠1的节点,然后用树状数组单点修改区间维护前缀和。
带删除的寻数问题,都可以用类似的并查集套路解决。
#include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #define lowbit(x) x&(-x) #define ll long long using namespace std; const ll maxn=100010; ll n,m,fa[maxn],a[maxn]; long long c[maxn]; void modify(ll x,ll k){for(ll i=x;i<=n;i+=lowbit(i))c[i]+=k;} long long ask(ll x){long long as=0;for(ll i=x;i>=1;i-=lowbit(i))as+=c[i];return as;} ll find(ll x){return fa[x]==x?x:fa[x]=find(fa[x]);} int main(){ scanf("%lld",&n); for(ll i=1;i<=n;i++)scanf("%lld",&a[i]),modify(i,a[i]); for(ll i=1;i<=n+1;i++)fa[i]=i; scanf("%lld",&m); for(ll i=1;i<=m;i++){ ll k,l,r; scanf("%lld%lld%lld",&k,&l,&r); if(l>r)swap(l,r); if(k==1)printf("%lld ",ask(r)-ask(l-1)); else{ for(ll j=l;j<=r;j++){ j=find(j); if(j<=r){ modify(j,(ll)sqrt(a[j])-a[j]); a[j]=(ll)sqrt(a[j]); if(a[j]<=1)fa[j]=find(j+1); } } } } return 0; }