Description
有一些高度为 h 的树在数轴上。每次选择剩下的树中最左边或是最右边的树推倒(各 50% 概率),往左倒有 p 的概率,往右倒 1-p。
一棵树倒了,如果挨到的另一棵树与该数的距离严格小于h,那么它也会往同方向倒。
问所有树都被推倒后的期望覆盖长度?
要注意的一点是每棵树占的是一个点,相邻点之间只有一段距离!
Solution
我们定义 (mathtt{f[x][y][l][r]}) 为:边界为 x 与 y 的推倒期望长度。
关于 l 与 r 则有:
- (mathtt{l==0}):x 向左推倒没有限制。(所谓限制就是 (mathtt{x-1}) 那一位向右推倒并与 x 向左推倒有相交的区域)
- (mathtt{l==1}):x 向左推倒有限制。
- (mathtt{r==0}):y 向右推倒有限制。
- (mathtt{r==1}):y 向右推倒没有限制。
我们再记录 (mathtt{L[i]}) 与 (mathtt{R[i]}) 分别表示 i 向左推倒能殃及到第几号树,向右推倒能殃及第几号树,这个 (mathtt{O(n)}) 就可实现。
然后记忆化搜索就行了。这个时间复杂度应该是 (mathtt{O(n^2)}) 的?(左边每向右推进一棵树,右边粗略有 (mathtt{O(n)}) 的树可枚举)
Code
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 2005;
int n, h, pos[N], L[N], R[N];
double f[N][N][2][2], p;
int read() {
int x = 0, f = 1; char s;
while((s = getchar()) < '0' || s > '9') if(s == '-') f = -1;
while(s >= '0' && s <= '9') {x = (x << 1) + (x << 3) + (s ^ 48); s = getchar();}
return x * f;
}
int check(const int op, const int x, const int dir) {
if(! op) {
if(! dir) return min(pos[x] - pos[x - 1], h);
return min(max(pos[x] - pos[x - 1] - h, 0), h);
}
else {
if(! dir) return min(max(pos[x + 1] - pos[x] - h, 0), h);
return min(pos[x + 1] - pos[x], h);
}
}
double dfs(const int x, const int y, const int l, const int r) {
if(x > y) return 0;
if(f[x][y][l][r]) return f[x][y][l][r];
double &ans = f[x][y][l][r];
ans += 0.5 * p * (dfs(x + 1, y, 0, r) + check(0, x, l));
if(R[x] + 1 <= y) ans += 0.5 * (1 - p) * (dfs(R[x] + 1, y, 1, r) + pos[R[x]] - pos[x] + h);
else ans += 0.5 * (1 - p) * (pos[y] - pos[x] + check(1, y, r));
ans += 0.5 * (1 - p) * (dfs(x, y - 1, l, 1) + check(1, y, r));
if(x <= L[y] - 1) ans += 0.5 * p * (dfs(x, L[y] - 1, l, 0) + pos[y] - pos[L[y]] + h);
else ans += 0.5 * p * (pos[y] - pos[x] + check(0, x, l));
return ans;
}
int main() {
n = read(), h = read(), scanf("%lf", &p);
for(int i = 1; i <= n; ++ i) pos[i] = read();
sort(pos + 1, pos + n + 1);
pos[0] = pos[1] - h; pos[n + 1] = pos[n] + h;
L[1] = 1; R[n] = n;
for(int i = 2; i <= n; ++ i)
if(pos[i] - pos[i - 1] < h) L[i] = L[i - 1];
else L[i] = i;
for(int i = n - 1; i >= 1; -- i)
if(pos[i + 1] - pos[i] < h) R[i] = R[i + 1];
else R[i] = i;
printf("%.10f
", dfs(1, n, 0, 1));
return 0;
}