拿到本题后,可以观察到一个性质,如果出现了 (c_i e c_{i + 1}) 那么我们一定可以确定一个位置的值,这启示着我们将 (c_i) 相同的部分单独拿出来考虑再将最后的答案合并。于是可以先思考一个更为特殊的问题,所有 (c_i) 都相同的答案。为了让所有区间都被满足填了一个 (c_i),可以令 (dp_i) 表示前 (i) 个区间都存在一个 (c_i) 的方案。但你发现这是不好转移的,因为你不知道这个区间的 (1) 填在哪里,而且这个 (1) 填的位置可能对后面的区间造成影响,因此这样是不行的。先解决不满足后效性的问题,为了让当前的选择不会对后面产生影响,我们只能每次填在当前区间的最前面,但你可以发现因为区间每次只向右挪了一位,所以每次填在区间最前面是可以覆盖前 (n - k + 1) 个位置的答案的。因此这个想法是很有希望的,所以我们可以令 (dp_i) 表示当前填的 (c_i) 最后在 (i) 的方案,那么就有转移((m = 1e9 - c_i + 1)):
于此同时你可以发现 (dp_{i - 1}) 的转移中也会包含同样的 (k - 1) 个 (dp) 值,即:
则可也类似错位相消的方法把中间的值消去:
拿 (dp_i) 与其做差可得:
这样转移就可以做到 (O(1)) 了,总体 (dp) 的复杂度就是 (O(n)) 的了。最后统计答案时候同样枚举最后一个填 (c_i) 的位置,显然只能在 ([n - k + 1, n]) 中,那么最后的答案就为:
可以发现这就是 (dp_{n + 1}),于是我们就在 (O(n)) 的时间复杂度内解决了这个特殊问题。
再回来考虑原来的问题,将原来的 (c) 序列划分成很多由相同 (c_i) 组成的段。一个直接的想法就是将满足这一段相同 (c_i) 的段的方案(假设这一段相同 (c_i) 的长度为 (len),即用 (c_i) 去填长度为 (len + k - 1) 的段的方案)算出来再将所有段的答案相乘。但是你会发现段与段之间是会互相影响的,也就是说两端之间需要填的位置会有重复,貌似还是不好去计数,但仔细分析一下会有影响的部分部分会发现(假设当前段开头位置为 (i),最后一个位置为 (j)):
-
假如 (c_{i - 1} > c_i),则 (i sim i + k - 2) 上填的数都要大于 (c_{i - 1}),我们何不将这部分的答案划给前一种 (c_{i - 1}) 来算呢?于此同时 (i + k - 1) 这个位置也以及确定必须填 (c_i),因此 (i) 能算的区间长度就要减去 (k)。
-
假如 (c_{j + 1} > c_i),则 (j + 1 sim j + k - 1) 上填的数都要大于 (c_{j + 1}) 同理我们可以将这一段划分给 (j + 1) 来算,因此 (i) 能算的区间长度又要减去 (k)。
那么这样将 (i) 能算的答案算出来再将所有块的答案相乘即可,复杂度 (O(n log n))((log) 来自于计算 ((m - 1) ^ k))。
#include<bits/stdc++.h>
using namespace std;
#define rep(i, l, r) for(int i = l; i <= r; ++i)
const int N = 100000 + 5;
const int inf = 1000000000;
const int Mod = 1000000000 + 7;
int n, k, P, ans, a[N], dp[N];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
int Inc(int a, int b){ return (a += b) >= Mod ? a - Mod : a;}
int Dec(int a, int b){ return (a -= b) < 0 ? a + Mod : a;}
int Mul(int a, int b){ return 1ll * a * b % Mod;}
int fpow(int a, int b){ int ans = 1; for(; b; a = Mul(a, a), b >>= 1) if(b & 1) ans = Mul(ans, a); return ans;}
int solve(int n, int m){
rep(i, 1, n + 1) dp[i] = 0;
dp[0] = dp[1] = 1, P = fpow(m - 1, k);
rep(i, 2, n + 1){
dp[i] = Mul(m, dp[i - 1]);
if(i - k > 0) dp[i] = Dec(dp[i], Mul(dp[i - k - 1], P));
}
return dp[n + 1];
}
int main(){
n = read(), k = read(), ans = 1;
rep(i, 1, n - k + 1) a[i] = read();
rep(i, 1, n - k + 1){
int j = i, len = 0; while(j <= n && a[j] == a[i]) ++j;
--j, len = j - i + k;
if(a[i - 1] > a[i]) len -= k;
if(a[j + 1] > a[i]) len -= k;
ans = Mul(ans, solve(len, inf - a[i] + 1));
i = j;
}
printf("%d", ans);
return 0;
}