Link:
Solution:
感觉NOIP T3也是有点东西的……
将该题转化为最大值最小问题后想到二分答案
接下来考虑$check$时如何贪心:
由于除了在根节点所有军队都只往上跳明显采取倍增的方式
记录所有能到达根节点的军队和根节点下所有未被封死的子树
将两个序列从小到大排序后贪心匹配即可判断
注意:在判断封死子树时不考虑能跳到根节点的军队
在匹配时用剩余量最小的军队来封死原子树即可
Code:
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const ll INF=5e14; const int MAXN=5e4+10; struct edge{int nxt,to,w;}e[MAXN<<2]; int n,m,x,y,w,pos[MAXN],head[MAXN],f[MAXN][30],dep[MAXN],tot; int vis[MAXN],cnt1,cnt2;ll d[MAXN][30]; struct Data{int pos;ll val;}a[MAXN],b[MAXN]; bool cmp(Data x,Data y){return x.val<y.val;} void add(int x,int y,int w) {e[++tot]=(edge){head[x],y,w};head[x]=tot;} void dfs(int x,int anc) { for(int i=1;(1<<i)<=dep[x];i++) f[x][i]=f[f[x][i-1]][i-1], d[x][i]=d[x][i-1]+d[f[x][i-1]][i-1]; for(int i=head[x];i;i=e[i].nxt) { if(e[i].to==anc) continue; f[e[i].to][0]=x; dep[e[i].to]=dep[x]+1; d[e[i].to][0]=e[i].w; dfs(e[i].to,x); } } bool find(int x,int anc) { bool f1=0,f2=0; if(vis[x]) return 1; for(int i=head[x];i;i=e[i].nxt) { if(e[i].to==anc) continue; f2=1; if(!find(e[i].to,x)) { f1=1; if(x==1) b[++cnt2]=(Data){e[i].to,e[i].w}; else return 0; } } return f2?(!f1):0; } bool check(ll x) { cnt1=cnt2=0; memset(vis,0,sizeof(vis)); for(int i=1;i<=m;i++) { int t=pos[i];ll cur=0; for(int j=20;j>=0;j--) if(f[t][j]>1&&cur+d[t][j]<=x) cur+=d[t][j],t=f[t][j]; if(f[t][0]==1&&cur+d[t][0]<=x) a[++cnt1]=(Data){t,x-cur-d[t][0]}; else vis[t]=1; } if(find(1,0)) return 1; int cur=1; memset(vis,0,sizeof(vis)); sort(a+1,a+cnt1+1,cmp); sort(b+1,b+cnt2+1,cmp); for(int i=1;i<=cnt2;i++) vis[b[i].pos]=1; for(int i=1;i<=cnt1;i++) { if(vis[a[i].pos]) vis[a[i].pos]=0; else { while(cur<=cnt2&&!vis[b[cur].pos]) cur++; if(cur>cnt2) return 1; if(a[i].val>=b[cur].val) vis[b[cur].pos]=0,cur++; if(cur>cnt2) return 1; } } while(cur<=cnt2&&!vis[b[cur].pos]) cur++; return cur>cnt2; } int main() { scanf("%d",&n); for(int i=1;i<n;i++) scanf("%d%d%d",&x,&y,&w),add(x,y,w),add(y,x,w); dfs(1,0);scanf("%d",&m); for(int i=1;i<=m;i++) scanf("%d",&pos[i]); ll l=0,r=INF; while(l<=r) { ll mid=(l+r)>>1; if(check(mid)) r=mid-1; else l=mid+1; } printf("%lld",l==INF?-1:l); return 0; }