一、题目:
二、思路:
先来看一波官方题解。
在这里对两个转移方程做一点解释。
先看
[dp1(H)=2^xprod_{i=1}^kdp1(c_i)
]
这个转移方程的意思是说,直方图的最下面 (x) 行每行都有两种选择,要么是 (mathtt{RBRB}),要么是 (mathtt{BRBR})。所以等式右边有一个因子 (2^x),又根据乘法原理,再乘上 (prodlimits_{i=1}^k dp1(c_i)) 即可。
再来看第二个转移方程
[dp2(H)=2^wprodlimits_{i=1}^k(dp1(c_i)+dp2(c_i))+(2^x-2)prodlimits_{i=1}^kdp1(c_i)
]
根据加法原理,(dp2(H)) 的计数可以分为两部分。
-
直方图的最下面 (x) 行每列都是红蓝交替染色。那么每列都有两种选择,所以是 (2^w)。剩下的部分,对于每个 (c_i) 又可以根据加法原理分成两部分。
- (c_i) 下面那一行是红蓝交替染色。那么 (c_i) 染色的方案数为 (2 imes dp1(c_i))。
- (c_i) 下面那一行不是红蓝交替染色。那么 (c_i) 染色的方案数就是 (dp2(c_i)-dp1(c_i))。
因此剩下部分的计数就是 (prodlimits_{i=1}^k dp1(c_i)+dp2(c_i))。
-
直方图最下面 (x) 行每列都不是红蓝交替染色。我们发现只要第一列填好,剩下列的染色也就确定了。因此最下面 (x) 行的染色仅取决于第一列的染色。所以是 (2^x-2)(除去两种红蓝交替的情况)。剩下的部分,对于每个 (c_i),它的第一行必然是红蓝交替的,所以方案数为 (dp1(c_i));因此剩下部分的总方案数就是 (prodlimits_{i=1}^kdp1(c_i))。
三、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#define FILEIN(s) freopen(s".in", "r", stdin);
#define FILEOUT(s) freopen(s".out", "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return f * x;
}
const int maxn = 105, mod = 1e9 + 7;
int n, h[maxn];
inline void chkmin(int &x, int y) {
x = y < x ? y : x;
}
inline long long power(long long a, long long b) {
long long res = 1;
for (; b; b >>= 1) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void solve(int l, int r, long long &res1, long long &res2) {
vector<int>vec;
vector<long long>dp1, dp2;
int mn = 0x3f3f3f3f;
for (int i = l; i <= r; ++ i) {
chkmin(mn, h[i]);
}
for (int i = l; i <= r; ++ i) {
if (h[i] == mn) vec.push_back(i);
}
for (int i = l; i <= r; ++ i) h[i] -= mn;
long long tmp1, tmp2;
int w = vec.size();
for (int i = l; i <= r; ++ i) {
if (h[i] > 0) {
int j = i;
for (; j <= r && h[j] > 0; ++ j);
-- j;
solve(i, j, tmp1, tmp2);
dp1.push_back(tmp1); dp2.push_back(tmp2);
i = j + 1;
}
}
if (dp1.empty()) {
res1 = power(2, mn);
res2 = (power(2, w) + power(2, mn) - 2) % mod;
return;
}
tmp1 = 1;
for (auto &p : dp1) (tmp1 *= p) %= mod;
res1 = power(2, mn) * tmp1 % mod;
res2 = (power(2, mn) - 2) * tmp1 % mod; if (res2 < 0) res2 += mod;
tmp2 = 1;
for (int i = 0; i < (int)dp1.size(); ++ i) (tmp2 *= (dp1[i] + dp2[i]) % mod) %= mod;
(tmp2 *= power(2, w)) %= mod;
(res2 += tmp2) %= mod;
}
int main() {
n = read();
for (int i = 1; i <= n; ++ i) {
h[i] = read();
}
long long ans1, ans2;
solve(1, n, ans1, ans2);
printf("%lld
", ans2);
return 0;
}