题目
分析
线段树合并+树上差分。
首先我们发现答案其实就是:对于每一个点来说的连通块大小之和。
那么现在问题在于怎么来维护这个连通块的大小。
我们可以考虑对每一个点开一个线段树,保存:(dfn) 序列对应的点被路径覆盖次数和长度。
然后对于这样一类树上路径修改且每个点都要查询的问题,我们可以考虑使用线段树合并+树上差分来解决。
那么这道题就显而易见了,是直接打上两个 (+1) 标记,然后在 (fa[lca]) 处打上 (-2) 的标记。
接下来就是直接线段树合并,重点在于怎么具体维护有多少个点,我们发现如果当前区间的都是大于 (0) 的话,那么个数就是 (r-l+1) ,因为我们每一次修改的时候,一定是一个连续的链(从下到上)。
这是由我们树剖来决定的,也就是有 (logn) 个区间要进行修改的意思。
具体见代码。
代码
#include<bits/stdc++.h>
using namespace std;
template <typename T>
inline void read(T &x){
x=0;char ch=getchar();bool f=false;
while(!isdigit(ch)){if(ch=='-'){f=true;}ch=getchar();}
while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x=f?-x:x;
return ;
}
template <typename T>
inline void write(T x){
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10^48);
return ;
}
const int N=1e5+5;
#define ll long long
int n,m;
ll Ans;
int head[N],nex[N<<1],to[N<<1],idx;
inline void add(int u,int v){
nex[++idx]=head[u];
to[idx]=v;
head[u]=idx;
return ;
}
int fa[N],dep[N],siz[N],son[N],top[N],dfn[N],rev[N],DFN;
void dfs1(int x,int f){
fa[x]=f,dep[x]=dep[f]+1,siz[x]=1;
for(int i=head[x];i;i=nex[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x);siz[x]+=siz[y];
if(siz[y]>siz[son[x]]) son[x]=y;
}
return ;
}
void dfs2(int x){
if(x==son[fa[x]]) top[x]=top[fa[x]];
else top[x]=x;
dfn[x]=++DFN,rev[DFN]=x;
if(son[x]) dfs2(son[x]);
for(int i=head[x];i;i=nex[i]){
int y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y);
}
return ;
}
inline int QueryLca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return dep[x]<dep[y]?x:y;
}
int Root[N];
struct SGT{
int sum,num,ls,rs;
#define sum(x) t[x].sum
#define num(x) t[x].num
#define ls(x) t[x].ls
#define rs(x) t[x].rs
}t[N*250];
int cur;
void Modify(int &x,int l,int r,int ql,int qr,int v){
if(!x) x=++cur;
if(ql<=l&&qr>=r) return sum(x)+=v,num(x)=(sum(x)>0?(r-l+1):(num(ls(x))+num(rs(x)))),void();
int mid=l+r>>1;
if(ql<=mid) Modify(ls(x),l,mid,ql,qr,v);
if(qr>mid) Modify(rs(x),mid+1,r,ql,qr,v);
num(x)=(sum(x)>0?(r-l+1):(num(ls(x))+num(rs(x))));
return ;
}
int Query(int x,int l,int r,int ql,int qr){
if(!x) return 0;
if(ql<=l&&r<=qr) return num(x);
int mid=l+r>>1,res=0;
if(ql<=mid) res+=Query(ls(x),l,mid,ql,qr);
if(qr>mid) res+=Query(rs(x),mid+1,r,ql,qr);
return res;
}
int Merge(int x,int y,int l,int r){
if(!x||!y) return x|y;
sum(x)+=sum(y);int mid=l+r>>1;
ls(x)=Merge(ls(x),ls(y),l,mid),rs(x)=Merge(rs(x),rs(y),mid+1,r);
num(x)=(sum(x)>0?(r-l+1):(num(ls(x))+num(rs(x))));
return x;
}
typedef pair<int,int> PII;
PII path[N];
int Cnt;
void GetSeq(int x,int y){
Cnt=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
path[++Cnt]=make_pair(dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
path[++Cnt]=make_pair(dfn[y],dfn[x]);
return ;
}
void Solve(int x){
Modify(Root[x],1,n,dfn[x],dfn[x],1);
for(int i=head[x];i;i=nex[i]){
int y=to[i];
if(y==fa[x]) continue;
Solve(y);Root[x]=Merge(Root[x],Root[y],1,n);
}
Ans+=Query(Root[x],1,n,1,n)-1;
Modify(Root[x],1,n,dfn[x],dfn[x],-1);
return ;
}
int main(){
read(n),read(m);
for(int i=1;i<n;i++){
int u,v;read(u),read(v);
add(u,v),add(v,u);
}
dfs1(1,0);dfs2(1);
for(int i=1;i<=m;i++){
int s,t;read(s),read(t);
int lca=QueryLca(s,t),f=fa[lca];
GetSeq(s,t);
for(int j=1;j<=Cnt;j++){
Modify(Root[s],1,n,path[j].first,path[j].second,1);
Modify(Root[t],1,n,path[j].first,path[j].second,1);
Modify(Root[f],1,n,path[j].first,path[j].second,-2);
}
}
Solve(1);
write(Ans/2);
return 0;
}