题目大意:
给你一棵n个结点的带权树,有q组询问,问你从u到v的路径上最大值与最小值的差(最大值在最小值后面)。
思路:
首先考虑路径上合并两个子路径u->t和t->v时的情况。
假设我们已经知道了两个路径的最大值max,最小值min,以及路径上最大值与最小值的差d(最大值在最小值后面),
那么我们最大值和最小值可以直接合并,d=max(d1,d2,max2-max1)。
现在我们用倍增或者树链剖分维护这些东西,再跑一跑LCA即可。
然而我们发现往上跑和往下跑是不一样的,所以我们要维护两种差值up和down,一种是最大值在最小值上,一种是最小值在最大值上。
1 #include<cstdio> 2 #include<cctype> 3 #include<vector> 4 inline int getint() { 5 register char ch; 6 while(!isdigit(ch=getchar())); 7 register int x=ch^'0'; 8 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 9 return x; 10 } 11 const int inf=0x7fffffff; 12 const int N=50001,logN=16; 13 int w[N]; 14 std::vector<int> e[N]; 15 inline void add_edge(const int &u,const int &v) { 16 e[u].push_back(v); 17 e[v].push_back(u); 18 } 19 inline int log2(const float &x) { 20 return ((unsigned&)x>>23&255)-127; 21 } 22 int dep[N],anc[N][logN],max[N][logN],min[N][logN],up[N][logN],down[N][logN]; 23 void dfs(const int &x,const int &par) { 24 dep[x]=dep[par]+1; 25 anc[x][0]=par; 26 max[x][0]=std::max(w[x],w[par]); 27 min[x][0]=std::min(w[x],w[par]); 28 up[x][0]=std::max(w[par]-w[x],0); 29 down[x][0]=std::max(w[x]-w[par],0); 30 for(register int i=1;i<=log2(dep[x]);i++) { 31 anc[x][i]=anc[anc[x][i-1]][i-1]; 32 max[x][i]=std::max(max[x][i-1],max[anc[x][i-1]][i-1]); 33 min[x][i]=std::min(min[x][i-1],min[anc[x][i-1]][i-1]); 34 up[x][i]=std::max(std::max(up[x][i-1],up[anc[x][i-1]][i-1]),max[anc[x][i-1]][i-1]-min[x][i-1]); 35 down[x][i]=std::max(std::max(down[x][i-1],down[anc[x][i-1]][i-1]),max[x][i-1]-min[anc[x][i-1]][i-1]); 36 } 37 for(unsigned i=0;i<e[x].size();i++) { 38 const int &y=e[x][i]; 39 if(y==par) continue; 40 dfs(y,x); 41 } 42 } 43 inline int lca(int x,int y) { 44 if(dep[x]<dep[y]) std::swap(x,y); 45 for(register int i=log2(dep[x]);i>=0;i--) { 46 if(dep[anc[x][i]]>=dep[y]) { 47 x=anc[x][i]; 48 } 49 } 50 if(x==y) return x; 51 for(register int i=log2(dep[x]);i>=0;i--) { 52 if(anc[x][i]!=anc[y][i]) { 53 x=anc[x][i]; 54 y=anc[y][i]; 55 } 56 } 57 return anc[x][0]; 58 } 59 inline int solve(int x,int y) { 60 const int t=lca(x,y); 61 int pmaxup=0,pminup=inf,pmaxdown=0,pmindown=inf,pup=0,pdown=0; 62 for(register int i=log2(dep[x]);i>=0;i--) { 63 if(dep[anc[x][i]]>=dep[t]) { 64 pup=std::max(std::max(pup,up[x][i]),max[x][i]-pminup); 65 pmaxup=std::max(pmaxup,max[x][i]); 66 pminup=std::min(pminup,min[x][i]); 67 x=anc[x][i]; 68 } 69 } 70 for(register int i=log2(dep[y]);i>=0;i--) { 71 if(dep[anc[y][i]]>=dep[t]) { 72 pdown=std::max(std::max(pdown,down[y][i]),pmaxdown-min[y][i]); 73 pmaxdown=std::max(pmaxdown,max[y][i]); 74 pmindown=std::min(pmindown,min[y][i]); 75 y=anc[y][i]; 76 } 77 } 78 return std::max(std::max(pup,pdown),pmaxdown-pminup); 79 } 80 int main() { 81 int n=getint(); 82 for(register int i=1;i<=n;i++) { 83 w[i]=getint(); 84 } 85 for(register int i=1;i<n;i++) { 86 add_edge(getint(),getint()); 87 } 88 dfs(1,0); 89 for(register int q=getint();q;q--) { 90 const int u=getint(),v=getint(); 91 printf("%d ",solve(u,v)); 92 } 93 return 0; 94 }