DFS序+主席树
其实这还可以用树状数组搞的(伟大的XSH大佬提出的)
But,本人苦调3个多小时,调不出来,最后无奈一查题解——woc,主席树!
暴力一发,A了。。。
那么来看看主席树是怎么PC过去的。
维护什么?
- 首先,想一下若路径A能被路径B包含,那么A的两端点就在路径B上。(无需讨论链和折线的情况)
- 那么,对于每个路径一端点开个vector,把另一端点加入其中。
- 然后,主席树上,我们就储存下A路径,查找时就看一下有多少条可以被所查路径包含。
- 这样,主席树就很好维护了。每个节点先以父亲节点为前缀构建,维护每条路径中的另一端点,(注意了,由于这是以父亲节点为前缀的,那么祖先存的路径中一端点已经被当前节点包含了)。
- 由于需要把树转换成区间,我们需要用出入栈序(回溯时序号也要+1)。
怎样修改?
那么,加入一端点时,只有在其子树中的点才能包含其(不需要讨论链的情况)。因此,将in[x]++,out[x]--。(单点修改)
怎样查询
- 查询时,由于每个节点的主席树是维护到根节点的,可以先把所查路径变成两条链(x,LCA(x,y))和(y,LCA(x,y))两部分。
- 再对于每个链(a,b),那么其的值就是Query(x,a,b)+Query(y,a,b)-Query(LCA(x,y),a,b)-Query(pre_LCA(x,y),a,b)(Query(t,a,b)表示在t节点的主席树上,查询链(a,b))。
注意:这里不要混淆了。链(a,b)表示一个区间,我们需要在以t节点的主席树上查询有多少个点在此区间中(即使是链(x,LCA(x,y)),其中在以y节点的主席树上也有节点,因为两者是不等价的。另外,查询链需要用到作差法)
这样,这道题差不多就解决了。
代码:
#include<bits/stdc++.h>
#define ll long long
const int MAXN=2e5+20;
using namespace std;
int n,m,rt[MAXN],lson[MAXN*20],rson[MAXN*20],L[MAXN],R[MAXN],mark,dfsn,lg[MAXN],pre[MAXN][20],deep[MAXN],Askx[MAXN],Asky[MAXN];
ll sum[MAXN*20],ans;
vector<int> G[MAXN],A[MAXN];
void DFS_BL(int x,int fa){
L[x]=++dfsn;
pre[x][0]=fa;
deep[x]=deep[fa]+1;
for(int i=1;(1<<i)<=deep[x];i++)pre[x][i]=pre[pre[x][i-1]][i-1];
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(t==fa)continue;
DFS_BL(t,x);
}
R[x]=++dfsn;
}
int LCA(int x,int y){
if(deep[x]<deep[y])swap(x,y);
while(deep[x]>deep[y])x=pre[x][lg[deep[x]-deep[y]]-1];
if(x==y)return x;
for(int k=lg[deep[x]]-1;k>=0;k--){
if(pre[x][k]==pre[y][k])continue;
x=pre[x][k],y=pre[y][k];
}
return pre[x][0];
}
void insert(int &node,int ot,int L,int R,int num,int pos){
node=++mark;
sum[node]=sum[ot]+num;
lson[node]=lson[ot];
rson[node]=rson[ot];
if(L==R)return;
int mid=(L+R)>>1;
if(pos<=mid)insert(lson[node],lson[ot],L,mid,num,pos);
else insert(rson[node],rson[ot],mid+1,R,num,pos);
}
ll query(int nodex,int nodey,int nodef,int nodeff,int L,int R,int st,int ed){
if(st==L&&ed==R)return sum[nodex]+sum[nodey]-sum[nodef]-sum[nodeff];
int mid=(L+R)>>1;
if(ed<=mid)return query(lson[nodex],lson[nodey],lson[nodef],lson[nodeff],L,mid,st,ed);
else if(st>=mid+1)return query(rson[nodex],rson[nodey],rson[nodef],rson[nodeff],mid+1,R,st,ed);
else return query(lson[nodex],lson[nodey],lson[nodef],lson[nodeff],L,mid,st,mid)+query(rson[nodex],rson[nodey],rson[nodef],rson[nodeff],mid+1,R,mid+1,ed);
}
void DFS(int x,int fa){
rt[x]=rt[fa];
int tmp;
for(int i=0;i<A[x].size();i++){
insert(tmp,rt[x],1,dfsn,1,L[A[x][i]]);
rt[x]=tmp;
insert(tmp,rt[x],1,dfsn,-1,R[A[x][i]]);
rt[x]=tmp;
}
for(int i=0;i<G[x].size();i++){
int t=G[x][i];
if(t==fa)continue;
DFS(t,x);
}
}
int main(){
for(int i=1;i<=MAXN-10;i++)lg[i]=lg[i-1]+((1<<lg[i-1])==i);
scanf("%d %d",&n,&m);
for(int i=1;i<=n-1;i++){
int x,y;
scanf("%d %d",&x,&y);
G[x].push_back(y);
G[y].push_back(x);
}
DFS_BL(1,0);
for(int i=1;i<=m;i++){
int x,y;
scanf("%d %d",&x,&y);
Askx[i]=x,Asky[i]=y;
A[x].push_back(y);
}
DFS(1,0);
for(int i=1;i<=m;i++){
int x=Askx[i],y=Asky[i];
int lca=LCA(x,y);
ans+=query(rt[x],rt[y],rt[lca],rt[pre[lca][0]],1,dfsn,L[lca],L[x]);
ans+=query(rt[x],rt[y],rt[lca],rt[pre[lca][0]],1,dfsn,L[lca],L[y]);
ans-=query(rt[x],rt[y],rt[lca],rt[pre[lca][0]],1,dfsn,L[lca],L[lca]);
ans--;
}
ll tot=1LL*m*(m-1)/2;
ll gcd=__gcd(tot,ans);
printf("%lld/%lld",ans/gcd,tot/gcd);
return 0;
}