• 【Learning】虚树题目汇总


    这里主要是我虚树的练习记录。
    关于虚树的建树,参考了 https://www.cnblogs.com/chenhuan001/p/5639482.html 。感谢!至于为什么虚树建树要打那么长,只是为了方便理解。

    套路:建虚树后树形DP,或者在建虚树过程中统计答案。

    fread好快啊!
    T1 xsy1633
    建树时直接统计即可。

    #include<cstdio>
    #include<algorithm>
    using namespace std;
    const int N=100005;
    int n,m,k,u,v,d,cnt,idx,pos,p[N],head[N],to[N*2],nxt[N*2],dd[N*2];
    int siz[N],son[N],fa[N],dep[N],dis[N],dfn[N],top[N],stk[N];
    char ch[20000005];
    inline int rd(){
        int ret=0;
        while(ch[pos]<'0'||ch[pos]>'9'){
            pos++;
        }
        while(ch[pos]>='0'&&ch[pos]<='9'){
            ret=ret*10+ch[pos]-'0';
            pos++;
        }
        return ret;
    }
    bool cmp(int a,int b){
        return dfn[a]<dfn[b];
    }
    void adde(int u,int v,int d){
        to[++cnt]=v;
        nxt[cnt]=head[u];
        dd[cnt]=d;
        head[u]=cnt;
    }
    void dfs(int u){
        siz[u]=1;
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u]){
                fa[v]=u;
                dep[v]=dep[u]+1;
                dis[v]=dis[u]+dd[i];
                dfs(v);
                siz[u]+=siz[v];
                if(!son[u]||siz[son[u]]<siz[v]){
                    son[u]=v;
                }
            }
        }
    }
    void dfs(int u,int tp){
        dfn[u]=++idx;
        top[u]=tp;
        if(son[u]){
            dfs(son[u],tp);
        }
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u]&&v!=son[u]){
                dfs(v,v);
            }
        }
    }
    int lca(int u,int v){
        while(top[u]!=top[v]){
            if(dep[top[u]]>dep[top[v]]){
                u=fa[top[u]];
            }else{
                v=fa[top[v]];
            }
        }
        if(dep[u]<dep[v]){
            return u;
        }else{
            return v;
        }
    }
    void solve(){
        int ans=0;
        stk[stk[0]=1]=1;
        for(int i=1;i<=k;i++){
            int tmp=lca(stk[stk[0]],p[i]);
            if(tmp!=stk[stk[0]]){
                while(stk[0]>1){
                    if(dfn[stk[stk[0]-1]]>dfn[tmp]){
                        ans+=dis[stk[stk[0]]]-dis[stk[stk[0]-1]];
                        stk[0]--;
                    }else if(stk[stk[0]-1]==tmp){
                        ans+=dis[stk[stk[0]]]-dis[tmp];
                        stk[0]--;
                        break;
                    }else if(dfn[stk[stk[0]-1]]<dfn[tmp]){
                        ans+=dis[stk[stk[0]]]-dis[tmp];
                        stk[0]--;
                        stk[++stk[0]]=tmp;
                        break;
                    }
                }
            }
            stk[++stk[0]]=p[i];
        }
        while(stk[0]>1){
            ans+=dis[stk[stk[0]]]-dis[stk[stk[0]-1]];
            stk[0]--;
        }
        printf("%d
    ",ans);
    }
    int main(){
        fread(ch,20000000,1,stdin);
        n=rd();
        for(int i=1;i<n;i++){
            u=rd(),v=rd(),d=rd();
            adde(u,v,d);
            adde(v,u,d);
        }
        dfs(1);
        dfs(1,1);
        m=rd();
        for(int i=1;i<=m;i++){
            k=rd();
            for(int j=1;j<=k;j++){
                p[j]=rd();
            }
            sort(p+1,p+k+1,cmp);
            solve();
        }
        return 0;
    }

    T2 bzoj2286 虚树+树形dp

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    const int N=250005;
    typedef long long ll;
    int n,m,k,u,v,d,cnt,idx,pos,h[N],stk[N],head[N],to[N*2],nxt[N*2],dd[N*2];
    int siz[N],son[N],fa[N],dep[N],top[N],dfn[N];
    ll mind[N],f[N];
    bool ck[N];
    char ch[20000005];
    inline int rd(){
        register int ret=0;
        while(ch[pos]<'0'||ch[pos]>'9'){
            pos++;
        }
        while(ch[pos]>='0'&&ch[pos]<='9'){
            ret=ret*10+ch[pos]-'0';
            pos++;
        }
        return ret;
    }
    bool cmp(int a,int b){
        return dfn[a]<dfn[b];
    }
    void adde(int u,int v,int d){
        to[++cnt]=v;
        nxt[cnt]=head[u];
        dd[cnt]=d;
        head[u]=cnt;
    }
    void dfs(int u){
        siz[u]=1;
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u]){
                fa[v]=u;
                dep[v]=dep[u]+1;
                mind[v]=min(mind[u],1LL*dd[i]);
                dfs(v);
                siz[u]+=siz[v];
                if(!son[u]||siz[son[u]]<siz[v]){
                    son[u]=v;
                }
            }
        }
    }
    void dfs(int u,int tp){
        dfn[u]=++idx;
        top[u]=tp;
        if(son[u]){
            dfs(son[u],tp);
        }
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u]&&v!=son[u]){
                dfs(v,v);
            }
        }
    }
    int lca(int u,int v){
        while(top[u]!=top[v]){
            if(dep[top[u]]>dep[top[v]]){
                u=fa[top[u]];
            }else{
                v=fa[top[v]];
            }
        }
        if(dep[u]<dep[v]){
            return u;
        }else{
            return v;
        }
    }
    void build(){
        stk[stk[0]=1]=1;
        for(int i=1;i<=k;i++){
            int tmp=lca(stk[stk[0]],h[i]);
            if(tmp!=stk[stk[0]]){
                while(stk[0]>1){
                    if(dfn[stk[stk[0]-1]]>dfn[tmp]){
                        adde(stk[stk[0]-1],stk[stk[0]],0);
                        stk[0]--;
                    }else if(stk[stk[0]-1]==tmp){
                        adde(tmp,stk[stk[0]],0);
                        stk[0]--;
                        break;
                    }else if(dfn[stk[stk[0]-1]]<dfn[tmp]){
                        adde(tmp,stk[stk[0]],0);
                        stk[0]--;
                        stk[++stk[0]]=tmp;
                        break;
                    }
                }
            }
            stk[++stk[0]]=h[i];
        }
        while(stk[0]>1){
            adde(stk[stk[0]-1],stk[stk[0]],0);
            stk[0]--;
        }
    }
    void dp(int u){
        ll sum=0,v;
        f[u]=mind[u];
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            dp(v);
            sum+=f[v];
        }
        if(!ck[u]){
            f[u]=min(sum,f[u]);
        }
        head[u]=ck[u]=0;
    }
    int main(){
        fread(ch,20000000,1,stdin);
        n=rd();
        for(int i=1;i<n;i++){
            u=rd(),v=rd(),d=rd();
            adde(u,v,d);
            adde(v,u,d);
        }
        mind[1]=1e15;
        dfs(1);
        dfs(1,1);
        m=rd();
        memset(head,0,sizeof(head));
        for(int i=1;i<=m;i++){
            k=rd();
            for(int j=1;j<=k;j++){
                h[j]=rd();
                ck[h[j]]=true;
            }
            sort(h+1,h+k+1,cmp);
            cnt=0;
            build();
            dp(1);
            printf("%lld
    ",f[1]);
        }
        return 0;
    }

    T3 bzoj3611 虚树+树形dp

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    const int N=1000005,inf=0x7f7f7f7f;
    int n,m,k,u,v,cnt,idx,pos,p[N],head[N],to[N*2],nxt[N*2];
    int siz[N],son[N],fa[N],dep[N],top[N],dfn[N],stk[N];
    long long sum,maxn,minn,mx[N],mi[N];
    bool ck[N];
    char ch[100000005];
    inline int rd(){
        int ret=0;
        while(ch[pos]<'0'||ch[pos]>'9'){
            pos++;
        }
        while(ch[pos]>='0'&&ch[pos]<='9'){
            ret=ret*10+ch[pos]-'0';
            pos++;
        }
        return ret;
    }
    bool cmp(int a,int b){
        return dfn[a]<dfn[b];
    }
    void adde(int u,int v){
        to[++cnt]=v;
        nxt[cnt]=head[u];
        head[u]=cnt;
    }
    void dfs(int u){
        siz[u]=1;
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u]){
                fa[v]=u;
                dep[v]=dep[u]+1;
                dfs(v);
                siz[u]+=siz[v];
                if(!son[u]||siz[son[u]]<siz[v]){
                    son[u]=v;
                }
            }
        }
    }
    void dfs(int u,int tp){
        dfn[u]=++idx;
        top[u]=tp;
        if(son[u]){
            dfs(son[u],tp);
        }
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u]&&v!=son[u]){
                dfs(v,v);
            }
        }
    }
    int lca(int u,int v){
        while(top[u]!=top[v]){
            if(dep[top[u]]>dep[top[v]]){
                u=fa[top[u]];
            }else{
                v=fa[top[v]];
            }
        }
        if(dep[u]<dep[v]){
            return u;
        }else{
            return v;
        }
    }
    void dfssize(int u){
        siz[u]=ck[u];
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            dfssize(v);
            siz[u]+=siz[v];
        }
    }
    void dp(int u){
        int v;
        mx[u]=ck[u]?0:-inf;
        mi[u]=ck[u]?0:inf;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            dp(v);
            sum+=1LL*siz[v]*(siz[1]-siz[v])*(dep[v]-dep[u]);
            maxn=max(maxn,mx[u]+mx[v]+dep[v]-dep[u]);
            minn=min(minn,mi[u]+mi[v]+dep[v]-dep[u]);
            mx[u]=max(mx[u],mx[v]+dep[v]-dep[u]);
            mi[u]=min(mi[u],mi[v]+dep[v]-dep[u]);
        }
        head[u]=ck[u]=0;
    }
    void solve(){
        stk[stk[0]=1]=1;
        for(int i=1;i<=k;i++){
            int tmp=lca(stk[stk[0]],p[i]);
            if(tmp!=stk[stk[0]]){
                while(stk[0]>1){
                    if(dfn[stk[stk[0]-1]]>dfn[tmp]){
                        adde(stk[stk[0]-1],stk[stk[0]]);
                        stk[0]--;
                    }else if(stk[stk[0]-1]==tmp){
                        adde(tmp,stk[stk[0]]);
                        stk[0]--;
                        break;
                    }else if(dfn[stk[stk[0]-1]]<dfn[tmp]){
                        adde(tmp,stk[stk[0]]);
                        stk[stk[0]]=tmp;
                        break;
                    }
                }
            }
            if(stk[stk[0]]!=p[i]){
                stk[++stk[0]]=p[i];
            }
        }
        while(stk[0]>1){
            adde(stk[stk[0]-1],stk[stk[0]]);
            stk[0]--;
        }
        sum=0;
        minn=inf;
        maxn=-inf;
        dfssize(1);
        dp(1);
        printf("%lld %lld %lld
    ",sum,minn,maxn);
    }
    int main(){
        fread(ch,100000000,1,stdin);
        n=rd();
        for(int i=1;i<n;i++){
            u=rd(),v=rd();
            adde(u,v);
            adde(v,u);
        }
        dfs(1);
        dfs(1,1);
        memset(head,0,sizeof(head));
        m=rd();
        for(int i=1;i<=m;i++){
            k=rd();
            for(int j=1;j<=k;j++){
                p[j]=rd();
                ck[p[j]]=true;
            }
            sort(p+1,p+k+1,cmp);
            cnt=0;
            solve();
        }
        return 0;
    }

    T4 bzoj3572 虚树+树形dp+倍增

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    const int N=300005;
    int n,m,k,u,v,cnt,idx,pos,p[N],head[N],to[N*2],nxt[N*2],dep[N],fa[N][20],dfn[N],siz[N],stk[N];
    int bel[N],left[N],sum[N],tmp[N];
    char ch[20000005];
    inline int rd(){
        register int ret=0;
        while(ch[pos]<'0'||ch[pos]>'9'){
            pos++;
        }
        while(ch[pos]>='0'&&ch[pos]<='9'){
            ret=ret*10+ch[pos]-'0';
            pos++;
        }
        return ret;
    }
    bool cmp(int a,int b){
        return dfn[a]<dfn[b];
    }
    void adde(int u,int v){
        to[++cnt]=v;
        nxt[cnt]=head[u];
        head[u]=cnt;
    }
    void dfs(int u){
        dfn[u]=++idx;
        siz[u]=1;
        for(int i=1;(1<<i)<=dep[u];i++){
            fa[u][i]=fa[fa[u][i-1]][i-1];
        }
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u][0]){
                fa[v][0]=u;
                dep[v]=dep[u]+1;
                dfs(v);
                siz[u]+=siz[v];
            }
        }
    }
    int lca(int u,int v){
        if(dep[u]<dep[v]){
            swap(u,v);
        }
        int d=dep[u]-dep[v];
        for(int i=0;i<=19;i++){
            if(d&(1<<i)){
                u=fa[u][i];
            }
        }
        if(u==v){
            return u;
        }
        for(int i=19;i>=0;i--){
            if(fa[u][i]!=fa[v][i]){
                u=fa[u][i];
                v=fa[v][i];
            }
        }
        return fa[u][0];
    }
    int dis(int u,int v){
        return dep[u]+dep[v]-2*dep[lca(u,v)];
    }
    void build(){
        stk[stk[0]=1]=1;
        for(int i=1;i<=k;i++){
            int tmp=lca(stk[stk[0]],p[i]);
            if(tmp!=stk[stk[0]]){
                while(stk[0]>1){
                    if(dfn[stk[stk[0]-1]]>dfn[tmp]){
                        adde(stk[stk[0]-1],stk[stk[0]]);
                        stk[0]--;
                    }else if(stk[stk[0]-1]==tmp){
                        adde(tmp,stk[stk[0]]);
                        stk[0]--;
                        break;
                    }else if(dfn[stk[stk[0]-1]]<dfn[tmp]){
                        adde(tmp,stk[stk[0]]);
                        stk[stk[0]]=tmp;
                        break;
                    }
                }
            }
            if(p[i]!=stk[stk[0]]){
                stk[++stk[0]]=p[i];
            }
        }
        while(stk[0]>1){
            adde(stk[stk[0]-1],stk[stk[0]]);
            stk[0]--;
        }
    }
    void dp1(int u){
        left[u]=siz[u];
        int v,d1,d2;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            dp1(v);
            if(bel[v]){
                if(!bel[u]){
                    bel[u]=bel[v];
                }else{
                    d1=dis(u,bel[u]),d2=dis(u,bel[v]);
                    if(d2<d1||(d2==d1&&bel[v]<bel[u])){
                        bel[u]=bel[v];
                    }
                }
            }
        }
    }
    void dp2(int u){
        int v,d1,d2;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(!bel[v]){
                bel[v]=bel[u];
            }else{
                d1=dis(v,bel[v]),d2=dis(v,bel[u]);
                if(d2<d1||(d2==d1&&bel[u]<bel[v])){
                    bel[v]=bel[u];
                }
            }
            dp2(v);
        }
    }
    void work(int u,int v){
        int tmp=v,mid=v,nxt,d1,d2;
        for(int i=19;i>=0;i--){
            if(dep[fa[tmp][i]]>dep[u]){
                tmp=fa[tmp][i];
            }
        }
        left[u]-=siz[tmp];
        if(bel[u]==bel[v]){
            sum[bel[u]]+=siz[tmp]-siz[v];
            return;
        }
        for(int i=19;i>=0;i--){
            nxt=fa[mid][i];
            if(dep[nxt]>dep[u]){
                d1=dis(bel[u],nxt);
                d2=dis(bel[v],nxt);
                if(d2<d1||(d2==d1&&bel[v]<bel[u])){
                    mid=nxt;
                }
            }
        }
        sum[bel[u]]+=siz[tmp]-siz[mid];
        sum[bel[v]]+=siz[mid]-siz[v];
    }
    void dp3(int u){
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            work(u,v);
            dp3(v);
        }
        sum[bel[u]]+=left[u];
    }
    void clear(int u){
        for(int i=head[u];i;i=nxt[i]){
            clear(to[i]);
        }
        bel[u]=sum[u]=head[u]=left[u]=0;
    }
    int main(){
        fread(ch,20000000,1,stdin);
        n=rd();
        for(int i=1;i<n;i++){
            u=rd(),v=rd();
            adde(u,v);
            adde(v,u);
        }
        dfs(1);
        m=rd();
        memset(head,0,sizeof(head));
        for(int i=1;i<=m;i++){
            k=rd();
            for(int j=1;j<=k;j++){
                p[j]=rd();
                bel[p[j]]=tmp[j]=p[j];
            }
            sort(p+1,p+k+1,cmp);
            cnt=0;
            build();
            dp1(1);
            dp2(1);
            dp3(1);
            for(int j=1;j<=k;j++){
                printf("%d ",sum[tmp[j]]);
            }
            puts("");
            clear(1);
        }
        return 0;
    }

    T5 codeforses631D 虚树+树形dp
    终于成功注册codeforses,辛辛苦苦把验证码搞出来了

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    const int N=100005;
    int n,m,k,u,v,cnt,idx,p[N],head[N],to[N*2],nxt[N*2];
    int dfn[N],fa[N][20],dep[N],stk[N];
    bool ck[N],exi[N];
    bool cmp(int a,int b){
        return dfn[a]<dfn[b];
    }
    void adde(int u,int v){
        to[++cnt]=v;
        nxt[cnt]=head[u];
        head[u]=cnt;
    }
    void dfs(int u){
        dfn[u]=++idx;
        for(int i=1;(1<<i)<=dep[u];i++){
            fa[u][i]=fa[fa[u][i-1]][i-1];
        }
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u][0]){
                fa[v][0]=u;
                dep[v]=dep[u]+1;
                dfs(v);
            }
        }
    }
    int lca(int u,int v){
        if(dep[u]<dep[v]){
            swap(u,v);
        }
        int d=dep[u]-dep[v];
        for(int i=0;(1<<i)<=d;i++){
            if(d&(1<<i)){
                u=fa[u][i];
            }
        }
        if(u==v){
            return u;
        }
        for(int i=18;i>=0;i--){
            if(fa[u][i]!=fa[v][i]){
                u=fa[u][i];
                v=fa[v][i];
            }
        }
        return fa[u][0];
    }
    void build(){
        stk[stk[0]=1]=1;
        for(int i=1;i<=k;i++){
            int tmp=lca(stk[stk[0]],p[i]);
            if(tmp!=stk[stk[0]]){
                while(stk[0]>1){
                    if(dfn[stk[stk[0]-1]]>dfn[tmp]){
                        adde(stk[stk[0]-1],stk[stk[0]]);
                        stk[0]--;
                    }else if(stk[stk[0]-1]==tmp){
                        adde(tmp,stk[stk[0]]);
                        stk[0]--;
                        break;
                    }else if(dfn[stk[stk[0]-1]]<dfn[tmp]){
                        adde(tmp,stk[stk[0]]);
                        stk[stk[0]]=tmp;
                        break;
                    }
                }
            }
            if(p[i]!=stk[stk[0]]){
                stk[++stk[0]]=p[i];
            }
        }
        while(stk[0]>1){
            adde(stk[stk[0]-1],stk[stk[0]]);
            stk[0]--;
        }
    }
    int dp(int u){
        int sum=0,tot=0,v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            sum+=dp(v);
            tot+=exi[v];
        }
        if(ck[u]){
            exi[u]=1;
            return sum+tot;
        }else{
            exi[u]=tot==1;
            return sum+(tot>1); 
        }
    }
    void clear(int u){
        for(int i=head[u];i;i=nxt[i]){
            clear(to[i]);
        }
        head[u]=ck[u]=exi[u]=0;
    }
    bool judge(){
        for(int i=1;i<=k;i++){
            if(ck[fa[p[i]][0]]){
                return false;
            }
        }
        return true;
    }
    int main(){
        scanf("%d",&n);
        for(int i=1;i<n;i++){
            scanf("%d%d",&u,&v);
            adde(u,v);
            adde(v,u);
        }
        dfs(1);
        memset(head,0,sizeof(head));
        scanf("%d",&m);
        for(int i=1;i<=m;i++){
            scanf("%d",&k);
            for(int j=1;j<=k;j++){
                scanf("%d",&p[j]);
                ck[p[j]]=true;
            }
            if(!judge()){
                puts("-1");
                for(int j=1;j<=k;j++){
                    ck[p[j]]=false;
                }
                continue;
            }
            sort(p+1,p+k+1,cmp);
            cnt=0;
            build();
            printf("%d
    ",dp(1));
            clear(1);
        }
        return 0;
    }

    T6 bzoj3991 乱入的一道dfs序+平衡树
    为何xsy上要把这道题分类带哦虚树上呢?==
    作死手打平衡树233

    #include<cstdio>
    #include<algorithm>
    using namespace std;
    typedef long long ll;
    const int N=100005;
    int n,m,u,v,cnt,idx,pos,dfn[N],fa[N][20],dep[N],head[N],to[N*2],nxt[N*2];
    ll d,ans,dd[N*2],dis[N];
    bool ck[N];
    char ch[20000005];
    inline int rd(){
        register int ret=0;
        while(ch[pos]<'0'||ch[pos]>'9'){
            pos++;
        }
        while(ch[pos]>='0'&&ch[pos]<='9'){
            ret=ret*10+ch[pos]-'0';
            pos++;
        }
        return ret;
    }
    void adde(int u,int v,ll d){
        to[++cnt]=v;
        nxt[cnt]=head[u];
        dd[cnt]=d;
        head[u]=cnt;
    }
    void dfs(int u){
        dfn[u]=++idx;
        for(int i=1;(1<<i)<=dep[u];i++){
            fa[u][i]=fa[fa[u][i-1]][i-1];
        }
        int v;
        for(int i=head[u];i;i=nxt[i]){
            v=to[i];
            if(v!=fa[u][0]){
                fa[v][0]=u;
                dep[v]=dep[u]+1;
                dis[v]=dis[u]+dd[i];
                dfs(v);
            }
        }
    }
    int lca(int u,int v){
        if(dep[u]<dep[v]){
            swap(u,v);
        }
        int d=dep[u]-dep[v];
        for(int i=0;i<=17;i++){
            if(d&(1<<i)){
                u=fa[u][i];
            }
        }
        if(u==v){
            return u;
        }
        for(int i=17;i>=0;i--){
            if(fa[u][i]!=fa[v][i]){
                u=fa[u][i];
                v=fa[v][i];
            }
        }
        return fa[u][0];
    }
    ll dist(int u,int v){
        return dis[u]+dis[v]-2*dis[lca(u,v)];
    }
    struct ScapegoatTree{
        int root,cnt,*goat,val[N],ch[N][2],siz[N],tot[N],del[N],mmp[N],pos[N];
        int rnk(int x){
            int k=root,ret=1;
            while(k){
                if(dfn[x]<=dfn[val[k]]){
                    k=ch[k][0];
                }else{
                    ret+=siz[ch[k][0]]+del[k];
                    k=ch[k][1];
                }
            }
            return ret;
        }
        int kth(int x){
            if(x==0){
                x=siz[root];
            }else if(x==siz[root]+1){
                x=1;
            }
            int k=root;
            while(k){
                if(del[k]&&x==siz[ch[k][0]]+1){
                    return val[k];
                }else if(x<=siz[ch[k][0]]+del[k]){
                    k=ch[k][0];
                }else{
                    x-=siz[ch[k][0]]+del[k];
                    k=ch[k][1];
                }
            }
        }
        int newnode(int &k,int x){
            if(mmp[0]){
                k=mmp[mmp[0]--];
            }else{
                k=++cnt;
            }
            ch[k][0]=ch[k][1]=0;
            siz[k]=tot[k]=del[k]=1;
            val[k]=x;
        }
        void dfs(int k){
            if(!k){
                return;
            }
            dfs(ch[k][0]);
            if(del[k]){
                pos[++pos[0]]=k;
            }else{
                mmp[++mmp[0]]=k;
            }
            dfs(ch[k][1]);
        }
        void build(int &k,int l,int r){
            if(l>r){
                k=0;
                return;
            }
            int mid=(l+r)/2;
            k=pos[mid];
            build(ch[k][0],l,mid-1);
            build(ch[k][1],mid+1,r);
            siz[k]=siz[ch[k][0]]+siz[ch[k][1]]+1;
            tot[k]=tot[ch[k][0]]+tot[ch[k][1]]+1;
        }
        void rebuild(int &k){
            pos[0]=0;
            dfs(k);
            build(k,1,pos[0]);
        }
        void insert(int &k,int x){
            if(!k){
                newnode(k,x);
                return;
            }
            siz[k]++;
            tot[k]++;
            insert(ch[k][dfn[x]>dfn[val[k]]],x);
            if(tot[k]*0.75<max(tot[ch[k][0]],tot[ch[k][1]])){
                goat=&k;
            }
        }
        void insert(int x){
            goat=NULL;
            insert(root,x);
            if(goat){
                rebuild(*goat);
            }
        }
        void remove(int &k,int x){
            siz[k]--;
            if(del[k]&&x==siz[ch[k][0]]+1){
                del[k]=0;
                return;
            }
            if(x<=siz[ch[k][0]]+del[k]){
                remove(ch[k][0],x);
            }else{
                remove(ch[k][1],x-siz[ch[k][0]]-del[k]);
            }
        }
        void remove(int x){
            remove(root,rnk(x));
            if(siz[root]<tot[root]*0.75){
                rebuild(root);
            }
        }
    }sgt;
    int main(){
        fread(ch,20000000,1,stdin);
        n=rd(),m=rd();
        for(int i=1;i<n;i++){
            u=rd(),v=rd(),d=rd();
            adde(u,v,d);
            adde(v,u,d);
        }
        dfs(1);
        for(int i=1;i<=m;i++){
            u=rd();
            if(!ck[u]){
                if(sgt.siz[sgt.root]==0){
                    sgt.insert(u);
                }else{
                    sgt.insert(u);
                    int tmp=sgt.rnk(u),pre=sgt.kth(tmp-1),sub=sgt.kth(tmp+1);
                    ans+=dist(u,pre)+dist(u,sub)-dist(pre,sub);
                }
                ck[u]=true;
            }else{
                if(sgt.siz[sgt.root]==1){
                    sgt.remove(u);
                }else{
                    int tmp=sgt.rnk(u),pre=sgt.kth(tmp-1),sub=sgt.kth(tmp+1);
                    ans-=dist(u,pre)+dist(u,sub)-dist(pre,sub);
                    sgt.remove(u);
                }
                ck[u]=false;
            }
            printf("%lld
    ",ans);
        }
        return 0;
    }
  • 相关阅读:
    threading库知识点补充
    数字中 数组排序
    python 多线程 thread (控制主线程跑完,子线程也关闭) 和 (等待子线程跑完,主线程才关闭)
    进程和线程理解
    上下文与 with语句 (如打开文件open的巧妙写法)
    Python中字符串String去除出换行符和空格的问题( , )
    去掉字符空格的方法
    postman 参数化(含文本)
    python之数组元素去重
    sql中limit使用方法
  • 原文地址:https://www.cnblogs.com/2016gdgzoi471/p/9476907.html
Copyright © 2020-2023  润新知