题目描述
You've got an array, consisting of nn integers: (a_{1},a_{2},...,a_{n}). Your task is to quickly run the queries of two types:
-
Assign value (x) to all elements from (l) to (r) inclusive. After such query the values of the elements of array (a_{l},a_{l+1},...,a_{r}) become equal to (x).
-
Calculate and print sum , where (k) doesn't exceed (5) . As the value of the sum can be rather large, you should print it modulo (1000000007 (10^{9}+7))
题目大意
一段序列 (a_1,a_2......a_n)
维护两种操作:
(= l r x) 表示将区间 ([l,r]) 的值赋为 (x)
(? l r k) 表示输出 (Sigma_{i=l}^ra_i(i-l+1)^k mod 1e9+7)
思路
用二项式定理展开一下
[egin{align*}
Sigma_{i=l}^ra_i[i+(1-l)]^k\
end{align*}]
[egin{align*}
=&Sigma_{i=l}^ra_iSigma_{j=0}^ki^j(1-l)^{k-j}C_k^j\
end{align*}]
[egin{align*}
=&Sigma_{j=0}^k(1-l)^{k-j}C_k^jSigma_{i=l}^ra_ii^j
end{align*}]
所以维护 (a_ii^k,kin[0,5]) 就好了
#include <cstdio>
const int c[6][6] = { { 1,0,0,0,0,0 },{ 1,1,0,0,0,0 },{ 1,2,1,0,0,0 },{ 1,3,3,1,0,0 },{ 1,4,6,4,1,0 },{ 1,5,10,10,5,1 } };
const int maxn = 1e5 + 10;
const int mod = 1e9 + 7;
typedef long long ll;
int n,m,laz[maxn<<3];
ll sum[maxn<<3][6];
inline ll powerkth(ll n,int k) {
if (k == 1) return n*(n+1)/2%mod;
if (k == 2) return n*(n+1)*(2*n+1)/6%mod;
if (k == 3) return n*n%mod*(n+1)%mod*(n+1)%mod*250000002ll%mod;
if (k == 4) return n*(n+1)%mod*(2*n+1)%mod*(3*n*n%mod+3*n%mod-1)%mod*233333335ll%mod;
if (k == 5) return n*n%mod*(n+1)%mod*(n+1)%mod*(2*n*n%mod+2*n%mod-1)%mod*83333334ll%mod;
return n;
}
inline void pushup(int root) { for (int i = 0;i <= 5;i++) sum[root][i] = (sum[root<<1][i]+sum[root<<1|1][i])%mod; }
inline void pushdown(int root,int l,int r) {
int mid = l+r>>1;
if (laz[root] ^ mod) {
laz[root<<1] = laz[root];
laz[root<<1|1] = laz[root];
for (int i = 0;i <= 5;i++) {
sum[root<<1][i] = laz[root]*((powerkth(mid,i)-powerkth(l-1,i)+mod)%mod)%mod;
sum[root<<1|1][i] = laz[root]*((powerkth(r,i)-powerkth(mid,i)+mod)%mod)%mod;
}
laz[root] = mod;
}
}
inline void build(int l,int r,int root) {
laz[root] = mod;
if (l == r) {
scanf("%lld",&sum[root][0]);
for (int i = 1;i <= 5;i++) sum[root][i] = sum[root][i-1]*l%mod;
return;
}
int mid = l+r>>1;
build(l,mid,root<<1);
build(mid+1,r,root<<1|1);
pushup(root);
}
inline void update(int l,int r,int ul,int ur,int root,ll x) {
if (l > ur || r < ul) return;
if (ul <= l && r <= ur) {
laz[root] = x;
for (int i = 0;i <= 5;i++) sum[root][i] = x*((powerkth(r,i)-powerkth(l-1,i)+mod)%mod)%mod;
return;
}
pushdown(root,l,r);
int mid = l+r>>1;
update(l,mid,ul,ur,root<<1,x);
update(mid+1,r,ul,ur,root<<1|1,x);
pushup(root);
}
inline ll query(int l,int r,int ql,int qr,int root,int k) {
if (l > qr || r < ql) return 0;
if (ql <= l && r <= qr) return sum[root][k];
pushdown(root,l,r);
int mid = l+r>>1;
return (query(l,mid,ql,qr,root<<1,k)+query(mid+1,r,ql,qr,root<<1|1,k))%mod;
}
int main() {
for (scanf("%d%d",&n,&m),build(1,n,1);m--;) {
char ch; int l,r,k;
scanf("%s%d%d%d",&ch,&l,&r,&k);
if (ch == '=') update(1,n,l,r,1,k);
else {
if (l == 1) { printf("%lld
",query(1,n,l,r,1,k)); continue; }
ll ans = 0;
for (int i = 0;i <= k;i++) {
ll tmp = 1;
for (int j = 1;j <= k-i;j++) tmp = (tmp*(1-l)%mod+mod)%mod;
(ans += tmp*c[k][i]%mod*query(1,n,l,r,1,i)%mod) %= mod;
}
printf("%lld
",ans);
}
}
return 0;
}