题面
题解
前置知识:
首先对求的值做一个转化:相当于对(1 leq i leq n),求出(S[i])表示1~n中可以与i开展贸易的点数。
一个点j(( eq i))与i能够开展贸易的充要条件是(exists x in[1,m])使得路径(path (s_x,t_x))通过点i,j。
因此,(S [i])就是所有通过i的(path(s_x,t_x)),这些路径的并集中点的个数。也就是这些路径的端点形成的虚树。
性质:k个点(u_1,u_2,…,u_k)(按dfs序升序)形成的虚树的大小是(sum_{i=1}^kdep_{u_i}-sum_{i=1}^{k} dep_{lca(u_i,u_{i + 1})}),其中(u_{k+1}=u_1)
可以使用dfs序证明,这里略去。
因此考虑对于原树中的每一个点u,维护(f[u],g[u],left[u],right[u]),使得:
-
所有通过u的“统一语言”路径,它们的端点按照dfs序排序后,形成(v_1,v_2,…,v_k)的序列。
-
(f[u]=sum_{i=1}^{k}dep_{v_k})。
-
(g[u]=sum_{i=1}^{k-1}dep_{lca(v_k,v_{k+1})})。
-
(left[u]=v_1,right[u]=v_k)
其中(left,right)是用于支持合并以及统计答案。
实现时,可以使用线段树。首先对于每一个(x in [1,m]),
- 在(s_x)处打((s_x,1))和((t_x,1))的标记。
- 在(t_x)处打((s_x,1))和((t_x,1))的标记。
- 在(lca(s_x,t_x))处打((s_x,-1))和((t_x,-1))的标记。
- 在(lca(s_x,t_x))的父亲处打((s_x,-1))和((t_x,-1))的标记。
然后,对于原树进行一次dfs,每一个原树上节点u对应的线段树首先是它所有子节点的线段树之并;
其次,按照u节点上打的每一个标记,对u对应的线段树进行更新
那么(S[u])就是(f[u]-g[u]-dep_{lca(left[u],right[u])})啦。
总时间复杂度(O(n log n))。
代码
#include<bits/stdc++.h>
using namespace std;
#define rg register
#define In inline
#define ll long long
const int N = 1e5;
const int TN = 9 * 17 * N;
typedef pair<int,int>pii;
namespace IO{
In int read(){
int s = 0,ww = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
return s * ww;
}
In void write(int x){
if(x < 0)putchar('-'),x = -x;
if(x > 9)write(x / 10);
putchar('0' + x % 10);
}
}
using namespace IO;
struct edge{
int des,next;
}e[2*N+5];
int s[N+5],t[N+5],fa[N+5];
int head[N+5],in[N+5],dfn[N+5],D[N+5],E[2*N+5];
ll dep[2*N+5];
int cnt,En,dn;
int n,m;
In void addedge(int a,int b){
cnt++;
e[cnt].des = b;
e[cnt].next = head[a];
head[a] = cnt;
}
void dfs1(int u,int f){
E[++En] = u;
in[u] = En;
D[++dn] = u;
dfn[u] = dn;
fa[u] = f;
dep[u] = dep[fa[u]] + 1;
for(rg int i = head[u];i;i = e[i].next){
int v = e[i].des;
if(v == f)continue;
dfs1(v,u);
E[++En] = u;
}
}
int lg[2*N+5];
struct ST{
int m[2*N+5][21];
void prepro(){
for(rg int i = 2;i <= 2 * N;i++)lg[i] = lg[i>>1] + 1;
for(rg int i = 1;i <= En;i++)m[i][0] = i;
for(rg int j = 1;j <= 20;j++)
for(rg int i = 1;i + (1<<(j-1)) <= En;i++){
int x = m[i][j-1],y = m[i+(1<<(j-1))][j-1];
m[i][j] = dep[E[x]] < dep[E[y]] ? x : y;
}
}
In int query(int l,int r){
int d = lg[r-l+1];
int x = m[l][d],y = m[r+1-(1<<d)][d];
return dep[E[x]] < dep[E[y]] ? x : y;
}
In int lca(int u,int v){
if(in[u] > in[v])swap(u,v);
return E[query(in[u],in[v])];
}
}S;
int rt[N+5];
struct SegTree{
ll f[TN+5],g[TN+5];
int left[TN+5],right[TN+5],lc[TN+5],rc[TN+5];
int cnt;
In void pushup(int u){
int l = lc[u],r = rc[u];
f[u] = f[l] + f[r];
if(!f[l]){
g[u] = g[r];
left[u] = left[r];
right[u] = right[r];
return;
}
if(!f[r]){
g[u] = g[l];
left[u] = left[l];
right[u] = right[l];
return;
}
left[u] = left[l],right[u] = right[r];
g[u] = g[l] + g[r] + dep[S.lca(D[right[l]],D[left[r]])];
}
In ll query(int u){
if(!f[u])return 0;
return f[u] - g[u] - dep[S.lca(D[right[u]],D[left[u]])];
}
void ud(int u,int l,int r,int x,ll d){
if(l == r){
f[u] += d * dep[D[x]];
if(!f[u])left[u] = right[u] = f[u] = g[u] = 0;
else{
int n = f[u] / dep[D[x]];
g[u] = (n - 1) * dep[D[x]];
left[u] = right[u] = x;
}
return;
}
int m = (l + r) >> 1;
if(x <= m){
if(!lc[u])lc[u] = ++cnt;
ud(lc[u],l,m,x,d);
}
else{
if(!rc[u])rc[u] = ++cnt;
ud(rc[u],m + 1,r,x,d);
}
pushup(u);
}
int merge(int u,int v,int l,int r){
if(!u || !v)return u + v;
if(l == r){
f[u] += f[v];
int n = f[u] / dep[D[l]];
if(!n)g[u] = left[u] = right[u] = 0;
else g[u] = 1ll * (n - 1) * dep[D[l]],left[u] = right[u] = l;
return u;
}
int m = (l + r) >> 1;
lc[u] = merge(lc[u],lc[v],l,m);
rc[u] = merge(rc[u],rc[v],m + 1,r);
pushup(u);
return u;
}
}T;
vector<pii>v[N+5];
ll ans[N+5];
void dfs2(int u){
rt[u] = ++T.cnt;
for(rg int i = head[u];i;i = e[i].next){
int v = e[i].des;
if(v == fa[u])continue;
dfs2(v);
rt[u] = T.merge(rt[u],rt[v],1,n);
}
for(rg int i = 0;i < v[u].size();i++){
int id = v[u][i].first,dx = v[u][i].second;
T.ud(rt[u],1,n,dfn[s[id]],dx);
T.ud(rt[u],1,n,dfn[t[id]],dx);
}
ans[u] = T.query(rt[u]);
}
int main(){
// freopen("L3046.in","r",stdin);
// freopen("L3046.out","w",stdout);
n = read(),m = read();
for(rg int i = 1;i < n;i++){
int u = read(),v = read();
addedge(u,v);
addedge(v,u);
}
dfs1(1,0);
S.prepro();
for(rg int i = 1;i <= m;i++){
s[i] = read(),t[i] = read();
if(dfn[s[i]] > dfn[t[i]])swap(s[i],t[i]);
v[s[i]].push_back(make_pair(i,1));
v[t[i]].push_back(make_pair(i,1));
int Lca = S.lca(s[i],t[i]);
v[Lca].push_back(make_pair(i,-1));
v[fa[Lca]].push_back(make_pair(i,-1));
}
dfs2(1);
ll rt = 0;
for(rg int i = 1;i <= n;i++)rt += ans[i];
rt >>= 1;
cout << rt << endl;
return 0;
}