题目大意:
给一棵树,每条边可能是黑色或白色(起始都是白色),有三种操作:
1、将u到v路径上所有边颜色翻转(黑->白,白->黑)
2、将只有一个点在u到v路径上的边颜色翻转
3、查询u到v路径上颜色为黑色的边数
如果只有1、3操作很好做,直接树链剖分+线段树,重点是2操作(废话)。
可以发现只有lca需要翻转和它父节点之间的边,因此将这条边暴力翻转。
现在剩下需要翻转的就是u到v路径上所有点与他们子节点之间的边。
我们将u到v的路径分成若干条重链,可以发现每条重链除了链低端的那个点,其他点要反转的是他们与他们所有轻儿子之间的边(这里注意特判lca少翻转了一条轻边)。
这么多边显然不能都翻转,那么我们再开一棵线段树,将重链上的点打标记,代表这些点与他们所有轻儿子之间的边需要翻转。
至于重链低端那个点还要暴力翻转它与重儿子间的边。
但现在只处理了每条重链,没有考虑两条链之间连接的那条边?
假设y链接在x链下面,那么x链低端那个点需要翻转的是除了一个轻边之外的其他轻边和一条重边,那么只需要将重边和不需要翻转的轻边翻转之后再打标记就好了。
将边下传到点,开两棵线段树,一棵维护边的颜色,另一棵维护给轻儿子翻转的标记(这棵线段树建议标记永久化),每次修改时注意链与链之间轻边的翻转即可。
#include<set> #include<map> #include<stack> #include<queue> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; int T; int n,m; int x,y; int tot; int opt; int num; int a[800010]; int b[800010]; int d[100010]; int f[100010]; int s[100010]; int to[200010]; int sum[800010]; int son[100010]; int top[100010]; int size[100010]; int head[100010]; int next[200010]; void add(int x,int y) { tot++; next[tot]=head[x]; head[x]=tot; to[tot]=y; } void dfs(int x) { d[x]=d[f[x]]+1; size[x]=1; for(int i=head[x];i;i=next[i]) { if(to[i]!=f[x]) { f[to[i]]=x; dfs(to[i]); size[x]+=size[to[i]]; if(size[to[i]]>size[son[x]]) { son[x]=to[i]; } } } } void dfs2(int x,int tp) { s[x]=++num; top[x]=tp; if(son[x]) { dfs2(son[x],tp); } for(int i=head[x];i;i=next[i]) { if(to[i]!=f[x]&&to[i]!=son[x]) { dfs2(to[i],to[i]); } } } void updata(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) { b[rt]^=1; return ; } int mid=(l+r)>>1; if(L<=mid) { updata(rt<<1,l,mid,L,R); } if(R>mid) { updata(rt<<1|1,mid+1,r,L,R); } } int downdata(int rt,int l,int r,int k) { if(l==r) { return b[rt]; } int res=b[rt]; int mid=(l+r)>>1; if(k<=mid) { return res^downdata(rt<<1,l,mid,k); } else { return res^downdata(rt<<1|1,mid+1,r,k); } } void pushup(int rt) { sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } void pushdown(int rt,int l,int r) { if(a[rt]) { int mid=(l+r)>>1; a[rt<<1]^=1; a[rt<<1|1]^=1; sum[rt<<1]=(mid-l+1)-sum[rt<<1]; sum[rt<<1|1]=(r-mid)-sum[rt<<1|1]; a[rt]=0; } } void change1(int rt,int l,int r,int k) { if(l==r) { sum[rt]^=1; return ; } pushdown(rt,l,r); int mid=(l+r)>>1; if(k<=mid) { change1(rt<<1,l,mid,k); } else { change1(rt<<1|1,mid+1,r,k); } pushup(rt); } void change2(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) { a[rt]^=1; sum[rt]=(r-l+1)-sum[rt]; return ; } pushdown(rt,l,r); int mid=(l+r)>>1; if(L<=mid) { change2(rt<<1,l,mid,L,R); } if(R>mid) { change2(rt<<1|1,mid+1,r,L,R); } pushup(rt); } int query1(int rt,int l,int r,int k) { if(l==r) { return sum[rt]; } pushdown(rt,l,r); int mid=(l+r)>>1; if(k<=mid) { return query1(rt<<1,l,mid,k); } else { return query1(rt<<1|1,mid+1,r,k); } } int query2(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) { return sum[rt]; } pushdown(rt,l,r); int mid=(l+r)>>1; int res=0; if(L<=mid) { res+=query2(rt<<1,l,mid,L,R); } if(R>mid) { res+=query2(rt<<1|1,mid+1,r,L,R); } return res; } void rotate1(int x,int y) { while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) { swap(x,y); } change2(1,1,n,s[top[x]],s[x]); x=f[top[x]]; } if(d[x]>d[y]) { swap(x,y); } if(x==y) { return ; } change2(1,1,n,s[x]+1,s[y]); } void rotate2(int x,int y) { if(son[x]) { change1(1,1,n,s[x]+1); } if(son[y]) { change1(1,1,n,s[y]+1); } while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) { swap(x,y); } updata(1,1,n,s[top[x]],s[x]); change1(1,1,n,s[top[x]]); x=f[top[x]]; change1(1,1,n,s[x]+1); } if(d[x]>d[y]) { swap(x,y); } change1(1,1,n,s[x]); change1(1,1,n,s[x]+1); updata(1,1,n,s[x],s[y]); } int ask(int x,int y) { int res=0; int ans; while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) { swap(x,y); } if(top[x]!=x) { res+=query2(1,1,n,s[top[x]]+1,s[x]); } ans=query1(1,1,n,s[top[x]]); x=f[top[x]]; res+=(ans^downdata(1,1,n,s[x])); } if(d[x]>d[y]) { swap(x,y); } if(x==y) { return res; } res+=query2(1,1,n,s[x]+1,s[y]); return res; } int main() { scanf("%d",&T); while(T--) { num=0; tot=0; memset(a,0,sizeof(a)); memset(b,0,sizeof(b)); memset(f,0,sizeof(f)); memset(d,0,sizeof(d)); memset(s,0,sizeof(s)); memset(to,0,sizeof(to)); memset(son,0,sizeof(son)); memset(top,0,sizeof(top)); memset(sum,0,sizeof(sum)); memset(size,0,sizeof(size)); memset(head,0,sizeof(head)); memset(next,0,sizeof(next)); scanf("%d",&n); for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1); dfs2(1,1); scanf("%d",&m); for(int i=1;i<=m;i++) { scanf("%d%d%d",&opt,&x,&y); if(opt==1) { rotate1(x,y); } else if(opt==2) { rotate2(x,y); } else { printf("%d ",ask(x,y)); } } } }