树上差分加上二分答案
详细题解待填坑
#include <cstdio> #include <algorithm> #include <cstring> using namespace std; const int MAXlog = 20; const int MAXM = 711001; const int MAXN = 400001; int dep[MAXN],jump[MAXN][MAXlog+2]; int lcav[MAXM],cf[MAXN],fw[MAXN],lent[MAXN]; int uline[MAXM],vline[MAXM],wline[MAXM]; int cnt=0,u[MAXN*2],v[MAXN*2],w[MAXN*2],first[MAXN*2],next[MAXN*2]; int n,m; void addline(int ux,int vx,int i){ uline[i]=ux; vline[i]=vx; wline[i]=lent[ux]+lent[vx]-2*lent[lcav[i]]; } void addedge(int ux,int vx,int wx){ cnt++; u[cnt]=ux; v[cnt]=vx; w[cnt]=wx; next[cnt]=first[ux]; first[ux]=cnt; } void dfs(int u,int f){ // printf("u=%d f=%d ",u,f); dep[u]=dep[f]+1; jump[u][0]=f; // printf("u=%d j=0 jump=%d ",u,jump[u][0]); for(int i=1;i<=MAXlog;i++){ jump[u][i]=jump[jump[u][i-1]][i-1]; // printf("u=%d j=%d jump=%d ",u,i,jump[u][i]); } for(int i=first[u];i;i=next[i]) if(v[i]!=f) dfs(v[i],u); } void dfs2(int u,int f,int wk){ fw[u]=wk; lent[u]=lent[f]+wk; for(int i=first[u];i;i=next[i]) if(v[i]!=f) dfs2(v[i],u,w[i]); } int lca(int x,int y){ if(dep[y]>dep[x]) swap(x,y); for(int i=MAXlog;i>=0;i--) if(dep[x]-(1<<i)>=dep[y]) x=jump[x][i]; if(x==y) return x; for(int i=MAXlog;i>=0;i--) if(jump[x][i]!=jump[y][i]){ x=jump[x][i]; y=jump[y][i]; } return jump[x][0]; } void runcf(int l){ cf[uline[l]]++; cf[vline[l]]++; cf[lcav[l]]-=2; } int maxlen=0; void calcf(int u,int f,int num){ for(int i=first[u];i;i=next[i]){ if(v[i]==f) continue; calcf(v[i],u,num); cf[u]+=cf[v[i]]; } if(cf[u]==num&&fw[u]>maxlen) maxlen=fw[u]; } bool check(int ans){ int inq=0; int maxchain=0; maxlen=0; // printf("!ok "); memset(cf,0,sizeof(cf)); for(int i=1;i<=cnt;i++) if(wline[i]>ans){ runcf(i); inq++; if(wline[i]>maxchain) maxchain=wline[i]; } calcf(1,0,inq); // printf("*ok "); if(maxchain-maxlen<=ans) return true; else return false; } int main(){ scanf("%d %d",&n,&m); int maxt=0; for(int i=1;i<=n-1;i++){ int a,b,t; scanf("%d %d %d",&a,&b,&t); addedge(a,b,t); addedge(b,a,t); maxt=max(maxt,t); } dep[0]=-1; dfs(1,0); dfs2(1,0,0); int maxline=0; for(int i=1;i<=m;i++){ int um,vm; scanf("%d %d",&um,&vm); // if(jump[5][0]!=3) // printf("! "); lcav[i]=lca(um,vm); addline(um,vm,i); if(wline[i]>maxline) maxline=wline[i]; } /*for(int i=0;i<=n;i++){ for(int j=0;j<=MAXlog;j++) printf("jump[%d][%d]=%d ",i,j,jump[i][j]); printf(" "); }*/ /*for(int i=1;i<=m;i++){ printf("u=%d v=%d lca=%d len=%d ",uline[i],vline[i],lcav[i],wline[i]); }*/ int l=maxline-maxt,r=maxline+1; while(l<r){ int mid=(l+r)>>1; if(check(mid)) r=mid; else l=mid+1; } printf("%d",r); return 0; }