( ext{Solution:})
菜鸡太菜了想了好久没有思路……只知道要求树上链的并集但不知道咋整……虽然题目算法线段树合并和树上差分看到题就能想到……但是怎么做还是有思维难度……(对笔者来说)
看了好久的题解都没有看懂 这次写的详细一点。
- 所谓矩阵面积并
某些题解中说了一句 “可以用树剖套扫描线” “就是一个矩阵面积并” 的做法,复杂度 (O(nlog^3 n).)
实际上,我们对树进行树剖后,每一条路径 ((s,t)) 都可以被划分为 (log n) 个线段。我们可以把每一段被划分的重链区间看做矩形,那我们对每个点求的链并集,实际上就是一个抽象的 “矩形面积并” 。
复杂度是树剖的 (log) 和扫面线本身的两个 (log). 我没有实现这种做法。
- 线段树合并、树上差分与树剖
常规的 (log^2 n) 做法是这个。
由于有序对不好算,我们可以算无序对的个数,除以二即可。
记 (S_u) 表示 (u) 可以到达的点的集合(不包括它自己)。那么 (ans=frac{sum S_i}{2}.)
如何对每一个点计算出它的 (S) 呢?
首先明确一点:我们对每一个点建立的线段树是一棵 以 dfs序 为基础的线段树,它维护了整棵树的信息。
那么,对每一个点都有这样一棵线段树,对于每一个语言传递过程,我们在路上的每一个点上都进行一次区间覆盖操作,最终每一个点上线段树上被覆盖的长度就是它的 (S) .
显然这玩意复杂度不对。
首先,对树上路径修改,我们需要用到树剖。
那么对于一条路径上的点,我们可以自然想到用树上差分的思路来优化。
并不知道一棵支持区间修改的树咋合并) 上文所述,我们需要维护一棵线段树上被覆盖的长度,并且是 全局询问 。
那这个东西就长得很像扫描线了。 维护区间被覆盖的次数来更新即可。
那么合并呢?
合并要注意一下细节:每一次合并到的点都要记录信息,因为这种写法的区间覆盖次数没有什么下传操作,不要只在叶子上维护 pushup 操作,这样维护的信息是错误而不全面的。
其他写法都一样,复杂度是一个 (log).
那么总结就是:考虑暴力,每一个点都记录一下链上信息,又因为需要树上修改路径需要树剖操作,而对于一条路径上的点我们可以考虑树上差分优化,进而利用线段树合并来解决这题。
时间复杂度:(O(nlog^2 n).)
空间分析
之前从神鱼那里见到,空间复杂度是 (2nlog n) 的。
计算一下:(Num=2cdot 10^5 cdot log 10^5=3.32cdot 10^6) 级别。然而代码中,进行修改的操作达到了 (O(mlog n)) 级别的个数,也就是多了个 (log) ,约为(5.5cdot 10^7)级别。空间没有卡满,代码开到 (2cdot 10^7) 可以过去。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e7+10;
int ls[MAXN],rs[MAXN],topp,rub[MAXN],node;
int cnt[MAXN],len[MAXN],id[MAXN],rk[MAXN];
int rt[MAXN],pa[MAXN],siz[MAXN],son[MAXN];
int dep[MAXN],head[MAXN],tot,n,m,top[MAXN];
long long ans;
struct E {
int nxt,to;
} e[MAXN];
inline int Max(int x,int y){return x>y?x:y;}
inline void add(int x,int y) {
e[++tot] = ( E ) {
head [ x ] , y
} ;
head[x]=tot;
}
void dfs1(int x,int fa) {
pa [ x ] = fa ;
siz [ x ] = 1 ;
dep [ x ] = dep [ fa ] + 1 ;
for ( int i=head[x]; i; i=e[i].nxt) {
int j=e[i].to;
if(j==fa)continue;
dfs1(j,x);
siz[x]+=siz[j];
if(siz[j]>siz[son[x]])son[x]=j;
}
}
int dfstime;
void dfs2(int u,int t) {
top [ u ] = t ;
id [ u ] = ++ dfstime ;
if ( ! son [ u ] )
return ;
dfs2 ( son [ u ] , t ) ;
for ( int i = head [ u ] ; i ; i = e [ i ] .nxt ) {
int j = e [ i ] .to ;
if ( j == son [ u ] || j == pa [ u ] )
continue ;
dfs2 ( j , j ) ;
}
}
inline void del(int x) {
rub[++topp]=x;
len[x]=cnt[x]=ls[x]=rs[x]=0;
}
inline int New() {
if(topp)return rub[topp--];
return ++node;
}
inline void pushup(int x,int l,int r) {
if(cnt[x]>0)
len[x]=r-l+1;
else
len[x]=len[ls[x]]+len[rs[x]];
}
void change(int &x,int L,int R,int l,int r,int v) {
if(!x)x=New();
if ( l <= L && R <= r ) {
cnt[x]+=v;
pushup(x,L,R);
return;
}
int mid=(L+R)>>1;
if(l<=mid)change(ls[x],L,mid,l,r,v);
if(mid<r)change(rs[x],mid+1,R,l,r,v);
pushup(x,L,R);
}
int merge(int x,int y,int l,int r) {
if(!x||!y)return x+y;
cnt[x]+=cnt[y];
if(l==r) {
pushup(x,l,r);
del(y);
return x;
}
int mid=(l+r)>>1;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
pushup(x,l,r);
del(y);
return x;
}
int query(int x,int L,int R,int l,int r) {
if(L>=l&&R<=r)return len[x];
int mid=(L+R)>>1,val=0;
if(l<=mid)val+=query(ls[x],L,mid,l,r);
if(mid<r)val+=query(rs[x],mid+1,R,l,r);
return val;
}
void changes(int root,int x,int y,int v) {
while(top[x]!=top[y]) {
if(dep[top[x]]>=dep[top[y]]) {
change(rt[root],1,n,id[top[x]],id[x],v);
x=pa[top[x]];
} else {
change(rt[root],1,n,id[top[y]],id[y],v);
y=pa[top[y]];
}
}
if(id[x]<=id[y])change(rt[root],1,n,id[x],id[y],v);
else change(rt[root],1,n,id[y],id[x],v);
}
inline int LCA(int x,int y) {
while ( top [ x ] != top [ y ] ) {
if(dep[top[x]]>=dep[top[y]])x=pa[top[x]];
else y=pa[top[y]];
}
return dep [ x ] < dep [ y ] ? x : y ;
}
void dfs3 ( int x ) {
for ( int i = head [ x ] ; i ; i = e [ i ] .nxt ) {
int j = e [ i ] .to ;
if ( j == pa [ x ] )
continue ;
dfs3 ( j ) ;
rt [ x ] = merge ( rt [ x ] , rt [ j ] , 1 , n ) ;
}
ans += Max(query ( rt [ x ] , 1 , n , 1 , n ) -1,0) ;
}
signed main() {
freopen("1.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1; i<n; ++i) {
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(1,1);
dfs2(1,1);
for(; m; m--) {
int s,t;
scanf("%d%d",&s,&t);
int lca=LCA(s,t);
changes(s,s,t,1);
changes(t,s,t,1);
changes(lca,s,t,-1);
if(lca==1)continue;
changes(pa[lca],s,t,-1);
}
dfs3(1);
ans>>=1;
printf("%lld
",ans);
return 0;
}