显然的想法是对每个点求出能通过某种语言到的点个数,然后加起来(/2)就是答案.每次加入一条路径,就可以更新路径上所有点到达其他点的状态.那个我们用线段树维护,每次对路径上所有点的线段树上该路径对应的dfn区间覆盖(用树剖处理),最后统计每个线段树上有值的位置个数
注意每次是对一条路径上的线段树操作,路径修改可以联想到树上差分,即两端点做正权修改,lca和lca的父亲做负权修改.本题类似,我们在两端点处对对应区间执行+1操作,在lca和lca的父亲处对对应区间执行-1,然后套个线段树合并,就可以在每个节点处统计答案
// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<cmath>
#include<ctime>
#include<queue>
#include<map>
#include<set>
#define LL long long
#define db double
using namespace std;
const int N=1e5+10;
int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
int to[N<<1],nt[N<<1],hd[N],tot=1;
void add(int x,int y)
{
++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;
++tot,to[tot]=x,nt[tot]=hd[y],hd[y]=tot;
}
int fa[N],de[N],sz[N],hs[N],top[N],dfn[N],ti;
void dfs1(int x)
{
sz[x]=1;
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(y==fa[x]) continue;
fa[y]=x,de[y]=de[x]+1,dfs1(y);
sz[x]+=sz[y],hs[x]=sz[hs[x]]>sz[y]?hs[x]:y;
}
}
void dfs2(int x,int ntp)
{
dfn[x]=++ti,top[x]=ntp;
if(hs[x]) dfs2(hs[x],ntp);
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(y==fa[x]||y==hs[x]) continue;
dfs2(y,y);
}
}
int sb[N*200],ch[N*200][2],tg[N*200],rt[N],tt;
#define mid ((l+r)>>1)
void psup(int o,int len){sb[o]=tg[o]?len:sb[ch[o][0]]+sb[ch[o][1]];}
void modif(int &o,int l,int r,int ll,int rr,int x)
{
if(!o) o=++tt;
if(ll<=l&&r<=rr){tg[o]+=x,psup(o,r-l+1);return;}
if(ll<=mid) modif(ch[o][0],l,mid,ll,rr,x);
if(rr>mid) modif(ch[o][1],mid+1,r,ll,rr,x);
psup(o,r-l+1);
}
int merge(int o1,int o2,int l,int r)
{
if(!o1||!o2) return o1+o2;
tg[o1]+=tg[o2];
ch[o1][0]=merge(ch[o1][0],ch[o2][0],l,mid);
ch[o1][1]=merge(ch[o1][1],ch[o2][1],mid+1,r);
psup(o1,r-l+1);
return o1;
}
struct qu
{
int l,r,x;
};
vector<qu> qq[N];
int glca(int x,int y)
{
while(top[x]!=top[y])
{
if(de[top[x]]<de[top[y]]) swap(x,y);
x=fa[top[x]];
}
return de[x]<de[y]?x:y;
}
int n,m;
LL ans;
void dfs3(int x)
{
for(int i=hd[x];i;i=nt[i])
{
int y=to[i];
if(y==fa[x]) continue;
dfs3(y),rt[x]=merge(rt[x],rt[y],1,n);
}
int nn=qq[x].size();
for(int i=0;i<nn;++i) modif(rt[x],1,n,qq[x][i].l,qq[x][i].r,qq[x][i].x);
ans+=max(sb[rt[x]]-1,0);
}
int main()
{
n=rd(),m=rd();
for(int i=1;i<n;++i) add(rd(),rd());
de[1]=1,dfs1(1),dfs2(1,1);
while(m--)
{
int x=rd(),y=rd(),lca=glca(x,y),xx=x,yy=y;
while(top[xx]!=top[yy])
{
if(de[top[xx]]<de[top[yy]]) swap(xx,yy);
qq[x].push_back((qu){dfn[top[xx]],dfn[xx],1});
qq[y].push_back((qu){dfn[top[xx]],dfn[xx],1});
qq[lca].push_back((qu){dfn[top[xx]],dfn[xx],-1});
qq[fa[lca]].push_back((qu){dfn[top[xx]],dfn[xx],-1});
xx=fa[top[xx]];
}
if(de[xx]>de[yy]) swap(xx,yy);
qq[x].push_back((qu){dfn[xx],dfn[yy],1});
qq[y].push_back((qu){dfn[xx],dfn[yy],1});
qq[lca].push_back((qu){dfn[xx],dfn[yy],-1});
qq[fa[lca]].push_back((qu){dfn[xx],dfn[yy],-1});
}
dfs3(1);
printf("%lld
",ans>>1);
return 0;
}