题意
给定两棵有根树,两棵树均有n个节点,且根均为1号点。
问有多少对(u,v)满足:在给定的两棵树中u均为v的祖先。
思路
本题即求对于所有节点x,两棵树中公共的儿子数量之和。
对于第一颗树,求出其dfs序,然后遍历第二棵树,每达到一个节点,查询该节点的值,表示这个节点同时被两棵树的一些节点father的数量,加入答案。
然后把这个节点的子树全部更新,回溯的时候撤销更新(避免影响)。
代码
#include <bits/stdc++.h>
using namespace std;
namespace StandardIO {
template<typename T> inline void read (T &x) {
x=0;T f=1;char c=getchar();
for (; c<'0'||c>'9'; c=getchar()) if (c=='-') f=-1;
for (; c>='0'&&c<='9'; c=getchar()) x=x*10+c-'0';
x*=f;
}
template<typename T> inline void write (T x) {
if (x<0) putchar('-'),x=-x;
if (x>=10) write(x/10);
putchar(x%10+'0');
}
}
using namespace StandardIO;
namespace Solve {
#define int long long
const int N=100001;
int n,ans;
struct Tree {
int cnt;
int head[N];
struct node {
int to,next;
} edge[N<<1];
inline void add (int a,int b) {
edge[++cnt].to=b,edge[cnt].next=head[a],head[a]=cnt;
}
} t1,t2;
int dfn;
int id[N],size[N];
struct node {
int l,r,val,tag;
} tree[N<<2];
inline void pushup (int pos) {
tree[pos].val+=tree[pos<<1].val+tree[pos<<1|1].val;
}
inline void pushdown (int pos) {
if (tree[pos].tag) {
tree[pos<<1].tag+=tree[pos].tag,tree[pos<<1|1].tag+=tree[pos].tag;
tree[pos<<1].val+=(tree[pos<<1].r-tree[pos<<1].l+1)*tree[pos].tag;
tree[pos<<1|1].val+=(tree[pos<<1|1].r-tree[pos<<1|1].l+1)*tree[pos].tag;
tree[pos].tag=0;
}
}
void build (int l,int r,int pos) {
tree[pos].l=l,tree[pos].r=r;
if (l==r) return;
int mid=(l+r)>>1;
build(l,mid,pos<<1),build(mid+1,r,pos<<1|1);
}
void update (int l,int r,int v,int pos) {
if (l<=tree[pos].l&&tree[pos].r<=r) return tree[pos].val+=(tree[pos].r-tree[pos].l+1)*v,tree[pos].tag+=v,void();
pushdown(pos);
int mid=(tree[pos].l+tree[pos].r)>>1;
if (l<=mid) update(l,r,v,pos<<1);
if (mid<r) update(l,r,v,pos<<1|1);
pushup(pos);
}
int query (int x,int pos) {
if (tree[pos].l==tree[pos].r) return tree[pos].val;
pushdown(pos);
int mid=(tree[pos].l+tree[pos].r)>>1;
if (x<=mid) return query(x,pos<<1);
return query(x,pos<<1|1);
}
void dfs1 (int now,int fa) {
id[now]=++dfn,size[now]=1;
for (register int i=t1.head[now]; i; i=t1.edge[i].next) {
int to=t1.edge[i].to;
if (to==fa) continue;
dfs1(to,now),size[now]+=size[to];
}
}
void dfs2 (int now,int fa) {
ans+=query(id[now],1);
update(id[now],id[now]+size[now]-1,1,1);
for (register int i=t2.head[now]; i; i=t2.edge[i].next) {
int to=t2.edge[i].to;
if (to==fa) continue;
dfs2(to,now);
}
update(id[now],id[now]+size[now]-1,-1,1);
}
inline void MAIN () {
read(n);
for (register int i=1,x,y; i<=n-1; ++i) {
read(x),read(y);
t1.add(x,y),t1.add(y,x);
}
build(1,n,1);
dfs1(1,1);
for (register int i=1,x,y; i<=n-1; ++i) {
read(x),read(y);
t2.add(x,y),t2.add(y,x);
}
dfs2(1,1);
write(ans);
}
#undef int
}
int main () {
// freopen("3.in","r",stdin);
// freopen(".out","w",stdout);
Solve::MAIN();
}