来自51nod的双倍经验题
公共祖先 : 题意: 给你两颗树,问你有多少对点(a,b,c)满足c在第一棵树中是(a,b)的公共祖先,在第二棵树中也是(a,b)的公共祖先
双重祖先:题意:给你两棵树,问你有多少对点(u,v)满足u在第一棵树种是v的祖先,在第二棵树种也是v的祖先
乍看一下好像不是很一样,蓝鹅…
转换一下题意发现,题意都是:给你两棵树,求两棵树中的每个节点有多少个共同儿子(设为$x_i$)
然后
对于「51nod1681」就大力$ans=sum_{i=1}^{n} inom{x_i}{2}$
而对于「51nod2553」就直接$ans=sum_{i=1}^{n} x_i$
怎么统计呢...大力线段树合并!(我想出来的通解)
做法:对第一棵树跑出dfs序,在第二棵树中遍历一下,找出$[L_x,R_x]$有多少个数
而找出有多少个数可以用线段树合并来处理。(裸)
对于第二道题可以直接上线段树,遍历到x的时候对$[L_x,R_x]+1$,然后统计一下子树中节点$id_x$的权值。
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define ll long long 4 inline ll read() { 5 ll x=0,f=1; char ch=getchar(); 6 for(;ch<'0'||ch>'9';ch=getchar()) 7 if(ch=='-')f=-f; 8 for(;ch>='0'&&ch<='9';ch=getchar()) 9 x=x*10+ch-'0'; 10 return x*f; 11 } 12 inline void chkmin( int &a,int b ) { if(a>b) a=b; } 13 inline void chkmax( int &a,int b ) { if(a<b) a=b; } 14 #define _ read() 15 #define ln endl 16 const int N=2e5+5; 17 int n,vis[N]; 18 int rt[N],lc[20*N],rc[20*N],tot; 19 int l[N],r[N],cnt; 20 ll sum[20*N],ans; 21 vector < int > G_1[N],G_2[N]; 22 inline void up( int x ) { 23 sum[x]=sum[lc[x]]+sum[rc[x]]; 24 } 25 inline int newnode( int x,int l,int r ) { 26 ++tot; 27 if(l==r) return tot; 28 int mid=(l+r)>>1,now=tot; 29 if(x<=mid) lc[tot]=newnode(x,l,mid); 30 else rc[tot]=newnode(x,mid+1,r); 31 return now; 32 } 33 inline int merge( int x,int y,int l,int r ) { 34 if(!x||!y) return x+y; 35 tot++; 36 if(l==r) return tot; 37 int mid=(l+r)>>1,now=tot; 38 lc[now]=merge(lc[x],lc[y],l,mid); 39 rc[now]=merge(rc[x],rc[y],mid+1,r); 40 up(now); 41 return now; 42 } 43 inline void add( int x,int l,int r,int rt ) { 44 if(l==r) { sum[rt]++; return; } 45 int mid=(l+r)>>1; 46 if(x<=mid) add(x,l,mid,lc[rt]); 47 else add(x,mid+1,r,rc[rt]); 48 up(rt); 49 } 50 inline ll query( int L,int R,int l,int r,int rt ) { 51 if(L<=l&&r<=R) return sum[rt]; 52 int mid=(l+r)>>1; 53 ll ans=0; 54 if(L<=mid) ans+=query(L,R,l,mid,lc[rt]); 55 if(R>mid) ans+=query(L,R,mid+1,r,rc[rt]); 56 return ans; 57 } 58 inline void dfs_1( int x,int fa ) { 59 l[x]=++cnt; 60 for( int i=0;i<G_1[x].size();i++ ) 61 dfs_1(G_1[x][i],x); 62 r[x]=cnt; 63 } 64 inline void dfs_2( int x,int fa ) { 65 for( int i=0;i<G_2[x].size();i++ ) { 66 dfs_2(G_2[x][i],x); 67 rt[x]=merge(rt[x],rt[G_2[x][i]],1,n); 68 } 69 ll tmp=query(l[x],r[x],1,n,rt[x]); 70 ans+=tmp*(tmp-1)/2; 71 add(l[x],1,n,rt[x]); 72 } 73 int main() { 74 n=_; 75 for( int i=1;i<n;i++ ) { 76 int x=_,y=_; 77 G_1[x].push_back(y); 78 vis[y]=1; 79 } 80 for( int i=1;i<=n;i++ ) if(!vis[i]) { dfs_1(i,0); break; } 81 for( int i=1;i<=n;i++ ) vis[i]=0; 82 for( int i=1;i<n;i++ ) { 83 int x=_,y=_; 84 G_2[x].push_back(y); 85 vis[y]=1; 86 } 87 for( int i=1;i<=n;i++ ) 88 rt[i]=newnode(l[i],1,n); 89 for( int i=1;i<=n;i++ ) if(!vis[i]) { dfs_2(i,0); break; } 90 cout<<ans<<ln; 91 }
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define ll long long 4 inline ll read() { 5 ll x=0,f=1; char ch=getchar(); 6 for(;ch<'0'||ch>'9';ch=getchar()) 7 if(ch=='-')f=-f; 8 for(;ch>='0'&&ch<='9';ch=getchar()) 9 x=x*10+ch-'0'; 10 return x*f; 11 } 12 inline void chkmin( int &a,int b ) { if(a>b) a=b; } 13 inline void chkmax( int &a,int b ) { if(a<b) a=b; } 14 #define _ read() 15 #define ln endl 16 const int N=2e5+5; 17 int n; 18 int rt[N],lc[30*N],rc[30*N],tot; 19 int l[N],r[N],cnt; 20 ll sum[30*N],ans; 21 vector < int > G_1[N],G_2[N]; 22 inline void up( int x ) { 23 sum[x]=sum[lc[x]]+sum[rc[x]]; 24 } 25 inline int newnode( int x,int l,int r ) { 26 ++tot; 27 if(l==r) return tot; 28 int mid=(l+r)>>1,now=tot; 29 if(x<=mid) lc[tot]=newnode(x,l,mid); 30 else rc[tot]=newnode(x,mid+1,r); 31 return now; 32 } 33 inline int merge( int x,int y,int l,int r ) { 34 if(!x||!y) return x+y; 35 tot++; 36 if(l==r) { sum[tot]=1; return tot; } 37 int mid=(l+r)>>1,now=tot; 38 lc[now]=merge(lc[x],lc[y],l,mid); 39 rc[now]=merge(rc[x],rc[y],mid+1,r); 40 up(now); 41 return now; 42 } 43 inline void add( int x,int l,int r,int rt ) { 44 if(l==r) { sum[rt]++; return; } 45 int mid=(l+r)>>1; 46 if(x<=mid) add(x,l,mid,lc[rt]); 47 else add(x,mid+1,r,rc[rt]); 48 up(rt); 49 } 50 inline int query( int L,int R,int l,int r,int rt ) { 51 if(L<=l&&r<=R) return sum[rt]; 52 int mid=(l+r)>>1,ans=0; 53 if(L<=mid) ans+=query(L,R,l,mid,lc[rt]); 54 if(R>mid) ans+=query(L,R,mid+1,r,rc[rt]); 55 return ans; 56 } 57 inline void dfs_1( int x,int fa ) { 58 l[x]=++cnt; 59 for( int i=0;i<G_1[x].size();i++ ) 60 if(G_1[x][i]!=fa) 61 dfs_1(G_1[x][i],x); 62 r[x]=cnt; 63 } 64 inline void dfs_2( int x,int fa ) { 65 for( int i=0;i<G_2[x].size();i++ ) 66 if(G_2[x][i]!=fa) { 67 dfs_2(G_2[x][i],x); 68 rt[x]=merge(rt[x],rt[G_2[x][i]],1,n); 69 } 70 ans+=query(l[x],r[x],1,n,rt[x]); 71 add(l[x],1,n,rt[x]); 72 } 73 int main() { 74 n=_; 75 for( int i=1;i<n;i++ ) { 76 int x=_,y=_; 77 G_1[x].push_back(y); 78 G_1[y].push_back(x); 79 } 80 dfs_1(1,0); 81 for( int i=1;i<n;i++ ) { 82 int x=_,y=_; 83 G_2[x].push_back(y); 84 G_2[y].push_back(x); 85 } 86 for( int i=1;i<=n;i++ ) 87 rt[i]=newnode(l[i],1,n); 88 dfs_2(1,0); 89 cout<<ans<<ln; 90 }
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define ll long long 4 inline ll read() { 5 ll x=0,f=1; char ch=getchar(); 6 for(;ch<'0'||ch>'9';ch=getchar()) 7 if(ch=='-')f=-f; 8 for(;ch>='0'&&ch<='9';ch=getchar()) 9 x=x*10+ch-'0'; 10 return x*f; 11 } 12 inline void chkmin( int &a,int b ) { if(a>b) a=b; } 13 inline void chkmax( int &a,int b ) { if(a<b) a=b; } 14 #define _ read() 15 #define ln endl 16 const int N=1e5+5; 17 ll ans,sum[4*N],tag[4*N]; 18 int n,vis[N],l[N],r[N],dep[N],tot; 19 vector < int > G_1[N],G_2[N],vec[N]; 20 inline void up( int rt ) { sum[rt]=sum[rt*2]+sum[rt*2+1]; } 21 inline void down( int ln,int rn,int rt ) { 22 if(tag[rt]) { 23 tag[rt*2]+=tag[rt]; 24 tag[rt*2+1]+=tag[rt]; 25 sum[rt*2]+=tag[rt]*ln; 26 sum[rt*2+1]+=tag[rt]*rn; 27 tag[rt]=0; 28 } 29 } 30 inline void add( int L,int R,int x,int l,int r,int rt ) { 31 if(L<=l&&r<=R) { sum[rt]+=x*(r-l+1); tag[rt]+=x; return; } 32 int mid=(l+r)>>1; 33 down(mid-l+1,r-mid,rt); 34 if(L<=mid) add(L,R,x,l,mid,rt*2); 35 if(R>mid) add(L,R,x,mid+1,r,rt*2+1); 36 up(rt); 37 } 38 inline ll query( int x,int l,int r,int rt ) { 39 if(l==r) return sum[rt]; 40 int mid=(l+r)>>1; 41 down(mid-l+1,r-mid,rt); 42 if(x<=mid) return query(x,l,mid,rt*2); 43 else return query(x,mid+1,r,rt*2+1); 44 } 45 inline void dfs_1( int x,int fa ) { 46 l[x]=++tot; 47 for( int i=0;i<G_1[x].size();i++ ) 48 if(G_1[x][i]!=fa) dfs_1(G_1[x][i],x); 49 r[x]=tot; 50 } 51 inline void dfs_2( int x,int fa ) { 52 // cout<<x<<":"<<query(l[x],1,n,1)<<ln; 53 ans+=query(l[x],1,n,1); 54 add(l[x],r[x],1,1,n,1); 55 for( int i=0;i<G_2[x].size();i++ ) 56 if(G_2[x][i]!=fa) dfs_2(G_2[x][i],x); 57 add(l[x],r[x],-1,1,n,1); 58 } 59 int main() { 60 // freopen("input.txt","r",stdin); 61 n=_; 62 for( int i=1;i<n;i++ ) { 63 int x=_,y=_; 64 G_1[x].push_back(y); 65 G_1[y].push_back(x); 66 } dfs_1(1,0); 67 for( int i=1;i<n;i++ ) { 68 int x=_,y=_; 69 G_2[x].push_back(y); 70 G_2[y].push_back(x); 71 } dfs_2(1,0); 72 cout<<ans<<ln; 73 }
(注意一下「51nod1681」根不一定是1)