题目:https://www.luogu.org/problemnew/show/P2221
似乎按点来算贡献很方便,但我抱住一篇没有这样的题解磕了两天...
以下转载:
题意:维护一段数列 支持区间加和求区间所有子区间的和的和
一看就知道要用线段树 于是用sum表示区间所有子区间的和的和 但是知道两个区间的sum并不能求出这两个区间并起来之后的sum 这时我们可以手玩一下
sum(1 2 3 4)=(1)+(2)+(1 2)+(3)+(4)+(3 4)+(2 3)+(1 2 3)+(2 3 4)+(1 2 3 4)竖着写在纸上 就会发现除了sum(1 2)和sum(3 4)外 所有包含左区间右端点的区间和(rsum)恰被算了len(右区间)次 同理包含右区间左端点的区间和(lsum)被算了len(左区间)次
于是 我们得到了sum(a b)=sum(a)+sum(b)+rsum(a)·len(b)+lsum(b)·len(a)
那么如何维护lsum和rsum呢 显然lsum(a b)=lsum(a)+lsum(b)+区间和(a)*len(b) rsum同理
至此维护操作已经没问题了 pushdown操作就简单很多了 当一个区间整体加a时 它的sum增加了a(n+2(n-1)+3(n-2)+...+n)也就是a((1+...+n)n-(1^2+...+(n-1)^2)-(1+...(n-1))) lsum rsum 同理 最后注意取GCD
看样子很精妙,似乎也挺好写的;
然而因为还是理解不到位,犯了众多细节的错误,调了两天,令人疯狂:
1.线段树上的每个点是车站之间的一段,所以读入的n和r都要-1;
2.需要另外记录纯粹的区间和(s),用于计算sum,lsum,rsum等等,而不是直接用sum一概计算;
3.采用那样的build写法,竟然一度忘记在pushup里更新len,调了半天;
*4.注意这个背景下的线段树与平常的线段树写法有所不同:
平常的线段树在修改、查询时都会传下去一个目标区间,一般来说这个目标区间不用更改,判断只需看是否在其内部就可以;
但这道题的计算是需要用到目标区间的,所以在把操作下放到子区间时,其目标区间也必须相应有所修改,否则在计算时就会大错特错;
也正因此,必须进行判断:目标区间是完全在左儿子,还是完全在右儿子,还是在中间;不同的情况中目标区间的修改也有所不同。
虽然这种写法似乎不是很优秀,但在线段树中痛苦挣扎的过程中,我认识到就算是线段树也不能不求甚解,靠习惯打板子;必须根据题目灵活调整,否则就会出现错误。
错误版本1:
#include<iostream> #include<cstdio> #include<cstring> using namespace std; int const maxn=4e5+5; int n,m; struct N{ int sum,lsum,rsum,len,lzy; int s;//! }t[maxn]; int gcd(int x,int y){return y?gcd(y,x%y):x;} void yf(int x,int nn)//nn { int y=(1+nn)*nn/2; int a=x,b=y; if(a<b)swap(a,b); int c=gcd(a,b); printf("%d/%d ",x/c,y/c); } int cala(int x)//add { int nn=t[x].len; int n=nn-1; return (1+nn)*nn/2*nn-(1+n)*n/2-n*(n+1)*(n*2+1)/6;// } int calb(int x)//lsum,rsum add { int nn=t[x].len; return (1+nn)*nn/2; } void pushup(int x) { int ls=(x<<1),rs=(x<<1|1); t[x].len=t[x<<1].len+t[x<<1|1].len; t[x].s=t[x<<1].s+t[x<<1|1].s; t[x].sum=t[ls].sum+t[rs].sum+t[ls].rsum*t[rs].len+t[rs].sum*t[ls].len; t[x].lsum=t[ls].lsum+t[rs].lsum+t[ls].s*t[rs].len;//s t[x].rsum=t[ls].rsum+t[rs].rsum+t[rs].s*t[ls].len; } void build(int l,int r,int x) { if(l==r) { t[x].len=1; return; } int mid=((l+r)>>1); build(l,mid,x<<1); build(mid+1,r,x<<1|1); pushup(x); } void pushdown(int x) { if(t[x].lzy) { int v=t[x].lzy;t[x].lzy=0; int ls=(x<<1),rs=(x<<1|1); t[ls].s+=t[ls].len*v; t[ls].sum+=cala(ls)*v; t[ls].lsum+=calb(ls)*v; t[ls].rsum+=calb(ls)*v; t[ls].lzy+=v; t[rs].s+=t[rs].len*v; t[rs].sum+=cala(rs)*v; t[rs].lsum+=calb(rs)*v; t[rs].rsum+=calb(rs)*v; t[rs].lzy+=v; } } void add(int l,int r,int L,int R,int v,int x) { if(l>=L&&r<=R) { t[x].s+=t[x].len*v; t[x].sum+=cala(x)*v; t[x].lsum+=calb(x)*v; t[x].rsum+=calb(x)*v; t[x].lzy+=v; return; } pushdown(x); int mid=((l+r)>>1); if(L<=mid)add(l,mid,L,R,v,x<<1); if(R>mid)add(mid+1,r,L,R,v,x<<1|1); pushup(x); } int qs(int l,int r,int L,int R,int x) { if(l>=L&&r<=R)return t[x].s; pushdown(x); int ret=0; int mid=((l+r)>>1); if(L<=mid)ret+=qs(l,mid,L,R,x<<1); if(R>mid)ret+=qs(mid+1,r,L,R,x<<1|1); return ret; } int qls(int l,int r,int L,int R,int x) { if(l>=L&&r<=R)return t[x].lsum; pushdown(x); int mid=((l+r)>>1); if(L>mid)return qls(mid+1,r,L,R,x<<1|1); if(R<=mid)return qls(l,mid,L,R,x<<1); return qls(mid+1,r,L,R,x<<1|1)+qls(l,mid,L,R,x<<1)+qs(l,mid,L,R,x<<1)*t[x<<1|1].len; } int qrs(int l,int r,int L,int R,int x) { if(l>=L&&r<=R)return t[x].rsum; pushdown(x); int mid=((l+r)>>1); if(L>mid)return qrs(mid+1,r,L,R,x<<1|1); if(R<=mid)return qrs(l,mid,L,R,x<<1); return qrs(mid+1,r,L,R,x<<1|1)+qrs(l,mid,L,R,x<<1)+qs(mid+1,r,L,R,x<<1|1)*t[x<<1].len; } int query(int l,int r,int L,int R,int x) { if(l>=L&&r<=R)return t[x].sum; pushdown(x); int mid=((l+r)>>1); if(L>mid)return query(mid+1,r,L,R,x<<1|1); if(R<=mid)return query(l,mid,L,R,x<<1); return query(mid+1,r,L,R,x<<1|1)+query(l,mid,L,R,x<<1) // +qrs(l,mid,L,R,x<<1)*t[x<<1|1].len+qls(mid+1,r,L,R,x<<1|1)*t[x<<1].len; +qrs(l,mid,L,R,x<<1)*(R-mid)+qls(mid+1,r,L,R,x<<1|1)*(mid-L+1); } int main() { scanf("%d%d",&n,&m); n--;// build(1,n,1); for(int i=1,l,r,v;i<=m;i++) { char dc; cin>>dc; scanf("%d%d",&l,&r);//车站而非间距 r--; if(dc=='C') { scanf("%d",&v); add(1,n,l,r,v,1); } if(dc=='Q') { int k=query(1,n,l,r,1); yf(k,r-l+1); } } return 0; }
错误版本2:
#include<iostream> #include<cstdio> #include<cstring> using namespace std; typedef long long ll; ll const maxn=5e5+5; ll n,m; struct N{ ll sum,lsum,rsum,len,lzy; ll s;//! }t[maxn]; ll gcd(ll x,ll y){return y?gcd(y,x%y):x;} void yf(ll x,ll nn)//nn { ll y=(1+nn)*nn/2; ll a=x,b=y; if(a<b)swap(a,b); ll c=gcd(a,b); printf("%lld/%lld ",x/c,y/c); } ll cala(ll x)//add { ll nn=t[x].len; ll n=nn-1; return (1+nn)*nn/2*nn-(1+n)*n/2-n*(n+1)*(n*2+1)/6;// } ll calb(ll x)//lsum,rsum add { ll nn=t[x].len; return (1+nn)*nn/2; } void pushup(ll x) { ll ls=(x<<1),rs=(x<<1|1); // t[x].len=t[x<<1].len+t[x<<1|1].len; t[x].s=t[x<<1].s+t[x<<1|1].s; t[x].sum=t[ls].sum+t[rs].sum+t[ls].rsum*t[rs].len+t[rs].lsum*t[ls].len; t[x].lsum=t[ls].lsum+t[rs].lsum+t[ls].s*t[rs].len;//s t[x].rsum=t[ls].rsum+t[rs].rsum+t[rs].s*t[ls].len; } void build(ll l,ll r,ll x) { t[x].len=r-l+1; if(l==r) { // t[x].len=1; return; } ll mid=((l+r)>>1); build(l,mid,x<<1); build(mid+1,r,x<<1|1); pushup(x); } void pushdown(ll x) { if(t[x].lzy) { ll v=t[x].lzy;t[x].lzy=0; ll ls=(x<<1),rs=(x<<1|1); t[ls].s+=t[ls].len*v; t[ls].sum+=cala(ls)*v; t[ls].lsum+=calb(ls)*v; t[ls].rsum+=calb(ls)*v; t[ls].lzy+=v; t[rs].s+=t[rs].len*v; t[rs].sum+=cala(rs)*v; t[rs].lsum+=calb(rs)*v; t[rs].rsum+=calb(rs)*v; t[rs].lzy+=v; } } void add(ll l,ll r,ll L,ll R,ll v,ll x) { if(l>=L&&r<=R)// { t[x].s+=t[x].len*v; t[x].sum+=cala(x)*v; t[x].lsum+=calb(x)*v; t[x].rsum+=calb(x)*v; t[x].lzy+=v; return; } pushdown(x); ll mid=((l+r)>>1); if(R<=mid)add(l,mid,L,R,v,x<<1); else if(L>mid)add(mid+1,r,L,R,v,x<<1|1); else { add(l,mid,L,R,v,x<<1); add(mid+1,r,L,R,v,x<<1|1); } pushup(x); } ll qs(ll l,ll r,ll L,ll R,ll x) { if(l>=L&&r<=R)return t[x].s;// pushdown(x); ll mid=((l+r)>>1); if(R<=mid)return qs(l,mid,L,R,x<<1); if(L>mid)return qs(mid+1,r,L,R,x<<1|1); return qs(l,mid,L,R,x<<1)+qs(mid+1,r,L,R,x<<1|1); } ll qls(ll l,ll r,ll L,ll R,ll x) { if(l>=L&&r<=R)return t[x].lsum;// pushdown(x); ll mid=((l+r)>>1); if(L>mid)return qls(mid+1,r,L,R,x<<1|1); if(R<=mid)return qls(l,mid,L,R,x<<1); return qls(mid+1,r,L,R,x<<1|1)+qls(l,mid,L,R,x<<1)+qs(l,mid,L,R,x<<1)*(R-mid); } ll qrs(ll l,ll r,ll L,ll R,ll x) { if(l>=L&&r<=R)return t[x].rsum;// pushdown(x); ll mid=((l+r)>>1); if(L>mid)return qrs(mid+1,r,L,R,x<<1|1); if(R<=mid)return qrs(l,mid,L,R,x<<1); return qrs(mid+1,r,L,R,x<<1|1)+qrs(l,mid,L,R,x<<1)+qs(mid+1,r,L,R,x<<1|1)*(mid-L+1); } ll query(ll l,ll r,ll L,ll R,ll x) { // cout<<l<<" "<<r<<endl; // printf("*%lld %lld ",L,R); if(l>=L&&r<=R)return t[x].sum;// pushdown(x); ll mid=((l+r)>>1); if(L>mid)return query(mid+1,r,L,R,x<<1|1); if(R<=mid)return query(l,mid,L,R,x<<1); return query(mid+1,r,L,R,x<<1|1)+query(l,mid,L,R,x<<1) +qrs(l,mid,L,R,x<<1)*(R-mid)+qls(mid+1,r,L,R,x<<1|1)*(mid-L+1); } int main() { scanf("%lld%lld",&n,&m); n--;// build(1,n,1); for(ll i=1,l,r,v;i<=m;i++) { char dc; cin>>dc; scanf("%lld%lld",&l,&r);//车站而非间距 r--; if(dc=='C') { scanf("%lld",&v); add(1,n,l,r,v,1); } if(dc=='Q') { ll k=query(1,n,l,r,1); yf(k,r-l+1); } } return 0; }
代码如下:
#include<iostream> #include<cstdio> #include<cstring> using namespace std; typedef long long ll; ll const maxn=5e5+5; ll n,m; struct N{ ll sum,lsum,rsum,len,lzy; ll s;//! }t[maxn]; ll gcd(ll x,ll y){return y?gcd(y,x%y):x;} void yf(ll x,ll nn)//nn { ll y=(1+nn)*nn/2; ll a=x,b=y; if(a<b)swap(a,b); ll c=gcd(a,b); printf("%lld/%lld ",x/c,y/c); } ll cala(ll x)//add { ll nn=t[x].len; ll n=nn-1; return (1+nn)*nn/2*nn-(1+n)*n/2-n*(n+1)*(n*2+1)/6;// } ll calb(ll x)//lsum,rsum add { ll nn=t[x].len; return (1+nn)*nn/2; } void pushup(ll x) { ll ls=(x<<1),rs=(x<<1|1); t[x].len=t[x<<1].len+t[x<<1|1].len; t[x].s=t[x<<1].s+t[x<<1|1].s; t[x].sum=t[ls].sum+t[rs].sum+t[ls].rsum*t[rs].len+t[rs].lsum*t[ls].len; t[x].lsum=t[ls].lsum+t[rs].lsum+t[ls].s*t[rs].len;//s t[x].rsum=t[ls].rsum+t[rs].rsum+t[rs].s*t[ls].len; } void build(ll l,ll r,ll x) { // t[x].len=r-l+1; if(l==r) { t[x].len=1; return; } ll mid=((l+r)>>1); build(l,mid,x<<1); build(mid+1,r,x<<1|1); pushup(x); } void pushdown(ll x) { if(t[x].lzy) { ll v=t[x].lzy;t[x].lzy=0; ll ls=(x<<1),rs=(x<<1|1); t[ls].s+=t[ls].len*v; t[ls].sum+=cala(ls)*v; t[ls].lsum+=calb(ls)*v; t[ls].rsum+=calb(ls)*v; t[ls].lzy+=v; t[rs].s+=t[rs].len*v; t[rs].sum+=cala(rs)*v; t[rs].lsum+=calb(rs)*v; t[rs].rsum+=calb(rs)*v; t[rs].lzy+=v; } } void add(ll l,ll r,ll L,ll R,ll v,ll x) { if(l>=L&&r<=R)// { t[x].s+=t[x].len*v; t[x].sum+=cala(x)*v; t[x].lsum+=calb(x)*v; t[x].rsum+=calb(x)*v; t[x].lzy+=v; return; } pushdown(x); ll mid=((l+r)>>1); if(R<=mid)add(l,mid,L,R,v,x<<1); else if(L>mid)add(mid+1,r,L,R,v,x<<1|1); else { add(l,mid,L,mid,v,x<<1); add(mid+1,r,mid+1,R,v,x<<1|1); } pushup(x); } ll qs(ll l,ll r,ll L,ll R,ll x) { if(l==L&&r==R)return t[x].s;// pushdown(x); ll mid=((l+r)>>1); if(R<=mid)return qs(l,mid,L,R,x<<1); if(L>mid)return qs(mid+1,r,L,R,x<<1|1); return qs(l,mid,L,mid,x<<1)+qs(mid+1,r,mid+1,R,x<<1|1); } ll qls(ll l,ll r,ll L,ll R,ll x) { if(l==L&&r==R)return t[x].lsum;// pushdown(x); ll mid=((l+r)>>1); if(L>mid)return qls(mid+1,r,L,R,x<<1|1); if(R<=mid)return qls(l,mid,L,R,x<<1); return qls(mid+1,r,mid+1,R,x<<1|1)+qls(l,mid,L,mid,x<<1) +qs(l,mid,L,mid,x<<1)*(R-mid); } ll qrs(ll l,ll r,ll L,ll R,ll x) { if(l==L&&r==R)return t[x].rsum;// pushdown(x); ll mid=((l+r)>>1); if(L>mid)return qrs(mid+1,r,L,R,x<<1|1); if(R<=mid)return qrs(l,mid,L,R,x<<1); return qrs(mid+1,r,mid+1,R,x<<1|1)+qrs(l,mid,L,mid,x<<1) +qs(mid+1,r,mid+1,R,x<<1|1)*(mid-L+1); } ll query(ll l,ll r,ll L,ll R,ll x)//注意L,R要纳入计算 { // cout<<l<<" "<<r<<endl; // printf("*%lld %lld ",L,R); if(l==L&&r==R)return t[x].sum;// pushdown(x); ll mid=((l+r)>>1); if(L>mid)return query(mid+1,r,L,R,x<<1|1); if(R<=mid)return query(l,mid,L,R,x<<1); return query(mid+1,r,mid+1,R,x<<1|1)+query(l,mid,L,mid,x<<1) +qrs(l,mid,L,mid,x<<1)*(R-mid)+qls(mid+1,r,mid+1,R,x<<1|1)*(mid-L+1);//! } int main() { scanf("%lld%lld",&n,&m); n--;// build(1,n,1); for(ll i=1,l,r,v;i<=m;i++) { char dc; cin>>dc; scanf("%lld%lld",&l,&r);//车站而非间距 r--; if(dc=='C') { scanf("%lld",&v); add(1,n,l,r,v,1); } if(dc=='Q') { ll k=query(1,n,l,r,1); yf(k,r-l+1); } } return 0; }