分析
首先我们知道如果在一棵树上加一条边一定会构成一个环,而删掉环上任意一条边都不改变连通性。我们把这一性质扩展到这个题上不难发现如果一条树边不在任意一个新边构成的环里则删掉这条边之后可以删掉任意一条新边,对方案数的贡献是m。而如果它只在一个新边构成的环中则要删除这条边和对应的新边,对方案数的贡献是1。而如果它在至少两个新边构成的环中则无论如何也不能将图分成两半,所以对方案数的贡献为0。在知道这些之后我们考虑如何维护一条边在几个由新边构成的环中,那我们自然考虑到了LCA,对于每一条新边将其LCA路径上的边的值都加1.所以我们只需要维护这个值就行了。据说可以用倍增+差分维护,但我并不会,我是用树剖维护的。我们考虑对于原来的树,除根节点外的每一个点入度一定为1,所以我们不在边上累加答案,而用这条边连接的两个点中深度较深的点来代表这条边,最后用2~n这几个点上的值便可以求出方案数。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<ctime>
#include<vector>
#include<set>
#include<map>
#include<stack>
using namespace std;
const int LOG = 20;
vector<int>v[100010];
int ans,id[100010],col[440000],n,m;
int son[100010],siz[100010],cnt,dep[100010],fa[100010],acc[100010];
inline void dfs(int x,int la){
int maxn=0;siz[x]=1;
for(int i=0;i<v[x].size();i++)
if(v[x][i]!=la){
fa[v[x][i]]=x;
dep[v[x][i]]=dep[x]+1;
dfs(v[x][i],x);
siz[x]+=siz[v[x][i]];
if(siz[v[x][i]]>maxn){
maxn=siz[v[x][i]];
son[x]=v[x][i];
}
}
return;
}
inline void dfs2(int x,int ac){
id[x]=++cnt;
acc[x]=ac;
if(!son[x])return;
dfs2(son[x],ac);
for(int i=0;i<v[x].size();i++)
if(v[x][i]!=fa[x]&&v[x][i]!=son[x])
dfs2(v[x][i],v[x][i]);
return;
}
inline void update(int le,int ri,int wh,int x,int y,int k){
if(x>y)return;
if(le>=x&&ri<=y){
col[wh]+=k;
return;
}
int mid=(le+ri)>>1;
if(col[wh]){
col[wh<<1]+=col[wh];
col[wh<<1|1]+=col[wh];
col[wh]=0;
}
if(mid>=x)update(le,mid,wh<<1,x,y,k);
if(mid<y)update(mid+1,ri,wh<<1|1,x,y,k);
return;
}
inline int q(int le,int ri,int wh,int pl){
if(le==ri)return col[wh];
int mid=(le+ri)>>1,ans;
if(col[wh]){
col[wh<<1]+=col[wh];
col[wh<<1|1]+=col[wh];
col[wh]=0;
}
if(mid>=pl)ans=q(le,mid,wh<<1,pl);
else ans=q(mid+1,ri,wh<<1|1,pl);
return ans;
}
inline void solve(int x,int y){
while(acc[x]!=acc[y]){
if(dep[acc[x]]<dep[acc[y]])swap(x,y);
update(1,n,1,id[acc[x]],id[x],1);
x=fa[acc[x]];
}
if(id[x]>id[y])swap(x,y);
update(1,n,1,id[x]+1,id[y],1);
return;
}
int main(){
int i,j,k;
scanf("%d%d",&n,&m);
for(i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
fa[1]=0,dep[1]=1;
dfs(1,0);
dfs2(1,1);
for(i=1;i<=m;i++){
int x,y;
scanf("%d%d",&x,&y);
solve(x,y);
}
for(i=2;i<=n;i++){
int x=q(1,n,1,id[i]);
if(x==0)ans+=m;
else if(x==1)ans+=1;
}
printf("%d
",ans);
return 0;
}