题目:https://www.lydsy.com/JudgeOnline/problem.php?id=2286
https://www.luogu.org/problemnew/show/P2495
学习(抄)了 hzwer 的代码,觉得写得很好。http://hzwer.com/6188.html
有一个 “如果排序后第 i 个关键点和第 i-1 个关键点的 lca 是第 i-1 个关键点,就舍弃第 i 个关键点” 的操作,觉得很好。
把 hd[ ] 数组清空写在了 dfs 里,觉得很好。
自己一开始写了一个倍增找链上边权最小值,用来给虚树的边赋值,参考之后发现只要记录一个 “到根的路径上的最小边权” 就行了。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } ll Mx(ll a,ll b){return a>b?a:b;} ll Mn(ll a,ll b){return a<b?a:b;} const int N=250005,K=20;const ll INF=3e10+5;//for dp int n,hd[N],xnt,to[N<<1],nxt[N<<1],w[N<<1]; int dep[N],pre[N][K],bin[K],dfn[N],tim; ll mn[N]; bool cmp(int a,int b){return dfn[a]<dfn[b];} void add(int x,int y,int z){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;w[xnt]=z;} void dfs(int cr,int fa) { dfn[cr]=++tim; dep[cr]=dep[fa]+1; pre[cr][0]=fa; for(int t=1;bin[t]<=dep[cr];t++) pre[cr][t]=pre[pre[cr][t-1]][t-1]; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { mn[v]=Mn(mn[cr],w[i]); dfs(v,cr); } } int get_lca(int x,int y) { if(dep[x]<dep[y])swap(x,y); int d=dep[x]-dep[y]; for(int t=0;bin[t]<=d;t++) if(d&bin[t])x=pre[x][t]; if(x==y)return x; for(int t=17;t>=0;t--) if(pre[x][t]!=pre[y][t]) x=pre[x][t],y=pre[y][t]; return pre[x][0]; } namespace Tr{ int hd[N],xnt,to[N],nxt[N]; int p[N],tot,sta[N],top; ll dp[N]; void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void get_tr() { xnt=0; sort(p+1,p+tot+1,cmp); int lm=tot; p[tot=1]=p[1]; for(int i=2;i<=lm;i++) if(get_lca(p[i],p[tot])!=p[tot])p[++tot]=p[i]; sta[top=1]=1; for(int i=1;i<=tot;i++) { int u=p[i], lca=get_lca(u,sta[top]); while(top&&dfn[lca]<dfn[sta[top]]) { if(dfn[sta[top-1]]<dfn[lca]) add(lca,sta[top]); else add(sta[top-1],sta[top]); top--; } if(sta[top]!=lca)sta[++top]=lca; sta[++top]=u; } for(int i=1;i<top;i++)add(sta[i],sta[i+1]); } void dfs(int cr) { if(!hd[cr]){dp[cr]=mn[cr];return;} dp[cr]=0; for(int i=hd[cr],v;i;i=nxt[i]) { dfs(v=to[i]); dp[cr]+=dp[v]; } hd[cr]=0;////// dp[cr]=Mn(dp[cr],mn[cr]); } void solve() { int k=rdn(); tot=0; for(int i=1,d;i<=k;i++) d=rdn(),p[++tot]=d; get_tr(); dfs(1); printf("%lld ",dp[1]); } } int main() { bin[0]=1;for(int i=1;i<=18;i++)bin[i]=bin[i-1]<<1; n=rdn(); for(int i=1,u,v,z;i<n;i++) u=rdn(),v=rdn(),z=rdn(),add(u,v,z),add(v,u,z); mn[1]=INF; dfs(1,0); int Q=rdn(); while(Q--)Tr::solve(); return 0; }