• Snow的追寻--线段树维护树的直径


    Snow终于得知母亲是谁,他现在要出发寻找母亲。王国中的路由于某种特殊原因,成为了一棵有n个节点的根节点
    为1的树,但由于"Birds are everywhere.",他得到了种种不一样的消息,每份消息中都会告诉他有两棵子树是禁
    忌之地,于是他向你求助了。他给出了q个形如"x y"的询问,表示他不能走到x和y的子树中,由于走的路径越长他
    遇见母亲的概率越大但是他只能走一条不经过重复节点的路径,现在他想知道对于每组询问他能走的最长路径是多
    少,如果没有,输出零。
    第一行两个正整数n和q(1≤n,q≤100000)
    第二到第n行每行两个整数u,v表示u和v之间有一条边连接,边的长度为1。
    接下来q行每行两个x,y表示一组询问,意义如题目描述。
    1≤n≤100000,1<=q<=50000
    Output
    q行,输出见题目描述
    Sample Input
    5 2
    1 3
    3 2
    3 4
    2 5
    2 4
    5 4
    Sample Output
    1
    2
    样例解释
    询问1中2和4的子树不能走,最长路径为(1,3)长度为1
    询问2中5和4的子树不能走,最长路径为(1,3,2)长度为2

    Sol:

    很明显的每个询问就是在求将两棵子树去掉后剩下的树的直径。我们先可以得出该树的dfs序,那么对于一颗子树就变成了序列上的一个区间,那么我们可以用线段树,维护一个区间表示的点的直径,对于两个区间,直径的合并就是从四个端点中任选两个连成的路径,选出其中长度最长的,即为合并后的直径,时间复杂度O(N*log2N)

    /*
    对树进行dfs遍历,形成一个长度为N的序列。
    要去掉的两个子树,在dfs序中是连续的。
    从整个序列中去掉这两个序列,可能形成二个或三个连续的序列
    对序列进行合并求直径。
    每个序列有左右端点,形成的新的直径,有四种选择。对于端点之间的距离利用lca来求就好了。
    对于文后图标样例,形成一个dfs序列,其中45及78是要去掉的
    12 45 6 78 3  
    于是合并12 6 3这三个区间就好了 
    */
    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    #define ls now<<1,l,mid
    #define rs now<<1|1,mid+1,r
    #define rep(i,x) for(int i=head[x],v=e[i].to;i;i=e[i].nxt,v=e[i].to)
    using namespace std;
    const int maxn=100010;
    struct fk
    {
    	int to,nxt;
    }
    e[maxn<<1];
    int cnt,n,q,tot,head[maxn],dfn[maxn],p[maxn],last[maxn],dep[maxn],f[maxn][20];
    struct fq{int sum,x,y;}
    t[maxn<<2],ans;
    void ins(int u,int v)
    {
    	e[++cnt].to=v;
    	e[cnt].nxt=head[u];
    	head[u]=cnt;
    }
    void dfs(int x,int fa)
    {
        dfn[x]=++tot;//x进入的时间点 
    	p[tot]=x;//第tot个点是x 
    	f[x][0]=fa;
    	dep[x]=dep[fa]+1;
        rep(i,x)
    	    if(v!=fa)
    		    dfs(v,x);
    	last[x]=tot;
    }
    int lca(int x,int y)
    {
        if(dep[x]<dep[y])
    	   swap(x,y);
        for(int i=19;i>=0;i--)
    	   x=dep[f[x][i]]>dep[y]?f[x][i]:x;
        if(dep[x]>dep[y])
    	    x=f[x][0];
        for(int i=19;i>=0;i--)
    	    if(f[x][i]!=f[y][i])
    		     x=f[x][i],y=f[y][i];
        return x==y?x:f[x][0];
    }
    int dis(int x,int y) //求x,y两点的距离 
    {
    	if(!x||!y)
    	   return 0;
    	int z=lca(x,y);
    	    return dep[x]+dep[y]-2*dep[z];}
    void merge(fq &now,fq x,fq y)
    //将x,y所代表的区间进行合并,结果放到now中 
    {
        int a,b,c,d,e;
        a=dis(x.x,y.x);//新直径可能为x左点与y左点的距离 
    	b=dis(x.x,y.y);//新直径可能为x左点与y右点的距离
    	c=dis(x.y,y.x);
    	d=dis(x.y,y.y);
        e=max(a,max(b,max(c,d)));//取最大值 
        if(a==e)
    	    now.x=x.x,now.y=y.x,now.sum=a;
    	if(b==e)
    	    now.x=x.x,now.y=y.y,now.sum=b;
        if(c==e)
    	   now.x=x.y,now.y=y.x,now.sum=c;
    	if(d==e)
    	   now.x=x.y,now.y=y.y,now.sum=d;
        if(x.sum>now.sum)//x区间的直径大于之 
    	   now.x=x.x,now.y=x.y,now.sum=x.sum;
        if(y.sum>now.sum)//y区间的直径大于之
    	   now.x=y.x,now.y=y.y,now.sum=y.sum;
        if(!now.sum)
    	   now.x=now.y=0;
    }
    void build(int now,int l,int r)
    {
        if(l==r)
    	 {
    			t[now].x=t[now].y=p[l];
    			return ;
    	 }
        int mid=(l+r)>>1;
    	build(ls);
    	build(rs);
    	merge(t[now],t[now<<1],t[now<<1|1]);
    }
    void get_ans(int now,int l,int r,int x,int y)
    //get_ans(1,1,n,1,dfn[u]-1);
    //now根结点编号,l,r左右区间 
    {
        if(x<=l&&r<=y)
    	  {
    			merge(ans,ans,t[now]);
    			return ;
    	  }
        int mid=(l+r)>>1;
        if(x<=mid)
    	   get_ans(ls,x,y);
        if(y>mid)
    	   get_ans(rs,x,y);
    }
    int main()
    {
        scanf("%d%d",&n,&q);
    	int u,v;
        for(int i=1;i<n;i++)
    	     scanf("%d%d",&u,&v),ins(u,v),ins(v,u);
        dfs(1,0);
    	for(int j=1;j<20;j++)
    	    for(int i=1;i<=n;i++)
    		     f[i][j]=f[f[i][j-1]][j-1];
        build(1,1,n);
        while(q--)
        {
            scanf("%d%d",&u,&v);
            if(v==1||u==1) //去掉的是根结点 
    		  {
    				puts("0");
    				continue;
    		  }
            ans.sum=ans.x=ans.y=0;
            if(dfn[u]>dfn[v]) //让u进入的时间更小 
    		   swap(u,v);
            get_ans(1,1,n,1,dfn[u]-1);//从1开始到u进入前的 
            get_ans(1,1,n,last[u]+1,dfn[v]-1);//从u离开后v进来之前 
            if(last[v]<=last[u])
    		//看谁离开的时间更大,从离开后的那个时间到n之一段也要加进来 
    		    get_ans(1,1,n,last[u]+1,n);
            else 
    		     get_ans(1,1,n,last[v]+1,n);
            printf("%d
    ",ans.sum);
        }
    }
    

      

     参考下这个文章:https://blog.csdn.net/rzO_KQP_Orz/article/details/52280811

  • 相关阅读:
    显示非模式窗口和模式窗口
    delphi 版本号
    数字证书和签名
    DLL知道自己的位置
    拖动处理
    驱动配置相关
    python sturct模块操作C数据
    Lambda学习笔记
    【转】update select
    [转]视频格式分类
  • 原文地址:https://www.cnblogs.com/cutemush/p/11830887.html
Copyright © 2020-2023  润新知