清真化简题
不妨考虑枚举每一个元素统计它出现的系数*它的权值。
(a[x])和(a[n-x+1])的出现次数是一样的,因此整个序列只剩下了一半。
先考虑(a[x](xle (n+1)/2))在长度(y)会出现多少次。
经过讨论,发现是(min(x,y,n-y+1))
那么假设(yle (n+1)/2),大于的部分是一样的。
最后就是要求这样一个东西((r<=(n+1)/2)):
(sum_{i=1}^{(n+1)/2} a[i]*sum_{j=1}^{r} min(i,j))
拆一下式子,发现维护(a[i]*i^{0..2})的区间和就好了。
Code:
#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i < _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("
")
using namespace std;
const int mo = 1e9 + 7;
ll ksm(ll x, ll y) {
ll s = 1;
for(; y; y /= 2, x = x * x % mo)
if(y & 1) s = s * x % mo;
return s;
}
const ll ni2 = ksm(2, mo - 2);
const int N = 2e5 + 5;
int n, m, n0, op, x, y, z;
#define i0 i + i
#define i1 i + i + 1
ll g[N * 4][3], t[N * 4][3], lz[N * 4];
void bt(int i, int x, int y) {
if(x == y) {
g[i][0] = 1;
g[i][1] = x;
g[i][2] = (ll) x * x % mo;
return;
}
int m = x + y >> 1;
bt(i0, x, m); bt(i1, m + 1, y);
fo(j, 0, 2) g[i][j] = (g[i0][j] + g[i1][j]) % mo;
}
int pl, pr, px;
void jia(int i, int px) {
fo(j, 0, 2) t[i][j] = (t[i][j] + g[i][j] * px) % mo;
lz[i] = (lz[i] + px) % mo;
}
void down(int i) {
if(lz[i]) jia(i0, lz[i]), jia(i1, lz[i]), lz[i] = 0;
}
void add(int i, int x, int y) {
if(y < pl || x > pr) return;
if(x >= pl && y <= pr) {
jia(i, px);
return;
}
int m = x + y >> 1; down(i);
add(i0, x, m); add(i1, m + 1, y);
fo(j, 0, 2) t[i][j] = (t[i0][j] + t[i1][j]) % mo;
}
ll py[3];
void ft(int i, int x, int y) {
if(y < pl || x > pr) return;
if(x >= pl && y <= pr) {
fo(j, 0, 2) py[j] = (py[j] + t[i][j]) % mo;
return;
}
int m = x + y >> 1; down(i);
ft(i0, x, m); ft(i1, m + 1, y);
}
void xiu(int x, int y, int z) {
if(x <= n0) {
pl = x, pr = min(y, n0), px = z;
add(1, 1, n0);
}
if(y > n0) {
pl = n - y + 1, pr = n - max(x, n0 + 1) + 1; px = z;
add(1, 1, n0);
}
}
ll calc(int r) {
ll ans = 0;
pl = 1, pr = r; fo(j, 0, 2) py[j] = 0;
ft(1, 1, n0);
ans = (ans + (py[2] + py[1]) * ni2 % mo - py[2] + py[1] * r + mo) % mo;
pl = r + 1, pr = n0; fo(j, 0, 2) py[j] = 0;
ft(1, 1, n0);
ans = (ans + py[0] * ((ll) r * (r + 1) / 2 % mo)) % mo;
return ans;
}
ll qry(int r) {
if(r <= n0) return calc(r);
return (calc(n0) + (n % 2 == 1 ? calc(n0 - 1) : calc(n0)) - calc(n - r) + mo) % mo;
}
int main() {
scanf("%d %d", &n, &m);
n0 = (n + 1) / 2;
bt(1, 1, n0);
fo(i, 1, n) {
scanf("%d", &x);
xiu(i, i, x);
}
fo(ii, 1, m) {
scanf("%d %d %d", &op, &x, &y);
if(op == 1) {
scanf("%d", &z);
if(x > y) swap(x, y);
xiu(x, y, z);
} else {
ll s = (qry(y) - qry(x - 1) + mo) % mo;
pp("%lld
", s);
}
}
}