数链剖分
给一棵树,树链剖分后,树上任意一条链上的节点都可以用O(logn)个连续的区间表示。
声明:
参考课件来自宋泽宇(我不认识的...),讲授来自Accelerator
定义
一个点的:size表示子树中包括该点本身的节点个数,重儿子表示儿子中size最大的点
一个点和它重儿子之间的边为重边, 除重边的其他边为轻边。
重边构成的链为重链(最上面的重儿子他爸也在重链上)。
性质
一个点只能在且一定在一条重链上(叶子相当于一个重链)
因为每个点只能有一个重儿子,所以两个重链不会相交
定义一个点的top为该点所在重链中深度最小的点。
步骤
首先,dfs根节点找重儿子
void dfs1(int x, int fa) {
a[x].deep = a[fa].deep + 1;
a[x].size = 1;
a[x].fa = fa;
for(int i = head[x]; i; i=e[i].next )
if(e[i].y != fa) {
dfs1(e[i].y , x);
a[x].size += a[e[i].y ].size ;
a[x].son = a[a[x].son ].size > a[e[i].y].size ? a[x].son : e[i].y ;
}
}
之后,再dfs一遍找到每个点的top
如果要用数据结构维护序列还需要处理出每个点在序列上的位置
void dfs2(int x, int tp) {
a[x].in = ++clock;
a[x].tp = tp;
//如果有重儿子要先dfs重儿子, 如果没有,就为叶子结点,自然也就不能dfs下去了
if(a[x].son == 0) return ;
dfs2(a[x].son, tp);
for(int i = head[x]; i; i = e[i].next )
if(e[i].y != a[x].son && e[i].y != a[x].fa ) //注意判断是不是重儿子(没必要dfs)和祖先(不能dfs)
dfs2(e[i].y , e[i].y );
a[x].out = clock;
}
(偷偷补一个课件觉得简单所以没写,而我又刚学的知识:
in,out为树链剖分序(一种dfs序),对于找 i 的子树,对 i 的子树求和,i 子树加这些操作,只需对这样的 j 修改/查询, 即满足in[i] <= in[j] <= out[i])
//可以自己手画一下看看,加深印象
对于一条链上的询问,每次让top深度大的点跳到这个点的top的父亲。(不然不就跳超过了嘛
并处理从这个点到它的top这段重链。直到两个点的top相同。
此时处理两点之间的这段重链。
证明:
关于链上询问做法的证明:
由于dfs时优先处理每个点的重儿子
故一条重链在序列上是连续的一段区间。
复杂度证明:
每次跳到top的父亲都会走一条轻边。会使子树大小至少扩大2倍
故有O(logn)段区间。
lca:
两个点top相同后深度小的点一定是原来两点的lca
树链剖分模板(仅供参考,手比较丑,勿怪):(我个傻子...居然需要对着家里迷迷糊糊写的代码查错.....
注意dfs2()中,在用线段树维护序列的情况下,处理每个点在序列上的位置
注意!!!
千万不要像博主一样,定义一个clock!!在iostream库里,clock是个关键字!!(在IDE中可能不会像gdb一样提醒你哦),小心CE! 改为_clock即可(我这懒得改了2333....)
#include<cstdio>
#include<algorithm>
using namespace std;
const int MAXN = 100000+99;
const int MAXM = MAXN<<1;
int n,m,s,p;
int cnt, head[MAXN];
struct node{
int tp, size, fa, son, deep, in, out;
//top, size,它爸, 重儿子 , 深度 , dfs序
}a[MAXN];
int clock;
struct seg{
int y,next;
}e[MAXM];
void add_edge(int x, int y) {
e[++cnt].y = y;
e[cnt].next = head[x];
head[x] = cnt;
}
void dfs1(int x, int fa) {
a[x].size = 1;
a[x].deep = a[fa].deep + 1;
a[x].fa = fa;
for(int i = head[x]; i; i = e[i].next )
if(e[i].y != fa) {
dfs1(e[i].y , x);
a[x].size += a[e[i].y ].size ;
a[x].son = a[a[x].son].size > a[e[i].y].size ? a[x].son : e[i].y ;
}
}
//-----------------------------------------------------------------
int arr[MAXN], pos[MAXN];
struct tree{
int sum, add;
}tr[MAXN<<2];//in为下标建立线段树
void dfs2(int x, int tp) {
a[x].in = ++clock;
a[x].tp = tp;
pos[clock] = arr[x];//维护每个点在序列上的位置
if(a[x].son) dfs2(a[x].son , tp);
for(int i = head[x]; i; i = e[i].next)
if(e[i].y != a[x].son && e[i].y != a[x].fa )
dfs2(e[i].y , e[i].y );
a[x].out = clock;
}
void pushup(int o) {tr[o].sum = (tr[o<<1].sum + tr[o<<1|1].sum)%p ;}
void build(int o, int l, int r) {
tr[o].add = 0;
if(l == r) {
tr[o].sum = pos[l]%p;
return ;
}
int mid = (l+r)>>1;
build(o<<1, l, mid);
build(o<<1|1, mid+1, r);
pushup(o);
}
void pushdown(int o, int l, int r) {
if(tr[o].add == 0) return ;
tr[o<<1].add = (tr[o<<1].add + tr[o].add)%p;
tr[o<<1|1].add = (tr[o<<1|1].add + tr[o].add)%p;
int mid = (l+r)>>1;
tr[o<<1].sum = (tr[o<<1].sum + tr[o].add*(mid-l+1) )%p;
tr[o<<1|1].sum = (tr[o<<1|1].sum + tr[o].add*(r-mid) )%p;
tr[o].add = 0;
}
void optadd(int o, int l, int r, int ql, int qr, int k) {
if(ql <= l && r <= qr) {
tr[o].add = (tr[o].add + k)%p;
tr[o].sum = (tr[o].sum + k*(r-l+1))%p;
return ;
}
pushdown(o, l, r);
int mid = (l+r)>>1;
if(ql <= mid) optadd(o<<1, l, mid, ql, qr, k);
if(mid < qr) optadd(o<<1|1, mid+1, r, ql, qr, k);
pushup(o);
}
int query(int o, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return tr[o].sum;
pushdown(o, l, r);
int mid = (l+r)>>1, ans = 0;
if(ql <= mid) ans = (ans + query(o<<1, l, mid, ql, qr))%p;
if(mid < qr) ans = (ans + query(o<<1|1, mid+1, r, ql, qr))%p;
return ans;
}
void ttt_update(int x, int y, int k) {//意为“跳跳跳”
while(a[x].tp != a[y].tp) {
if(a[a[x].tp].deep < a[a[y].tp].deep) swap(x, y);
optadd(1, 1, clock, a[a[x].tp].in, a[x].in, k);
x = a[a[x].tp].fa;
}
if(a[x].deep > a[y].deep) swap(x, y);
optadd(1, 1, clock, a[x].in, a[y].in, k);
}
int ttt_query(int x, int y) {
int ans = 0;
while(a[x].tp != a[y].tp) {
if(a[a[x].tp].deep < a[a[y].tp].deep) swap(x, y);
ans = (ans + query(1, 1, clock, a[a[x].tp].in, a[x].in))%p;
x = a[a[x].tp].fa;
}
if(a[x].deep > a[y].deep) swap(x, y);
ans = (ans + query(1, 1, clock, a[x].in, a[y].in))%p;
return ans;
}
int main() {
scanf("%d%d%d%d",&n,&m,&s,&p);
int x,y;
for(int i = 1; i <= n; i++) scanf("%d",&arr[i]);
for(int i = 1; i < n; i++) {
scanf("%d%d",&x, &y);
add_edge(x, y);
add_edge(y, x);
}
dfs1(s, 0);
dfs2(s, s);
build(1, 1, clock);
// for(int i = 1; i <= n; i++) printf("%d : in = %d, out = %d
tr[%d] = %d
--------
",i, a[i].in , a[i].out , i, tr[i].sum );
int cmd, z;
for(int i = 1; i <= m; i++) {
scanf("%d", &cmd);
if(cmd == 1) {
scanf("%d%d%d",&x, &y, &z);
ttt_update(x, y, z);
} else if(cmd == 2) {
scanf("%d%d",&x, &y);
printf("%d
", ttt_query(x, y));
} else if(cmd == 3) {
scanf("%d%d",&x,&z);
optadd(1, 1, clock, a[x].in, a[x].out, z);
} else {
scanf("%d",&x);
printf("%d
",query(1, 1, clock, a[x].in, a[x].out));
}
}
}