题解:
首先贪心的会发现我们每次一定会选当前权值和最大的那个
然后在于怎么维护这个最大值
我们发现每个修改实际上是对沿途所有点的子树的修改
所以用线段树维护就可以了。。
另外注意有重复部分,但一定是包含关系所以比较好处理
代码:
#include <bits/stdc++.h> using namespace std; #define IL inline #define ll long long #define rint register ll #define rep(i,h,t) for (rint i=h;i<=t;i++) #define dep(i,t,h) for (rint i=t;i>=h;i--) #define mid ((h+t)/2) char ss[1<<24],*A=ss,*B=ss; char gc() { return A==B&&(B=(A=ss)+fread(ss,1,1<<24,stdin),A==B)?EOF:*A++; } template<class T> void read(T &x) { rint f=1,c; while (c=gc(),c<48||c>57) if (c=='-') f=-1; x=c^48; while (c=gc(),47<c&&c<58) x=(x<<3)+(x<<1)+(c^48); x*=f; } const ll INF=1e9; const ll N=3e5; ll n,m,k,l,v[N],head[N],maxn[N],minn[N],sum[N],fa[N],cnt,ans; struct re{ ll a,b; }a[N]; const ll N2=N*4; ll data1[N2],data2[N2],data3[N2],lazy1[N2],lazy2[N2],real2[N2]; void arr(ll x,ll y) { a[++l].a=head[x]; a[l].b=y; head[x]=l; } IL void down(ll x) { if (!lazy1[x]) return; lazy1[x*2]=lazy1[x*2+1]=lazy1[x]; lazy2[x*2]+=lazy2[x]; lazy2[x*2+1]+=lazy2[x]; data3[x*2]+=lazy2[x]; data3[x*2+1]+=lazy2[x]; data2[x*2]=lazy1[x]; data2[x*2+1]=lazy1[x]; lazy1[x]=lazy2[x]=0; } IL void updata(ll x) { if (data3[x*2]<data3[x*2+1]) { data2[x]=data2[x*2+1]; data1[x]=data1[x*2+1]; data3[x]=data3[x*2+1]; } else { data2[x]=data2[x*2]; data1[x]=data1[x*2]; data3[x]=data3[x*2]; } } void build(ll x,ll h,ll t) { if (h==t) { data1[x]=real2[h]; return; } build(x*2,h,mid); build(x*2+1,mid+1,t); } void change(ll x,ll h,ll t,ll h1,ll t1,ll k1,ll k2) { if (h1<=h&&t<=t1) { lazy1[x]=k1; lazy2[x]+=k2; data3[x]+=k2; data2[x]=k1; return; } down(x); if (h1<=mid) change(x*2,h,mid,h1,t1,k1,k2); if (mid<t1) change(x*2+1,mid+1,t,h1,t1,k1,k2); updata(x); } void dfs1(ll x) { ll u=head[x]; if (!u) { real2[++cnt]=x; return; } while (u) { ll vv=a[u].b; dfs1(vv); u=a[u].a; } } void dfs(ll x,ll y) { ll u=head[x]; sum[x]=y+v[x]; if (!u) { minn[x]=maxn[x]=++cnt; real2[cnt]=x; change(1,1,n,cnt,cnt,0,sum[x]); return; } while (u) { ll vv=a[u].b; dfs(vv,y+v[x]); minn[x]=min(minn[x],minn[vv]); maxn[x]=max(maxn[x],maxn[vv]); u=a[u].a; } } int main() { freopen("1.in","r",stdin); freopen("1.out","w",stdout); read(n); read(k); rep(i,1,n) read(v[i]); rep(i,1,n) minn[i]=INF; rep(i,1,n-1) { ll x,y; read(x); read(y); arr(x,y); fa[y]=x; } dfs1(1); build(1,1,n); cnt=0; dfs(1,0); rep(i,1,k) { ll x=data1[1],y=data2[1],z=data3[1]; ans+=z; int kk1=minn[x]-1,kk2=minn[x]; while (x!=y) { if (kk1>=minn[x]) change(1,1,n,minn[x],kk1,x,-(sum[x]-sum[y])); if (kk2<=maxn[x]) change(1,1,n,kk2,maxn[x],x,-(sum[x]-sum[y])); kk1=minn[x]-1; kk2=maxn[x]+1; x=fa[x]; } } cout<<ans<<endl; return 0; }