首先可以想到对每个点统计出所有经过它的链的并所包含的点数,然后可以直接得到答案。根据实现不同有下面几种方法。
三个log:假如对每个点都存下经过它的链并S[x],那么每新加一条路径进来的时候,相当于在路径上所有点的S中都加入这条路径。树剖之后,相当于对log个区间中的点都加入log个区间。具体实现有树剖后线段树维护虚树、矩形扫描线、线段树+set存区间等多种方法,这里不再多说。
两个log:先树剖,然后对每个点开一棵线段树存储它的S,由于题中没有修改,所以可以树上差分+线段树合并,这样可以将方法一中“需要修改的区间数”的log去掉了。
一个log:发现就是对每个点求所有经过它的路径的端点的斯坦纳树(这里一个点集的斯坦纳树是指原树上最小的点集,满足包含这个点集且连通)。考虑如何暴力求一个点集的斯坦纳树,那显然就是将它们按DFS序排序后,所有点深度之和减去每对相邻点LCA的深度和。为了方便我们将点集中强制加入根,最后求出结果后再减去所有点LCA的深度的两倍。以DFS序为下标建线段树,每个点维护它所代表的DFS区间中,所有在点集中的点(加上根)构成的斯坦纳树的大小。两个区间的合并就是两边的斯坦纳树大小之和,减去左边区间里在点集中的DFS序最大的点与右边区间里在点集中的DFS序最小的点的LCA的深度,于是再维护区间里在点集中的DFS序最大/小的点分别是谁即可。同样使用树上差分+线段树合并,就可以将方法一中“每个修改区间中要加入的区间数”的log去掉了。
(参考https://www.luogu.org/blog/Sooke/solution-p5327)
1 #include<cstdio> 2 #include<vector> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 #define For(i,x) for (int i=h[x],k; i; i=nxt[i]) 6 typedef long long ll; 7 using namespace std; 8 9 const int N=200010,M=6400010,K=18; 10 int n,m,x,y,tim,cnt,nd,d[N],lg[N],rt[N],fa[N],dfn[N],st[N][20]; 11 int v[M],ls[M],rs[M],s[M],t[M],c[M],h[N],to[N],nxt[N]; 12 ll ans; 13 vector<int>del[N]; 14 15 void add(int u,int v){ to[++cnt]=v; nxt[cnt]=h[u]; h[u]=cnt; } 16 17 void dfs(int x){ 18 d[x]=d[fa[x]]+1; st[++tim][0]=x; dfn[x]=tim; 19 For(i,x) if ((k=to[i])!=fa[x]) fa[k]=x,dfs(k),st[++tim][0]=x; 20 } 21 22 void init(){ 23 rep(j,1,lg[tim]) rep(i,1,tim-(1<<j)+1){ 24 int x=st[i][j-1],y=st[i+(1<<(j-1))][j-1]; 25 st[i][j]=d[x]<d[y] ? x : y; 26 } 27 } 28 29 int lca(int x,int y){ 30 if (!x || !y) return 0; 31 x=dfn[x]; y=dfn[y]; 32 if (x>y) swap(x,y); 33 int t=lg[y-x+1]; x=st[x][t]; y=st[y-(1<<t)+1][t]; 34 return d[x]<d[y] ? x : y; 35 } 36 37 void upd(int x){ 38 v[x]=v[ls[x]]+v[rs[x]]-d[lca(t[ls[x]],s[rs[x]])]; 39 s[x]=s[ls[x]] ? s[ls[x]] : s[rs[x]]; 40 t[x]=t[rs[x]] ? t[rs[x]] : t[ls[x]]; 41 } 42 43 void mdf(int &x,int L,int R,int p,int k){ 44 if (!x) x=++nd; 45 if (L==R){ c[x]+=k; v[x]=(c[x]?d[p]:0); s[x]=t[x]=(c[x]?p:0); return; } 46 int mid=(L+R)>>1; 47 if (dfn[p]<=mid) mdf(ls[x],L,mid,p,k); else mdf(rs[x],mid+1,R,p,k); 48 upd(x); 49 } 50 51 int merge(int x,int y,int L,int R){ 52 if (!x || !y) return x|y; 53 if (L==R){ c[x]+=c[y]; v[x]|=v[y]; s[x]|=s[y]; t[x]|=t[y]; return x; } 54 int mid=(L+R)>>1; ls[x]=merge(ls[x],ls[y],L,mid); rs[x]=merge(rs[x],rs[y],mid+1,R); 55 upd(x); return x; 56 } 57 58 void solve(int x){ 59 For(i,x) if ((k=to[i])!=fa[x]) solve(k); 60 int ed=del[x].size()-1; 61 rep(i,0,ed) mdf(rt[x],1,tim,del[x][i],-1); 62 ans+=v[rt[x]]-d[lca(s[rt[x]],t[rt[x]])]; rt[fa[x]]=merge(rt[fa[x]],rt[x],1,tim); 63 } 64 65 int main(){ 66 freopen("a.in","r",stdin); 67 freopen("a.out","w",stdout); 68 scanf("%d%d",&n,&m); 69 rep(i,2,n<<1) lg[i]=lg[i>>1]+1; 70 rep(i,2,n) scanf("%d%d",&x,&y),add(x,y),add(y,x); 71 dfs(1); init(); 72 rep(i,1,m){ 73 scanf("%d%d",&x,&y); int l=lca(x,y); 74 mdf(rt[x],1,tim,x,1); mdf(rt[x],1,tim,y,1); 75 mdf(rt[y],1,tim,x,1); mdf(rt[y],1,tim,y,1); 76 del[l].push_back(x); del[l].push_back(y); 77 del[fa[l]].push_back(x); del[fa[l]].push_back(y); 78 } 79 solve(1); printf("%lld ",ans/2); 80 return 0; 81 }