WA的一声哭出来了。。。
显然是二分时间,显然是越靠近首都,军队的作用越大。
有几个细节
第一是军队向上跳的时候用倍增。
第二是如果这个军队有能力跳到根节点,记录一下它能继续走的路,然后用它去搞别的子树。
判断是否可行的时候把不能被覆盖的子树记下来,用之前记录下来跳过首都的去尝试能不能覆盖到该子树的根节点。
同时我们要记录能覆盖到根节点下的第一层的节点的最小路程,让它在下面不能被覆盖的时候驻守这里。
#include <iostream> #include <algorithm> #include <cstdio> #include <cstring> #define int long long using namespace std; const int N=50005; int f[N][18],dis[N][18],n,m,head[N],ecnt,nxt[N<<1],to[N<<1],val[N<<1],cnta,cntb,rstmn[N],jd[N],rst[N]; inline void add(int bg,int ed,int v) { nxt[++ecnt]=head[bg]; to[ecnt]=ed; val[ecnt]=v; head[bg]=ecnt; } void dfs(int x) { for(int i=1;i<=17;i++) f[x][i]=f[f[x][i-1]][i-1], dis[x][i]=dis[x][i-1]+dis[f[x][i-1]][i-1]; for(int i=head[x];i;i=nxt[i]) if(to[i]!=f[x][0]) f[to[i]][0]=x,dis[to[i]][0]=val[i],dfs(to[i]); } struct Node{int id,rest;}a[N],b[N]; inline bool cmp1(Node x,Node y) {return x.rest>y.rest;} bool vis[N],used[N]; inline bool ok(int x,int fa) { if(vis[x]) return 1; bool flag=1,haveson=0; for(int i=head[x];i;i=nxt[i]) { if(to[i]==fa) continue; haveson=1; if(!ok(to[i],x)) { if(x==1) flag=0,b[++cntb]=(Node){to[i],val[i]}; else return 0; } } if(!haveson) return 0; return flag; } inline bool ck(long long tim) { cnta=cntb=0;memset(vis,0,sizeof vis);memset(rstmn,0,sizeof rstmn);memset(rst,0,sizeof rst); for(int i=1;i<=m;i++) { int x=jd[i];long long sum=0; for(int j=17;~j;j--) if(sum+dis[x][j]<=tim&&f[x][j]>1) sum+=dis[x][j],x=f[x][j]; if(f[x][0]==1&&tim-sum-dis[x][0]>=0) { a[++cnta]=(Node){i,tim-sum-dis[x][0]}; if(!rst[x]||rstmn[x]>a[cnta].rest) rstmn[x]=a[cnta].rest,rst[x]=i; } else vis[x]=1; } if(ok(1,0)) return 1; sort(a+1,a+cnta+1,cmp1),sort(b+1,b+1+cntb,cmp1); memset(used,0,sizeof used); used[0]=1; for(int i=1,j=1;i<=cntb;i++) { if(!used[rst[b[i].id]]) {used[rst[b[i].id]]=1;continue;} while(j<=cnta&&(a[j].rest<b[i].rest||used[a[j].id]))j++; if(j>cnta) return 0;used[a[j].id]=1; } return 1; } signed main() { scanf("%lld",&n); for(int i=1,x,y,z;i<n;i++) {scanf("%lld%lld%lld",&x,&y,&z),add(x,y,z),add(y,x,z);} scanf("%lld",&m); for(int i=1;i<=m;i++) scanf("%lld",&jd[i]); dfs(1); long long l=1,r=0x3f3f3f3f3f3f3f3f,ans=r; while(l<=r) { long long mid=l+r>>1; if(ck(mid)) {r=mid-1;ans=mid;} else l=mid+1; } if(ans!=0x3f3f3f3f3f3f3f3f) cout<<ans;else cout<<-1; }