题目
题目链接:https://www.luogu.com.cn/problem/P7735
小 W 有一棵 (n) 个结点的树,树上的每一条边可能是轻边或者重边。接下来你需要对树进行 (m) 次操作,在所有操作开始前,树上所有边都是轻边。操作有以下两种:
- 给定两个点 (a) 和 (b),首先对于 (a) 到 (b) 路径上的所有点 (x)(包含 (a) 和 (b)),你要将与 (x) 相连的所有边变为轻边。然后再将 (a) 到 (b) 路径上包含的所有边变为重边。
- 给定两个点 (a) 和 (b),你需要计算当前 (a) 到 (b) 的路径上一共包含多少条重边。
(n,mleq 10^5)。
思路
修改操作的话可以先把所有与这条链相连的边都覆盖为 (0),然后再把这一条链上的点覆盖为 (1)。
把边的颜色扔到点上,重链剖分,然后考虑一次修改操作中其中一条重链 ((x,y)) 应该如何处理。
记 (z) 是 (x) 的重儿子,那么需要覆盖为 (0) 的部分有:
- ((z,y)) 这条重链的所有点。
- ((x,y)) 这条重链的所有点的轻儿子。
- (x) 的重儿子。
所以考虑把重链和轻儿子的贡献分开算。维护两棵线段树,第一棵就是重链剖分常规的线段树,把每一条重链放在一个区间中,第二棵线段树则是把每一个点的所有轻儿子放到一个区间中,且每一条重链的轻儿子也要在区间中。这样重新编号应该不难实现。
剩余的部分就很裸了。时间复杂度 (O(nlog^2 n))。
但是这样写常数很大,oisdoaiu 大爷给我讲了一种常数似乎小很多的做法。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=100010;
int Q,n,m,tot,head[N],id1[N],id2[N],id3[N],siz[N],fa[N],son[N],cnt[N],dep[N],top[N];
int read()
{
int d=0; char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
return d;
}
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs1(int x,int fat)
{
fa[x]=fat; dep[x]=dep[fat]+1; siz[x]=1; son[x]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fat)
{
dfs1(v,x); siz[x]+=siz[v];
if (siz[v]>siz[son[x]]) son[x]=v;
}
}
}
void dfs2(int x,int tp)
{
top[x]=tp; id1[x]=++tot;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa[x] && v!=son[x]) dfs2(v,v);
}
}
void dfs3(int x)
{
id3[x]=tot+1; cnt[x]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=son[x] && v!=fa[x])
id2[v]=++tot,cnt[x]++;
}
if (son[x]) dfs3(son[x]);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=son[x] && v!=fa[x]) dfs3(v);
}
}
struct SegTree
{
int sum[N*4],lazy[N*4];
void pushdown(int x,int l,int r)
{
if (lazy[x]!=-1)
{
int mid=(l+r)>>1;
sum[x*2]=lazy[x]*(mid-l+1); lazy[x*2]=lazy[x];
sum[x*2+1]=lazy[x]*(r-mid); lazy[x*2+1]=lazy[x];
lazy[x]=-1;
}
}
void update(int x,int l,int r,int ql,int qr,int v)
{
if (ql>qr) return;
if (ql<=l && qr>=r)
return (void)(sum[x]=v*(r-l+1),lazy[x]=v);
pushdown(x,l,r);
int mid=(l+r)>>1;
if (ql<=mid) update(x*2,l,mid,ql,qr,v);
if (qr>mid) update(x*2+1,mid+1,r,ql,qr,v);
sum[x]=sum[x*2]+sum[x*2+1];
}
int query(int x,int l,int r,int ql,int qr)
{
if (ql>qr) return 0;
if (ql<=l && qr>=r) return sum[x];
pushdown(x,l,r);
int mid=(l+r)>>1,res=0;
if (ql<=mid) res+=query(x*2,l,mid,ql,qr);
if (qr>mid) res+=query(x*2+1,mid+1,r,ql,qr);
return res;
}
}seg1,seg2;
void clear(int x,int y)
{
for (;top[x]!=top[y];x=fa[top[x]])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
seg1.update(1,1,n,id1[top[x]]+1,id1[x],0);
seg2.update(1,1,n,id3[top[x]],id3[x]+cnt[x]-1,0);
if (son[x]) seg1.update(1,1,n,id1[son[x]],id1[son[x]],0);
}
if (dep[x]<dep[y]) swap(x,y);
seg1.update(1,1,n,id1[y]+1,id1[x],0);
seg2.update(1,1,n,id3[y],id3[x]+cnt[x]-1,0);
if (son[x]) seg1.update(1,1,n,id1[son[x]],id1[son[x]],0);
if (y!=1 && son[fa[y]]==y) seg1.update(1,1,n,id1[y],id1[y],0);
if (y!=1 && son[fa[y]]!=y) seg2.update(1,1,n,id2[y],id2[y],0);
}
void update(int x,int y)
{
for (;top[x]!=top[y];x=fa[top[x]])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
seg1.update(1,1,n,id1[top[x]]+1,id1[x],1);
seg2.update(1,1,n,id2[top[x]],id2[top[x]],1);
}
if (dep[x]<dep[y]) swap(x,y);
seg1.update(1,1,n,id1[y]+1,id1[x],1);
}
void query(int x,int y)
{
int ans=0;
for (;top[x]!=top[y];x=fa[top[x]])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=seg1.query(1,1,n,id1[top[x]]+1,id1[x]);
ans+=seg2.query(1,1,n,id2[top[x]],id2[top[x]]);
}
if (dep[x]<dep[y]) swap(x,y);
ans+=seg1.query(1,1,n,id1[y]+1,id1[x]);
cout<<ans<<"
";
}
void prework()
{
memset(head,-1,sizeof(head));
memset(seg1.lazy,-1,sizeof(seg1.lazy));
memset(seg2.lazy,-1,sizeof(seg2.lazy));
memset(seg1.sum,0,sizeof(seg1.sum));
memset(seg2.sum,0,sizeof(seg2.sum));
tot=0;
}
int main()
{
Q=read();
while (Q--)
{
prework();
n=read(); m=read();
for (int i=1,x,y;i<n;i++)
{
x=read(); y=read();
add(x,y); add(y,x);
}
tot=0; dfs1(1,0); dfs2(1,1);
tot=0; dfs3(1);
while (m--)
{
int opt=read(),x=read(),y=read();
if (opt==1) clear(x,y),update(x,y);
if (opt==2) query(x,y);
}
}
return 0;
}