这道题由于问最大值最小,所以很容易想到二分,但怎么验证并且如何实现是这道题的难点;
首先我们考虑,对于一个军队,尽可能的往根节点走(但一定不到)是最优的;
判断一个军队最远走到哪可以树上倍增来实现;
但是,这并没有结束,因为可能这颗子树的军队会去另一个军队;
我们先找出所有以根节点的子节点为根的子树中,是否有到叶子节点的路径还未被驻扎,并记录下还有路径未被驻扎的这些子树的根节点;
若该节点上停留有军队,则剩余时间最小的军队驻扎在该节点一定是最优的。
这样处理过这些节点后,把剩下的节点按照到根节点的距离从小到大排序。
对于现在闲置的军队和需要被驻扎的节点,让剩余时间小的军队优先驻扎在距离根节点近的节点,这样可以保证决策最优
#include <bits/stdc++.h> using namespace std; const int MXR=5e4+2; int n,m,t,tot=0,atot=0,btot=0,ctot=0; int d[MXR],query[MXR],f[MXR][20]; int ver[2*MXR],edge[2*MXR],MXRext[2*MXR],head[MXR],dist[MXR][20]; pair<long long,int> h[MXR]; void add(int x,int y,int z){ ver[++tot]=y,edge[tot]=z,MXRext[tot]=head[x],head[x]=tot; } void bfs() { queue<int> q; q.push(1); d[1]=1; while(q.size()){ int x=q.front();q.pop(); for(int i=head[x];i;i=MXRext[i]){ int y=ver[i]; if(d[y]) continue; d[y]=d[x]+1; f[y][0]=x,dist[y][0]=edge[i]; for(int j=1;j<=t;j++){ f[y][j]=f[f[y][j-1]][j-1]; dist[y][j]=dist[y][j-1]+dist[f[y][j-1]][j-1]; } q.push(y); } } } bool ok,sta[MXR],need[MXR]; long long ans,tim[MXR],ned[MXR]; int dfs(register int x) { bool pson=0; if(sta[x]) return 1; for(int i=head[x];i;i=MXRext[i]){ int y=ver[i]; if(d[y]<d[x]) continue; pson=1; if(!dfs(y)) return 0; } if(!pson) return 0; return 1; } template<class nT> inline void read(nT&x) { char c;while(c=getchar(),!isdigit(c)); x=c^48;while(c=getchar(),isdigit(c)) x=x*10+c-48; } bool check(long long lim) { memset(sta,0,sizeof(sta)); memset(tim,0,sizeof(tim)); memset(ned,0,sizeof(ned)); memset(h,0,sizeof(h)); memset(need,0,sizeof(need)); atot=0,btot=0,ctot=0; for(int i=1;i<=m;i++){ long long x=query[i],cnt=0; for(int j=t;j>=0;j--) if(f[x][j]>1 && cnt+dist[x][j]<=lim){ cnt+=dist[x][j]; x=f[x][j]; } if(f[x][0]==1 && cnt+dist[x][0]<=lim) h[++ctot]=make_pair(lim-cnt-dist[x][0],x); else sta[x]=1; } for(int i=head[1];i;i=MXRext[i]) if(!dfs(ver[i])) need[ver[i]]=1; sort(h+1,h+ctot+1); for(int i=1;i<=ctot;i++){ if(need[h[i].second] && h[i].first<dist[h[i].second][0]) need[h[i].second]=0; else tim[++atot]=h[i].first; } for(int i=head[1];i;i=MXRext[i]) if(need[ver[i]]) ned[++btot]=dist[ver[i]][0]; if(atot<btot) return 0; sort(tim+1,tim+atot+1),sort(ned+1,ned+btot+1); int i=1,j=1; while(i<=btot && j<=atot) if(tim[j]>=ned[i]){ i++,j++; } else j++; if(i>btot)return 1; return 0; } int main() { long long l=0,r=0,mid; cin>>n; t=log2(n)+1; for(int i=1;i<=n-1;i++){ int x,y,z; read(x); read(y); read(z); add(x,y,z),add(y,x,z); r+=z; } bfs(); cin>>m; for(int i=1;i<=m;i++) read(query[i]); while(l<=r){ mid=(l+r)>>1; if(check(mid)){ r=mid-1; ans=mid; ok=1; } else l=mid+1; } if(!ok) cout<<-1; else cout<<ans; return 0; }