每个点的答案为所有经过该点的链的并的大小。得链并即为所有经过该点的链的端点构成的最小连通块,设端点按 (dfs) 序排序后为 (a_i),得最小连通块的边数为:
[large sum_{i=1}^{cnt} dep_{a_i}-sum_{i=1}^{cnt-1} dep_{operatorname{lca}(a_i,a_{i+1})}-dep_{operatorname{lca}(a_1,a_{cnt})}
]
即所有端点的深度减去排序后相邻点的 (operatorname{lca}) 的深度。
用线段树维护 (dfs) 序,添加路径用树上差分,更新信息用线段树合并即可。
#include<bits/stdc++.h>
#define maxn 200010
#define maxm 8000010
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,cnt,tot;
int rt[maxn],f[maxn][19],dep[maxn],dfn[maxn],rev[maxn];
int ls[maxm],rs[maxm],val[maxm],mx[maxm],mn[maxm];
ll ans;
ll sum[maxm];
struct edge
{
int to,nxt;
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to)
{
e[++edge_cnt]={to,head[from]},head[from]=edge_cnt;
}
void dfs_pre(int x,int fa)
{
dep[x]=dep[f[x][0]=fa]+1,rev[dfn[x]=++cnt]=x;
for(int i=1;i<=17;++i) f[x][i]=f[f[x][i-1]][i-1];
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==fa) continue;
dfs_pre(y,x);
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=17;i>=0;--i)
if(f[x][i]&&dep[f[x][i]]>=dep[y])
x=f[x][i];
if(x==y) return x;
for(int i=17;i>=0;--i)
if(f[x][i]&&f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
int get(int x,int y)
{
if(!x||!y) return 0;
return dep[lca(rev[x],rev[y])];
}
void pushup(int cur)
{
mx[cur]=mx[rs[cur]]?mx[rs[cur]]:mx[ls[cur]];
mn[cur]=mn[ls[cur]]?mn[ls[cur]]:mn[rs[cur]];
sum[cur]=sum[ls[cur]]+sum[rs[cur]]-get(mx[ls[cur]],mn[rs[cur]]);
}
void modify(int l,int r,int pos,int v,int &cur)
{
if(!cur) cur=++tot;
if(l==r)
{
if((val[cur]+=v)>0) mx[cur]=mn[cur]=l,sum[cur]=dep[rev[l]];
else mx[cur]=mn[cur]=sum[cur]=0;
return;
}
if(pos<=mid) modify(l,mid,pos,v,ls[cur]);
else modify(mid+1,r,pos,v,rs[cur]);
pushup(cur);
}
int merge(int x,int y,int l,int r)
{
if(!x||!y) return x+y;
if(l==r)
{
if((val[x]+=val[y])>0) mx[x]=mn[x]=l,sum[x]=dep[rev[l]];
else mx[x]=mn[x]=sum[x]=0;
return x;
}
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
pushup(x);
return x;
}
void dfs_ans(int x)
{
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(y==f[x][0]) continue;
dfs_ans(y),rt[x]=merge(rt[x],rt[y],1,n);
}
ans+=sum[rt[x]]-get(mx[rt[x]],mn[rt[x]]);
}
void update(int x,int y,int id)
{
int anc=lca(x,y);
modify(1,n,id,1,rt[x]);
modify(1,n,id,1,rt[y]);
modify(1,n,id,-1,rt[anc]);
if(f[anc][0]) modify(1,n,id,-1,rt[f[anc][0]]);
}
int main()
{
read(n),read(m);
for(int i=1;i<n;++i)
{
int x,y;
read(x),read(y);
add(x,y),add(y,x);
}
dfs_pre(1,0);
for(int i=1;i<=m;++i)
{
int x,y;
read(x),read(y);
update(x,y,dfn[x]),update(x,y,dfn[y]);
}
dfs_ans(1),printf("%lld",ans/2);
return 0;
}