【传送门:BZOJ3252】
简要题意:
给出一棵树,树上的每个节点都有权值,现在要遍历这棵树的k条链,权值为链上的节点权值和,每个节点的权值只有在第一次被遍历的时候才能用,也就是每个节点遍历两遍只能得到一次的权值,求出最大能得到的权值和
题解:
哇塞,直接就给一棵树,而且输入还直接告诉你两个点的父子关系,爽!
然后想想,每次遍历,肯定遍历一条链,那么我们先用DFS序来重新编号,然后记录每一个点的子树的第一个新编号和最后一个新编号
然后用sum[i]表示第i个点到根节点的权值和
想到了,线段树维护贪心,因为每次取的链一定是当前权值和最大的(可以自己出一出小数据证一下)
先将sum数组放进线段树的每个对应新点处
然后维护区间最大值
那么我们每次得到的tr[1].mx就一定是权值和最大的链的权值和,然后我们还要对当前这个最大值所在的位置进行记录
因为每个点经过一次之后就要把自身的值变为0,所以每次找到一个链之后就要将这条链从下向上遍历一遍,然后再进行处理
显然将链遍历一遍的复杂度太高了,那么我们就记录每个点是否被遍历过,如果这个点被遍历过,那么就不用向上遍历了,因为一旦这个点被遍历过,那么上面的点一定也被遍历过
然后每个点的改变,只会对这个点的子树有影响
所以当这个点被遍历过,就将这个点的子树里的所有点减去这个点值,这样就能维护线段树的值了
总的来说就是单点询问,区间修改线段树
注意加long long,因为lazy没搞好,WA了3次
参考代码:
#include<cstdio> #include<cstdlib> #include<algorithm> #include<cmath> #include<cstring> using namespace std; typedef long long LL; LL s[210000]; struct node { int x,y,next; }a[210000];int len,last[210000]; void ins(int x,int y) { len++; a[len].x=x;a[len].y=y; a[len].next=last[x];last[x]=len; } int fa[210000],dfn[210000],tot,to[210000]; int L[210000],R[210000]; LL sum[210000]; void dfs(int x) { dfn[x]=++tot;to[tot]=x; L[dfn[x]]=tot; sum[x]=sum[fa[x]]+s[x]; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; dfs(y); } R[dfn[x]]=tot; } struct trnode { int l,r,lc,rc,x;LL mx,lazy; }tr[410000];int trlen; void bt(int l,int r) { trlen++;int now=trlen; tr[now].l=l;tr[now].r=r;tr[now].mx=tr[now].lazy=0; tr[now].lc=tr[now].rc=-1;tr[now].x=-1; if(l==r) tr[now].mx=sum[to[l]],tr[now].x=l; else { int mid=(l+r)/2; tr[now].lc=trlen+1;bt(l,mid); tr[now].rc=trlen+1;bt(mid+1,r); tr[now].mx=max(tr[tr[now].lc].mx,tr[tr[now].rc].mx); if(tr[now].mx==tr[tr[now].lc].mx) tr[now].x=tr[tr[now].lc].x; else tr[now].x=tr[tr[now].rc].x; } } void update(int now) { int lc=tr[now].lc,rc=tr[now].rc; if(lc!=-1) tr[lc].mx-=tr[now].lazy,tr[lc].lazy+=tr[now].lazy; if(rc!=-1) tr[rc].mx-=tr[now].lazy,tr[rc].lazy+=tr[now].lazy; tr[now].lazy=0; } void change(int now,int l,int r,LL c) { if(tr[now].l==l&&tr[now].r==r) { tr[now].lazy+=c; tr[now].mx-=c; return ; } int mid=(tr[now].l+tr[now].r)/2,lc=tr[now].lc,rc=tr[now].rc; if(tr[now].lazy>0) update(now); if(r<=mid) change(lc,l,r,c); else if(l>mid) change(rc,l,r,c); else change(lc,l,mid,c),change(rc,mid+1,r,c); tr[now].mx=max(tr[lc].mx,tr[rc].mx); if(tr[now].mx==tr[lc].mx) tr[now].x=tr[lc].x; else tr[now].x=tr[rc].x; } bool v[210000]; int main() { int n,k; scanf("%d%d",&n,&k); for(int i=1;i<=n;i++) scanf("%lld",&s[i]); len=0;memset(last,0,sizeof(last)); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); fa[y]=x; ins(x,y); } tot=0;sum[1]=s[1];dfs(1); trlen=0;bt(1,tot); LL ans=0;fa[1]=0; memset(v,true,sizeof(v)); for(int i=1;i<=k;i++) { if(tr[1].mx==0) break; ans+=tr[1].mx; int x=to[tr[1].x]; while(x!=0&&v[x]==true) { change(1,L[dfn[x]],R[dfn[x]],s[x]); v[x]=false; x=fa[x]; } } printf("%lld ",ans); return 0; }