【题目描述】
你的面前有 (n) 个数排成一行,分别为 (a_1, a_2, dots, a_n)。你打算在每相邻的两个 (a_i) 和 (a_{i+1}) 间都插入一个加号、减号或者乘号。那么一共有 (3^{n-1}) 种可能的表达式。
你对所有可能的表达式的值的和非常感兴趣。但这毕竟太简单了,所以你还打算支持一个修改操作,可以修改某个 (a_i) 的值。
你能够编写一个程序对每个修改都输出修改完之后所有可能表达式的和吗?注意,修改是永久的,也就是说每次修改都是在上一次修改的基础上进行,而不是在最初的表达式上进行。
【输入格式】
第一行包含两个正整数 (n) 和 (Q),为数的个数和询问的个数。
第二行包含 (n) 个非负整数,依次表示 (a_1, a_2, dots, a_n) 。
接下来 (Q) 行,每行包含两个非负整数 (t) 和 (v),表示要将 (a_t) 修改为 (v),其中 (1 leq t leq n)。
保证对于 (1 leq j leq n, 1 leq i leq Q),都有 (a_j, v_i leq 10^4)。
【输出格式】
输出 (Q) 行。对于每个修改输出一行,包含一个整数,表示修改之后所有可能表达式的和,对 (10^9 + 7) 取模。
(n, Q le 100000)
此题中乘号相对于加减号来说其实算一个比较特殊的运算符,由于运算符优先级,乘法要优先计算,所以每个表达式其实都能看作是几段连续的乘积相加减。
举例来说(a_1*a_2-a_3*a_4*a_5+a_6),就是(a_1*a_2)减去(a_3*a_4*a_5)减去(a_6)(废话),即三个乘积相加减
所以我们可以将一个表达式的构成这样理解:先插入乘号,使得原来的(n)个元素变成(x)个乘积,不妨设这些乘积为(p_1,p_2,dots,p_x)。
然后在这(x)个乘积之间加上正负号
对于每种表达式,一定有另外一种表达式与它相互对应,它们的和为(2*p_1)。什么意思呢?比如表达式(a_1*a_2-a_3*a_4*a_5+a_6),中间的5个符号,乘号不变,加变减,减变加,变成(a_1*a_2+a_3*a_4*a_5-a_6),这两个表达式的和为(2*a_1*a_2),即(2*p_1)。
考虑(p_1)一定是一个(a_1*a_2*dots*a_k)的形式,即一个前缀积,而每种表达式产生的贡献只与(p_1)有关,所以可以来计算一下每个前缀积在多少个表达式中出现。
设这个前缀积是(a_1*a_2*dots*a_k),显然(a_k)和(a_{k+1})之间的符号只能填加减中的一种,后面的位置则是三种任选,总方案数是(f[k]=2*3^{n-k-1}),特别的,(f[n]=1)。
那么对于(a_1, a_2, dots, a_n),所有可能的表达式的和就是(sum_{k=1}^n f[k]*a_1*a_2*dots*a_k)
把这(n)个值插入线段树中,对于一个将(a_t)修改为(v)的修改操作,只需将[t,n]区间内的值除以原(a_t)(即乘以(a_t)的逆元)再乘上(v),得到新的前缀积。
时间复杂度(O((n+Q)log n))
【代码】
#include <bits/stdc++.h>
#define mod 1000000007
#define lson ind<<1
#define rson ind<<1|1
using namespace std;
typedef long long ll;
inline ll read() {
ll x = 0, f = 1; char ch = getchar();
for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x<<1) + (x<<3) + (ch^'0');
return x * f;
}
ll n, q, a[100005], mul[100005], f[100005];
struct segtree{
ll l, r, sum, tag;
} tr[400005];
inline ll fpow(ll x, ll t) {
ll r = 1;
for (; t; t >>= 1, x = x * x % mod) if (t & 1) r = r * x % mod;
return r;
}
void build(ll ind, ll l, ll r) {
tr[ind].l = l; tr[ind].r = r; tr[ind].tag = 1;
if (l == r) {
tr[ind].sum = mul[l] * f[l] % mod;
return;
}
ll mid = (l + r) >> 1;
build(lson, l, mid); build(rson, mid+1, r);
tr[ind].sum = (tr[lson].sum + tr[rson].sum) % mod;
}
inline void pushdown(ll ind) {
if (tr[ind].tag == 1) return;
ll v = tr[ind].tag; tr[ind].tag = 1;
tr[lson].sum = tr[lson].sum * v % mod; tr[lson].tag = tr[lson].tag * v % mod;
tr[rson].sum = tr[rson].sum * v % mod; tr[rson].tag = tr[rson].tag * v % mod;
}
void update(ll ind, ll x, ll y, ll v) {
ll l = tr[ind].l, r = tr[ind].r;
if (x <= l && r <= y) {
tr[ind].sum = tr[ind].sum * v % mod;
tr[ind].tag = tr[ind].tag * v % mod;
return;
}
ll mid = (l + r) >> 1;
pushdown(ind);
if (x <= mid) update(lson, x, y, v);
if (mid < y) update(rson, x, y, v);
tr[ind].sum = (tr[lson].sum + tr[rson].sum) % mod;
}
int main() {
n = read(); q = read();
for (int i = 1; i <= n; i++) a[i] = read();
mul[0] = 1; f[n] = 1; f[n-1] = 2;
for (int i = 1; i <= n; i++) mul[i] = mul[i-1] * a[i] % mod;
for (int i = n-2; i; i--) f[i] = f[i+1] * 3 % mod;
build(1, 1, n);
for (int i = 1, p, v; i <= q; i++) {
p = read(); v = read();
update(1, p, n, v * fpow(a[p], mod-2) % mod);
a[p] = v;
printf("%lld
", tr[1].sum);
}
return 0;
}