解法:
一个比较直观的想法是把有宝物点的虚树构出来,答案就是虚树中边的总长度*2。
但动态维护虚树比较麻烦。经过观察我们可以发现:只要出发点在虚树之中,那么答案就不会变,而且走的路径是一个环,即一个经过所有点的欧拉回路。
即我们把欧拉遍历(或dfn序)中相邻两点的距离加起来,再加上首尾距离,就是答案。
所以我们用一棵splay维护dfn序,复杂度O(MlogN)
#include<cstdio> #include<cstdlib> #include<algorithm> #include<cmath> #include<cstring> using namespace std; typedef long long ll; struct Tree{ int son[2],fa,key; }tr[200011]; ll f[100011],Ans,dep[100011]; int len[200022],next[200022],y[200022],g[100011],que[100011]; int fa[100011][17],dfn[100011],num[100011],th[100011],D[100011]; int n,m,tot,T,j,rs,x,z,q,i,tt,sum,dst,ls,root; bool pl; void star(int i,int j,int k) { tt++; next[tt]=g[i]; g[i]=tt; y[tt]=j; len[tt]=k; } void dfs(int x) { int j,k; dfn[x]=++T; num[T]=x; j=g[x]; while(j!=0){ k=y[j]; if(k!=fa[x][0]){ fa[k][0]=x; D[k]=D[x]+1; dep[k]=dep[x]+len[j]; dfs(k); } j=next[j]; } } int get(int x,int z) { int i,l,e; if(D[x]<D[z])swap(x,z); l=D[x]-D[z]; e=0; while(l){ if(l%2==1)x=fa[x][e]; l/=2; e++; } if(x==z)return x; for(i=16;i>=0;i--)if(fa[x][i]!=fa[z][i]){ x=fa[x][i]; z=fa[z][i]; } return fa[x][0]; } ll dis(int x,int z) { return dep[x]+dep[z]-2*dep[get(x,z)]; } void Find(int x,int z) { if(tr[x].key==z){ dst=x; return; } if(tr[x].key>z)Find(tr[x].son[0],z); else Find(tr[x].son[1],z); } void ins(int &x,int z,int ls) { if(x==0){ x=++tot; tr[x].fa=ls; tr[x].key=z; return; } if(tr[x].key>z)ins(tr[x].son[0],z,x); else ins(tr[x].son[1],z,x); } void rotate(int x) { int z,e; z=tr[x].fa; e=tr[z].son[1]==x; tr[z].son[e]=tr[x].son[e^1]; tr[tr[x].son[e^1]].fa=z; tr[x].son[e^1]=z; if(tr[z].fa!=0){ e=tr[tr[z].fa].son[1]==z; tr[tr[z].fa].son[e]=x; } tr[x].fa=tr[z].fa; tr[z].fa=x; } void splay(int x) { int nt,ft,e1,e2; while(tr[x].fa!=0){ if(tr[tr[x].fa].fa==0)rotate(x); else{ ft=tr[x].fa; nt=tr[ft].fa; e1=tr[ft].son[1]==x; e2=tr[nt].son[1]==ft; if(e1==e2)rotate(ft),rotate(x); else rotate(x),rotate(x); } } root=x; } int findlst(int x) { x=tr[x].son[0]; while(tr[x].son[1])x=tr[x].son[1]; return x; } int findnxt(int x) { x=tr[x].son[1]; while(tr[x].son[0])x=tr[x].son[0]; return x; } int findmax(int x) { while(tr[x].son[1])x=tr[x].son[1]; return tr[x].key; } int findmin(int x) { while(tr[x].son[0])x=tr[x].son[0]; return tr[x].key; } void del(int x) { Find(root,dfn[x]); splay(dst); ls=findlst(dst); if(!ls)ls=findmax(root); else ls=tr[ls].key; rs=findnxt(dst); if(!rs)rs=findmin(root); else rs=tr[rs].key; Ans=Ans-dis(num[ls],num[tr[dst].key])-dis(num[rs],num[tr[dst].key]); Ans+=dis(num[ls],num[rs]); if(findnxt(dst)==0){ root=tr[dst].son[0]; tr[root].fa=0; } else{ root=findnxt(dst); splay(root); if(tr[dst].son[0])tr[tr[dst].son[0]].fa=root; tr[root].son[0]=tr[dst].son[0]; } } void Work() { int i,l,r,x,j,k,lx,mx,mn; Ans=0; root=0; for(i=1;i<=m;i++){ scanf("%d",&x); lx=th[x]; th[x]^=1; if(th[x]==0)del(x); else{ ins(root,dfn[x],0); splay(tot); ls=findlst(tot); if(!ls)ls=findmax(root); else ls=tr[ls].key; rs=findnxt(tot); if(!rs)rs=findmin(root); else rs=tr[rs].key; Ans-=dis(num[ls],num[rs]); Ans+=dis(num[ls],x)+dis(x,num[rs]); } printf("%lld ",Ans); } } int main() { scanf("%d%d",&n,&m); pl=true; for(i=1;i<n;i++){ scanf("%d%d%d",&x,&z,&q); star(x,z,q); star(z,x,q); } dfs(1); for(i=1;i<=16;i++) for(j=1;j<=n;j++)fa[j][i]=fa[fa[j][i-1]][i-1]; Work(); }