Description
小凸和小方相约玩密室逃脱,这个密室是一棵有 (n) 个节点的完全二叉树,每个节点有一个灯泡。点亮所有灯泡即可逃出密室。每个灯泡有个权值 (Ai) ,每条边也有个权值 (bi) 。点亮第 1 个灯泡不需要花费,之后每点亮 1 个新的灯泡 (V) 的花费,等于上一个被点亮的灯泡 (U) 到这个点 (V) 的距离 (Du,v),乘以这个点的权值 (Av) 。在点灯的过程中,要保证任意时刻所有被点亮的灯泡必须连通,在点亮一个灯泡后必须先点亮其子树所有灯泡才能点亮其他灯泡。
请告诉他们,逃出密室的最少花费是多少。
Input
第1行包含1个数 (n) ,代表节点的个数
第2行包含 (n) 个数,代表每个节点的权值 (ai) 。( (i=1,2,…,n) )
第3行包含 (n-1) 个数,代表每条边的权值 (bi) ,第 (i) 号边是由第 ((i+1)/2) 号点连向第 (i+1) 号点的边。( (i=l,2,...N-1) )
Output
输出包含1个数,代表最少的花费。
Sample Input
3
5 1 2
2 1
Sample Output
5
HINT
对于 (100\%) 的数据,(1 leq N leq 2 imes 10^5),(1<Ai,Bi leq 10^5)
想法
明显的树形 (DP) 。
但注意题中的2个坑点 !!!!!!
第一个点亮的节点不一定是1号点!
“完全二叉树”的意思的第 (i) 个点的父亲是 (i/2) ,但不保证所有非叶子节点都有两个孩子!
先假设第一个点亮的点是 1 。
那么树形 (dp) 的状态为:
(dp[i][j]) 表示当前 (i) 已被点亮,开始点亮以 (i) 为根的子树,将其全点亮后,最后一个点跑到 (j) 去点亮 (j) 的最少花费。
转移也挺显然的,考虑左右子哪个先点亮就行了,记忆化搜索。
由于在这种情况下,对每个 (i) ,有用的 (dp[i][j]) 中的 (j) 为其所有祖先的另一个孩子,不超过 (O(logn)) 个,所以总状态数 (O(nlogn)) ,不会超时。
交一发,(WA) 了。
于是开始换根。
对于先点亮的那个点,还是要先把它的子树点亮,然后再点亮它的父节点。
这时对每个 (i) ,有用的 (dp[i][j]) 中的 (j) 除了所有祖先的另一个孩子,还有它所有的祖先,但还是 (O(logn)) 级别,总复杂度 (O(nlogn))。
记忆化搜索,然后超时了 (qwq)
那就不记忆化了(用 (map) 常数过大【捂脸】)
重新设状态——
(f[i][j]) 表示 (dp[i][y]) ,其中 (y) 为 (i) 的第 (j+1) 个祖先。
(g[i][j]) 表示 (dp[i][z]) ,其中 (z) 为 (i) 的第 (j+1) 个祖先的另一个孩子。
(O(nlogn)) 时间能把这些值都算出来,然后再换根。
交一发, (WA) 了。
发现不一定每个非叶子节点都有2个孩子,于是又改了改细节。终于 (A) 掉了!
代码
细节极多 【害怕】
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
int read(){
int x=0;
char ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x;
}
const int N = 200005;
typedef long long ll;
int n,a[N],b[N];
ll f[N][20],g[N][20];
ll ans;
void dfs(int x,ll cur){ //换根
int l=x*2,r=x*2+1;
if(x!=1){
ll now;
if(r<=n) now=min(1ll*a[l]*b[l]+g[l][0]+f[r][1],1ll*a[r]*b[r]+g[r][0]+f[l][1]);
else if(l==n) now=1ll*a[l]*b[l]+f[l][1];
else now=f[x][0];
ans=min(ans,now+cur);
}
if(l>n) return;
if(l==n) dfs(l,cur+1ll*b[x]*a[x/2]);
else{
dfs(l,cur+1ll*a[r]*b[r]+f[r][1]);
dfs(r,cur+1ll*a[l]*b[l]+f[l][1]);
}
}
int main()
{
n=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=2;i<=n;i++) b[i]=read();
for(int i=n;i>0;i--){
if(i*2>n){
int x=i/2,last=(i&1) ? i-1 : i+1;
ll s=b[i];
for(int j=0;x>=0;j++,x/=2){
f[i][j]=s*a[x];
g[i][j]=1ll*(s+b[last])*a[last];
s+=b[x]; last=(x&1) ? x-1 : x+1;
if(x==0) break;
}
continue;
}
else if(i*2==n){
for(int j=0,x=i/2;x>=0;j++,x/=2){
f[i][j]=1ll*a[n]*b[n]+f[n][j+1];
g[i][j]=1ll*a[n]*b[n]+g[n][j+1];
if(x==0) break;
}
continue;
}
int l=i*2,r=l+1;
for(int j=0,x=i/2;x>=0;j++,x/=2){
f[i][j]=min(1ll*a[l]*b[l]+g[l][0]+f[r][j+1],1ll*a[r]*b[r]+g[r][0]+f[l][j+1]);
g[i][j]=min(1ll*a[l]*b[l]+g[l][0]+g[r][j+1],1ll*a[r]*b[r]+g[r][0]+g[l][j+1]);
if(x==0) break;
}
}
ans=f[1][0];
dfs(1,0);
printf("%lld
",ans);
return 0;
}