CF932F Escape Through Leaf
首先, $ O(n^2) $ dp 是很显然的,方程长这样:
[dp[u] = min{dp[v] + a_u imes b_v}
]
这个方程看起来就很斜率,当我们写成了斜率优化的形式大概是这样的:
[frac{dp[v]-dp[j]}{a_v-a_j} < -b_u
]
我们想通过这个式子做就必须维护动态凸包以及凸包的合并。这个东西是很恼火的,可能用 set 和 splay 啥的可以搞,可惜不大会。
这里就引入了一种科技,李超线段树。
李超线段树是一种维护线段的线段树,支持插入一个线段,询问 $ x_0 $ 上的最大/最小值。
不难发现这个 dp 方程就是个直线的形式,当我们计算完了 $ u $ 就把 $ b_ux+dp[u] $ 这个直线插入。查询就是查 $ a_u$ 的值。
李超树的实现是每个点存储一个直线,考虑我们插入一个线段到一个节点:
- 这里没有线段,直接放进去
- 这里有线段但是被这个线段完爆,把这个位置的线段替换掉 return
- 这里有线段并且完爆插入线段,直接return
- 否则,必然插入线段和节点线段有交,把较长一段放在这里,较短的递归到一个子树。因为较短的必然不超过节点长度的一半。
这题还需要写一个线段树合并,和普通的线段树合并也没啥区别,就是先递归合并,最后把需要合并进去的树的当前节点的线段插入当前树的这个节点。
最终复杂度,看起来很 $ O(nlog^2n) $ 但是有神仙证明了复杂度是 $ O(nlog n) $ 也不是很清楚了。
这个东西还是很好写的:
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "vector"
using namespace std;
#define MAXN 100006
typedef long long ll;
#define min( a , b ) ( (a) < (b) ? (a) : (b) )
int n , m , L;
#define D 100006
int A[MAXN] , B[MAXN];
vector<int> G[MAXN];
struct line {
ll k , b;
ll re( ll x ) { return k * x + b; }
} f[MAXN] ;
int ls[MAXN << 4] , rs[MAXN << 4] , id[MAXN << 4] , cnt , rt[MAXN];
void ins( int& x , int l , int r , int d ) {
if( !x ) { x = ++ cnt , id[x] = d; return; }
int m = l + r >> 1;
if( f[id[x]].re( m ) > f[d].re( m ) ) swap( d , id[x] );
if( f[id[x]].re( l ) <= f[d].re( l ) && f[id[x]].re( r ) <= f[d].re( r ) ) return;
if( f[id[x]].re( l ) > f[d].re( l ) ) ins( ls[x] , l , m , d );
else ins( rs[x] , m + 1 , r , d );
}
long long que( int x , int l , int r , int p ) {
if( !x ) return 0x3f3f3f3f3f3f3f3f;
int m = l + r >> 1; ll re = f[id[x]].re( p );
if( p <= m ) return min( re , que( ls[x] , l , m , p ));
else return min( re , que( rs[x] , m + 1 , r , p ));
}
int merge( int x , int y , int l , int r ) {
// printf("%d %d %d %d
",x,y,l,r);
if( !x || !y ) return x + y;
ins( x , l , r , id[y] );
int m = l + r >> 1;
ls[x] = merge( ls[x] , ls[y] , l , m );
rs[x] = merge( rs[x] , rs[y] , m + 1 , r );
return x;
}
long long ans[MAXN];
void dfs( int u , int fa ) {
for( int v : G[u] ) if( v != fa ) {
dfs( v , u );
rt[u] = merge( rt[u] , rt[v] , 1 , D << 1 );
}
ans[u] = que( rt[u] , 1 , D << 1 , A[u] + D );
if( ans[u] > 1e18 ) ans[u] = 0;
f[u] = (line) { B[u] , ans[u] - 1ll * B[u] * D };
ins( rt[u] , 1 , D << 1 , u );
}
int main() {
cin >> n;
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&A[i]);
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&B[i]);
for( int i = 1 , u , v ; i < n ; ++ i ) {
scanf("%d%d",&u,&v) , G[u].push_back( v ) , G[v].push_back( u );
}
dfs( 1 , 1 );
for( int i = 1 ; i <= n ; ++ i ) printf("%lld ",ans[i]);
}