真就讲完套路啥都简单……
题面
S 有一棵 (n) 个点的有根树,每个点有权值 (w_i) ,初始每个结点上都没有石子。
S 准备了一些石子,并把它们拿在手中。她可以进行以下两种操作任意多次:
- 从手中取 (w_i) 个石子放在结点 (i) 上,进行该操作要求结点 (i) 的所有孩子 (j) 上都有 (w_j) 个石子。
- 将结点 (i) 上的所有石子收回手中。
T 想知道对于每个 (i) ,为了在结点 (i) 上放 (w_i) 个石子,S 至少需要准备多少石子。
思路
显然,当你往一个节点放石子的时候,每个儿子节点都是放满的。
那么每一次操作可以用一个二元组表示: ((w_i-sum w_{son},sum w_{son})) ,表示这一次操作完成后手中石子的增量,和这次操作中石子数达到的最大值。
由于 “父亲受多个儿子限制” 很难处理,考虑把整个操作序列倒过来处理,那么一个点的限制就只有它的父亲。二元组变成了 ((sum w_{son}-w_i,sum w_{son})) ,需要最小化这个序列的历史最值。
考虑贪心,对每个操作求出它的优先度。
对于 (x<0) 的情况,显然放在前面更优(使前缀更小),按 (y) 升序;
如果 (xge 0) ,试比较 ((x,y)) 和 ((x',y')) 的优劣:(max(y,y'+x)<max( y',x'+y))
((y'+x)< (x'+y) => x-y>x'-y')
如此就得到了全局最优的序列。
现在,每次找最优的一个,如果这个点的父亲还没处理,那么就和父亲合并即可(一旦做了父亲就做,因为优先级高)。注意到每棵子树的最优序列是全局的子序列,线段树合并得到答案。
代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+10;
struct node //二元组
{
ll x,y; int head,tail;
node operator + ( const node &tmp ) const { return (node){x+tmp.x,max(y,x+tmp.y),head,tmp.tail}; }
bool operator < ( const node &tmp ) const
{
int t1=(x>=0),t2=(tmp.x>=0);
if( t1!=t2 ) return x<tmp.x;
if ( !t1 )
if ( y!=tmp.y ) return y<tmp.y;
else return head<tmp.head;
if ( (y-x)!=(tmp.y-tmp.x) ) return (y-x)>(tmp.y-tmp.x);
return head<tmp.head;
}
}lis[N];
struct SegmentTree
{
node x; int l,r;
}tr[N*40];
int f[N],n,id[N],fa[N],tr2[N],tot,nxt[N];
ll w[N],sum[N],ans[N],val[N];
bool vis[N];
vector<int> ve[N];
set<node> s;
int find( int x ) { return x==f[x] ? x : f[x]=find(f[x]); }
int merge( int u,int v )
{
if ( !u || !v ) return u | v;
tr[u].l=merge( tr[u].l,tr[v].l ); tr[u].r=merge( tr[u].r,tr[v].r );
tr[u].x=tr[tr[u].l].x+tr[tr[u].r].x;
return u;
}
void insert( int &cnt,int l,int r,int x,int id )
{
cnt=++tot;
if ( l==r ) { tr[cnt].x=(node){val[id],sum[id],l,l}; return; }
int mid=(l+r)>>1;
if ( x<=mid ) insert( tr[cnt].l,l,mid,x,id );
else insert( tr[cnt].r,mid+1,r,x,id );
tr[cnt].x=tr[tr[cnt].l].x+tr[tr[cnt].r].x;
}
void dfs( int u,int fa )
{
insert( tr2[u],1,n,id[u],u );
for ( vector<int> :: iterator it =ve[u].begin(); it!=ve[u].end(); it++ )
{
int v=*it; dfs( v,u ); tr2[u]=merge( tr2[u],tr2[v] );
}
ans[u]=tr[tr2[u]].x.y+w[u];
}
int main()
{
int cas; scanf( "%d%d",&cas,&n );
for ( int i=2; i<=n; i++ )
scanf( "%d",&fa[i] ),ve[fa[i]].push_back(i);
for ( int i=1; i<=n; i++ )
scanf( "%lld",&w[i] ),sum[fa[i]]+=w[i],f[i]=i;
for ( int i=1; i<=n; i++ )
{
val[i]=sum[i]-w[i]; lis[i]=(node){val[i],sum[i],i,i};
s.insert(lis[i]);
}
vis[0]=1; int now=0;
while ( !s.empty() )
{
set<node>:: iterator it=s.begin(); node x=*it; s.erase(it);
if ( vis[fa[x.head]] ) //父亲节点已经处理过了
{
nxt[now]=x.head;
while ( now!=x.tail ) vis[now]=1,now=nxt[now];
vis[now]=1;
}
else //否则和父亲合并
{
int pf=find( fa[x.head] ); s.erase( lis[pf] );
nxt[lis[pf].tail]=lis[x.head].head; lis[pf]=lis[pf]+lis[x.head];
f[find(x.head)]=pf; s.insert( lis[pf] );
}
}
int cnt=0; now=0;
for ( int i=1; i<=n; i++ )
cnt++,now=nxt[now],id[now]=i;
dfs( 1,0 );
for ( int i=1; i<=n; i++ )
printf( "%lld ",ans[i] );
}