其实呢,我也不理解这道题咋做,等以后有时间再研究研究
#include <bits/stdc++.h> #define ll long long #define maxn 100002 using namespace std; void setIO(string s) { string in=s+".in"; freopen(in.c_str(),"r",stdin); } struct Union { int p[maxn]; void init() { for(int i=0;i<maxn;++i) p[i]=i; } int find(int x) { return p[x]==x?x:p[x]=find(p[x]); } }tr; int n,edges,m; ll val[maxn]; int hd[maxn],to[maxn<<1],nex[maxn<<1],fa[21][maxn],nx[400][maxn]; int dep[maxn],key[maxn]; void addedge(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u,int ff) { dep[u]=dep[ff]+1, fa[0][u]=ff, nx[1][u]=ff, nx[0][u]=u; for(int i=2;i<=m;++i) nx[i][u]=nx[i-1][ff]; for(int i=1;i<21;++i) fa[i][u]=fa[i-1][fa[i-1][u]]; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v^ff) dfs(v, u); } } int LCA(int x,int y) { if(dep[x]^dep[y]) { if(dep[x] > dep[y]) swap(x,y); for(int i=20;i>=0;--i) if(dep[fa[i][y]]>=dep[x]) y=fa[i][y]; } if(x==y) return x; for(int i=20;i>=0;--i) if(fa[i][x] ^ fa[i][y]) x=fa[i][x],y=fa[i][y]; return fa[0][y]; } int up(int x,int k) { if(k<=m) return nx[k][x]; for(int i=20;i>=0;--i) { if(key[i]<=k) x=fa[i][x], k-=key[i]; if(!k) break; } return x; } void modify(int x) { if(val[x]==1) return; val[x]=sqrt(val[x]); if(val[x]==1) tr.p[x]=tr.find(fa[0][x]); } int jump(int x,int y,int f,int k) { if(dep[y]-dep[f]>=k) return up(y,k); return up(x,dep[x]+dep[y]-(dep[f]<<1)-k); } int get(int x, int k) { if (k > m) return up(x, k); int y = tr.find(fa[0][x]); return up(y, (k - (dep[x] - dep[y]) % k) % k); } void update(int x,int y,int k) { int f=LCA(x,y), len=dep[x]+dep[y]-(dep[f]<<1); if(len%k) modify(y),y=jump(x,y,f,len%k),f=LCA(x,y); while(dep[x]>=dep[f]) modify(x),x=get(x,k); while(dep[y]>dep[f]) modify(y),y=get(y,k); } ll query(int x,int y,int k) { int f=LCA(x,y),len=dep[x]+dep[y]-(dep[f]<<1); ll res=0; if(len%k) { int a=len%k; res+=val[y]; // printf("%d %d ",dep[x]-dep[y],a); y=jump(x,y,f,len%k); // y=up(x,11); f=LCA(x,y); } res+=(dep[x]+dep[y]-(dep[f]<<1))/k+1; while(dep[x]>=dep[f]) res+=val[x]-1,x=get(x,k); while(dep[y]>dep[f]) res+=val[y]-1,y=get(y,k); return res; } int main() { // setIO("input"); scanf("%d",&n),m=233; key[0]=1; for(int i=1;i<=22;++i) key[i]=key[i-1]*2; for(int i=1;i<=n;++i) scanf("%lld",&val[i]); for(int i=1;i<n;++i) { int u,v; scanf("%d%d",&u,&v); addedge(u,v),addedge(v,u); } dfs(1,0); for(int i=1;i<=n;++i) { if(val[i]==1) tr.p[i]=fa[0][i]; else tr.p[i]=i; } int Q; scanf("%d",&Q); for(int i=1;i<=Q;++i) { int op,x,y,k; scanf("%d%d%d%d",&op,&x,&y,&k); // printf("%d %d %d %d ",i,op,x,y); if(op==0) update(x,y,k); else printf("%lld ",query(x,y,k)); } return 0; }