转载声明
解释来自:https://www.luogu.org/blog/running-coder/solution-p3373
进入正题:(以下内容均在sgt结构体内)
此题相对于模板一,加了个区间乘,于是在模板一的基础上需要多开个数组(记录乘法懒标记)、多写个函数(区间乘),还有要把懒标记下放函数做些修改。
变量定义:
sum[]:线段树节点对应区间的元素总和;
addv[]:线段树节点对应区间的所有元素待加的值(懒标记),初值全部设为0;
mulv[]:线段树节点对应区间的所有元素待乘的值(懒标记),初值全部设为1。
过程说明:
建树(Build):
同模板一。。。
懒标记下放(Push_down):
原理解释:
1.当对某区间执行加法操作时,由于加法优先级低,不会对乘法操作产生影响,故直接相加即可;
2.当对某区间执行乘法操作时,由于乘法优先级高,会对之前的加法操作产生影响,故需要在相乘时不仅对sum和mulv相乘,也需要对addv相乘;
3.由于上述原因,故需要先算乘法再算加法。
细节实现:
1.子树的sum、mulv、addv值分别乘上当前节点的mulv值;
2.当前节点的mulv值还原,即置为1;
3.子树的addv值加上当前节点的addv值;
4.子树的sum值加上(子树包含元素数量*当前节点的addv值);
5.当前节点的addv值还原,即置为0。
特别说明:
1.使用前判断,若当前节点的懒标记为空则不需执行此下放函数。虽然执行了也不会有影响,但浪费时间;
2.为尽量节省时间,要将判断放在此函数外而不是函数内。
区间加(Addall):
同模板一。。。
区间乘(Mulall):
若当前节点完全包含在待更新区间内,则直接修改当前节点的mulv、addv、sum值即可(参考下放函数说明);
否则执行与区间加类似的操作即可。
区间查询(Query):
同模板一。。。
提示:不要忘记取模。。。
个人代码
#include<cstdio>
using namespace std;
#define MAX 100000+99
#define ll long long
int n,m,p;
ll sumv[MAX<<2], addv[MAX<<2], mulv[MAX<<2];
ll a[MAX];
void push_up(int o) { sumv[o] = (sumv[o<<1] + sumv[o<<1|1] ) % p;}
void build(int o, int l, int r) {
addv[o] = 0; mulv[o] = 1;
if(l == r) { sumv[o] = a[l]; return;}
int mid = (l+r)>>1;
build(o<<1, l, mid);
build(o<<1|1, mid+1, r);
push_up(o);
}
//void pushtag()这里的pushtag不一样,我就不写了
void push_down(int o, int l, int r) {
if(mulv[o] != 1) {//先传mul
mulv[o<<1] = (mulv[o<<1] * mulv[o]) % p;
mulv[o<<1|1] = (mulv[o<<1|1] * mulv[o]) % p;
addv[o<<1] = (addv[o<<1] * mulv[o]) % p;
addv[o<<1|1] = (addv[o<<1|1] * mulv[o]) % p;
sumv[o<<1] = (sumv[o<<1] * mulv[o]) % p;
sumv[o<<1|1] = (sumv[o<<1|1] * mulv[o]) % p;
mulv[o] = 1;
}
if(addv[o]) {
addv[o<<1] = (addv[o<<1] + addv[o]) % p;
addv[o<<1|1] = (addv[o<<1|1] + addv[o]) % p;
int mid = (l + r) >> 1;
sumv[o<<1] = (sumv[o<<1] + addv[o]*(mid-l+1) ) % p;
sumv[o<<1|1] = (sumv[o<<1|1] + addv[o]*(r-mid) ) % p ;
addv[o] = 0;
}
}
void optadd(int o, int l, int r, int ql, int qr, int v) {
if(ql <= l && r <= qr) {
addv[o] = (ll)(addv[o] + v) % p;
sumv[o] = (ll)(sumv[o] + v*(r-l+1) ) % p;
return ;
}
push_down(o,l,r);
int mid = (l + r) >> 1;
if(ql <= mid) optadd(o<<1, l, mid, ql, qr, v);
if(mid < qr) optadd(o<<1|1, mid+1, r, ql, qr, v);
push_up(o);
}
void optmul(int o, int l, int r, int ql, int qr, int v) {
if(ql <= l && r <= qr) {
mulv[o] = (ll)(mulv[o] * v) % p;
addv[o] = (ll)(addv[o] * v) % p;
sumv[o] = (ll)(sumv[o] * v) % p;
return ;
}
push_down(o,l,r);
int mid = (l + r) >> 1;
if(ql <= mid) optmul(o<<1, l, mid, ql, qr, v);
if(mid < qr) optmul(o<<1|1, mid+1, r, ql, qr, v);
push_up(o);
}
ll query(int o, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return (ll)sumv[o]%p;
push_down(o,l,r);
ll ans = 0;
int mid = (l + r) >> 1;
if(ql <= mid) ans = (ans + query(o<<1, l, mid, ql, qr) ) % p;
if(mid < qr) ans = (ans + query(o<<1|1, mid+1, r, ql, qr) ) % p;
return ans%p;
}
int main() {
scanf("%d%d%d",&n,&m,&p);
for(int i = 1; i <= n; i++) scanf("%lld",&a[i]);
build(1,1,n);
int tmp, l, r, v;
for(int i = 1; i <= m; i++) {
scanf("%d", &tmp);
if(tmp == 1) { scanf("%d%d%d", &l, &r, &v); optmul(1, 1, n, l, r, v); }
else if(tmp == 2) { scanf("%d%d%d", &l, &r, &v); optadd(1, 1, n, l, r, v); }
else if(tmp == 3) { scanf("%d%d", &l, &r); printf("%lld
", query(1, 1, n, l, r)); }
}
return 0;
}
/*
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
*/