题目大意:
已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
本线段树的标记是个二元组:add和mul,其代表将一个线段中的每一个点乘以mul再加add。设区间长度为x,原来区间和为sum。如果两个标记要叠加,标记叠加前区间上的和将是sum*mul+add,叠加后的值将是(sum*mul+add)*mul'+add'=mul*mul'*sum+add*mul'+add'。所以将mul*=mul', add=add*mul'+add'即可。
注意:
- 尽管数据关于P取模了,但是因为有数据相乘的操作,所以程序中所有的值类型都要是long long
- 宏定义ModPlus, ModMult时,如ModMult,不要写成((x%P)*(y%P))%P,应该写成(x*y)%P,否则就被卡常数了。
#include <cstdio> #include <cstring> #include <cassert> using namespace std; const int MAX_RANGE=100010, MAX_NODE = MAX_RANGE * 4; #define LOOP(i, n) for(int i=1; i<=n; i++) long long P, TotRange; long long OrgData[MAX_RANGE]; struct RangeTree { private: #define ModPlus(x, y) ((x)%P+(y)%P)%P #define ModMult(x, y) ((x)%P*(y)%P)%P #define lSon cur*2, l, mid #define rSon cur*2+1, mid+1, r #define Lson cur*2, sl, mid, al, ar #define Rson cur*2+1, mid+1, sr, al, ar struct Tag { long long add, mul; Tag() {} Tag(int m, int a):mul(m),add(a){} void Refresh(Tag x) { mul = ModMult(mul, x.mul); add = ModMult(add, x.mul); add = ModPlus(add, x.add); } void Clear() { add = 0; mul = 1; } int GetSum(int sum, int l, int r) { return ModPlus(ModMult(sum, mul), ModMult(add, (r - l + 1))); } }; Tag _tags[MAX_NODE]; long long Sum[MAX_NODE]; void PushDown(int cur, int l, int r) { if (_tags[cur].add != 0 || _tags[cur].mul != 1) { int mid = (l + r) / 2; Sum[cur * 2] = _tags[cur].GetSum(Sum[cur * 2], l, mid); Sum[cur * 2 + 1] = _tags[cur].GetSum(Sum[cur * 2 + 1], mid + 1, r); _tags[cur * 2].Refresh(_tags[cur]); _tags[cur * 2 + 1].Refresh(_tags[cur]); _tags[cur].Clear(); } } void PullUp(int cur) { Sum[cur] = ModPlus(Sum[cur * 2], Sum[cur * 2 + 1]); } void Update(int cur, int sl, int sr, int al, int ar, int op, int value) { assert(al <= ar && sl <= sr && al <= sr && ar >= sl); if (al <= sl && sr <= ar) { if (op == 1) { Sum[cur] = ModMult(Sum[cur], value); _tags[cur].Refresh(Tag(value, 0)); } else if (op == 2) { Sum[cur] = ModPlus(Sum[cur], (sr - sl + 1)*value); _tags[cur].Refresh(Tag(1, value)); } return; } PushDown(cur, sl, sr); int mid = (sl + sr) / 2; if (al <= mid) Update(Lson, op, value); if (ar > mid) Update(Rson, op, value); PullUp(cur); } int Query(int cur, int sl, int sr, int al, int ar) { assert(al <= ar && sl <= sr && al <= sr && ar >= sl); if (al <= sl && sr <= ar) return Sum[cur]; PushDown(cur, sl, sr); int mid = (sl + sr) / 2, ans = 0; if (al <= mid) ans = ModPlus(ans, Query(Lson)); if (ar > mid) ans = ModPlus(ans, Query(Rson)); PullUp(cur); return ans; } void SetEachNode(long long *a, int cur, int l, int r) { _tags[cur] = Tag(1, 0); if (l == r) { Sum[cur] = a[l]; return; } int mid = (l + r) / 2; SetEachNode(a, lSon); SetEachNode(a, rSon); PullUp(cur); } public: RangeTree() {} void SetEachNode(long long *a) { SetEachNode(a, 1, 1, TotRange); } void Update(int l, int r, int op, int value) { Update(1, 1, TotRange, l, r, op, value); } long long Query(int l, int r) { return Query(1, 1, TotRange, l, r); } }g; int main() { int opCnt, op, l, r, val; scanf("%lld%d%lld", &TotRange, &opCnt, &P); LOOP(i, TotRange) scanf("%lld", OrgData + i); g.SetEachNode(OrgData); while (opCnt--) { scanf("%d", &op); switch (op) { case 1://Mult scanf("%d%d%d", &l, &r, &val); g.Update(l, r, 1, val); break; case 2://Plus scanf("%d%d%d", &l, &r, &val); g.Update(l, r, 2, val); break; case 3://Query scanf("%d%d", &l, &r); printf("%lld ", g.Query(l, r)); break; } } return 0; }