传送阵:http://acm.hdu.edu.cn/showproblem.php?pid=4897
题目大意:一棵树,三个操作:1、将某条链取反,2、将与某条链相邻的边取反,3、查询某条链上为1的边数
树链剖分直接上
某条边是否被修改取决于这条边以及这条边的两个端点
对于第一个操作相当于修改边,第二个操作相当于修改点
对于修改边:将标记下放给儿子结点,直接修改即可,修改点也是直接修改,不过要另设一个标记记录
查询时将所有区间合并即可
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #define maxn 100010 using namespace std; inline int read(){ int s=0;char ch=getchar(); for(;ch<'0'||ch>'9';ch=getchar()); for(;ch>='0'&&ch<='9';ch=getchar())s=s*10+ch-'0'; return s; } int fa[maxn][18],dep[maxn],f[maxn]; int to[maxn<<1],Next[maxn<<1],tot,h[maxn]; int pos[maxn],sz,size[maxn]; struct node{ int lb,ld,rd,w,s,md,mb; void clear(){ lb=ld=rd=w=s=mb=md=0; } friend node operator + (node a,node b){ node ans; ans.s=a.s+b.s+1; ans.w=a.w+b.w+(a.rd^b.lb^b.ld); ans.ld=a.ld; ans.lb=a.lb; ans.rd=b.rd; return ans; } }t[maxn<<2]; int n,q; void mem(){ tot=0;sz=0; memset(h,0,sizeof(h)); memset(fa,0,sizeof(fa)); memset(dep,0,sizeof(dep)); memset(f,0,sizeof(f)); memset(pos,0,sizeof(pos)); memset(size,0,sizeof(size)); memset(t,0,sizeof(t)); } void add(int x,int y){ tot++;to[tot]=y;Next[tot]=h[x];h[x]=tot; } void dfs(int x){ for(int i=1;i<=17;++i) if(dep[x]<(1<<i))break; else fa[x][i]=fa[fa[x][i-1]][i-1]; for(int i=h[x];i;i=Next[i]){ int v=to[i]; if(dep[v])continue; fa[v][0]=x; dep[v]=dep[x]+1; dfs(v); size[x]+=size[v]; }size[x]++; } void dfs(int x,int ff){ pos[x]=++sz;f[x]=ff;int mx=0; for(int i=h[x];i;i=Next[i]) if(dep[to[i]]>dep[x]&&size[to[i]]>size[mx]) mx=to[i]; if(!mx)return; dfs(mx,ff); for(int i=h[x];i;i=Next[i]) if(dep[to[i]]>dep[x]&&mx!=to[i]) dfs(to[i],to[i]); } int lca(int x,int y){ if(dep[x]<dep[y])swap(x,y); int d=dep[x]-dep[y]; for(int i=0;i<=17;++i) if(d&(1<<i)) x=fa[x][i]; if(x==y)return x; for(int i=17;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if(x==y)return x; return fa[x][0]; } #define L(k) (k<<1) #define R(k) (k<<1|1) void pushdown(int k){ if(t[k].md){ t[L(k)].md^=1;t[R(k)].md^=1; t[L(k)].ld^=1;t[R(k)].ld^=1; t[L(k)].rd^=1;t[R(k)].rd^=1; t[k].md=0; } if(t[k].mb){ t[L(k)].mb^=1;t[R(k)].mb^=1; t[L(k)].w=t[L(k)].s-t[L(k)].w; t[R(k)].w=t[R(k)].s-t[R(k)].w; t[L(k)].lb^=1;t[R(k)].lb^=1; t[k].mb=0; } } void update(int k){ node a=t[k]; t[k]=t[L(k)]+t[R(k)]; t[k].mb=a.mb; t[k].md=a.md; } void build(int k,int l,int r){ if(l==r)return; int mid=(l+r)>>1; build(L(k),l,mid); build(R(k),mid+1,r); update(k); } void change1(int k,int l,int r,int x,int y){ if(x<=l&&r<=y){ t[k].mb^=1; t[k].w=t[k].s-t[k].w; t[k].lb^=1; return; } pushdown(k); int mid=(l+r)>>1; if(x<=mid)change1(L(k),l,mid,x,y); if(y>mid)change1(R(k),mid+1,r,x,y); update(k); } void work1(int x,int y){ int k=lca(x,y); while(f[x]!=f[k]){ change1(1,1,n,pos[f[x]],pos[x]);x=fa[f[x]][0]; }if(x!=k)change1(1,1,n,pos[k]+1,pos[x]); while(f[y]!=f[k]){ change1(1,1,n,pos[f[y]],pos[y]);y=fa[f[y]][0]; }if(y!=k)change1(1,1,n,pos[k]+1,pos[y]); } void change2(int k,int l,int r,int x,int y){ if(x<=l&&r<=y){ t[k].md^=1; t[k].ld^=1; t[k].rd^=1; return; } pushdown(k); int mid=(l+r)>>1; if(x<=mid)change2(L(k),l,mid,x,y); if(y>mid)change2(R(k),mid+1,r,x,y); update(k); } void work2(int x,int y){ int k=lca(x,y); while(f[x]!=f[k]){ change2(1,1,n,pos[f[x]],pos[x]);x=fa[f[x]][0]; }if(x!=k)change2(1,1,n,pos[k]+1,pos[x]); while(f[y]!=f[k]){ change2(1,1,n,pos[f[y]],pos[y]);y=fa[f[y]][0]; }change2(1,1,n,pos[k],pos[y]); } node ask(int k,int l,int r,int x,int y){ if(x<=l&&r<=y)return t[k]; pushdown(k); int mid=(l+r)>>1; node ans;ans.clear();int flag=0; if(x<=mid)ans=ask(L(k),l,mid,x,y),flag=1; if(y>mid){ if(flag)ans=ans+ask(R(k),mid+1,r,x,y); else ans=ask(R(k),mid+1,r,x,y); } update(k); return ans; } int ask(int k,int l,int r,int x){ if(l==r)return t[k].ld; pushdown(k); int mid=(l+r)>>1; int ans; if(x<=mid)ans=ask(L(k),l,mid,x); else ans=ask(R(k),mid+1,r,x); update(k); return ans; } node work3(int k,int x){ node ans;ans.clear(); int flag=0; while(f[x]!=f[k]){ if(flag)ans=ask(1,1,n,pos[f[x]],pos[x])+ans; else ans=ask(1,1,n,pos[f[x]],pos[x]); flag=1;x=fa[f[x]][0]; } if(x!=k){ if(flag)ans=ask(1,1,n,pos[k]+1,pos[x])+ans; else ans=ask(1,1,n,pos[k]+1,pos[x]); } return ans; } void getans(int x,int y){ int k=lca(x,y); node L=work3(k,x); node R=work3(k,y); int tmp=ask(1,1,sz,pos[k]); int ans=L.w+R.w; if(x!=k)ans+=(L.lb^L.ld^tmp); if(y!=k)ans+=(R.lb^R.ld^tmp); printf("%d ",ans); } int main(){ int T;scanf("%d",&T); while(T--){ mem(); n=read(); for(int i=1;i<n;++i){ int a=read(),b=read(); add(a,b);add(b,a); } dep[1]=1;dfs(1); dfs(1,1); build(1,1,sz); q=read(); for(int i=1;i<=q;++i){ int opt=read(),a=read(),b=read(); if(opt==1)work1(a,b); if(opt==2)work2(a,b); if(opt==3)getans(a,b); } } return 0; }