@
题目链接:luogu P3384 【模板】树链剖分
先上完整代码,变量名解释[1]
#include<cstdio>
#include<algorithm>
#include<iostream>
using namespace std;
typedef long long ll;
#define N 500005
#define RI register int
int tot=0,n,m,rt,md;
int fa[ N ],deep[ N ],head[ N ],size[ N ],son[ N ],id[ N ],w[ N ],nw[ N ],top[ N ];
struct EDGE{
int to,next;
}e[ N ];
inline void add( int from , int to ){
e[ ++ tot ].to = to;
e[ tot ].next = head[ from ];
head[ from ] = tot;
}
template<class T>
inline void read(T &res){
static char ch;T flag = 1;
while( ( ch = getchar() ) < '0' || ch > '9' ) if( ch == '-' ) flag = -1;
res = ch - 48;
while( ( ch = getchar() ) >= '0' && ch <= '9' ) res = res * 10 + ch - 48;
res *= flag;
}
struct NODE{
ll sum,flag;
NODE *ls,*rs;
NODE(){
sum = flag = 0;
ls = rs = NULL;
}
inline void pushdown( int l , int r )
{
if( flag )
{
int midd = ( l + r ) >> 1;
ls->flag += flag;
rs->flag += flag;
ls->sum += flag * ( midd - l + 1 );
rs->sum += flag * ( r - midd );
flag = 0;
}
}
inline void update()
{
sum = ls->sum + rs->sum;
}
}tree[ N * 2 + 5 ],*p = tree,*root;
NODE *build( int l , int r )
{
NODE *nd = ++p;
if( l == r )
{
nd->sum = nw[ l ];
return nd;
}
int mid = ( l + r ) >> 1;
nd->ls = build( l , mid );
nd->rs = build( mid + 1 , r );
nd->update();
return nd;
}
ll sum( int l , int r , int x , int y , NODE *nd )
{
if( x <= l && r <= y )
{
return nd->sum;
}
nd->pushdown( l , r );
int mid = ( l + r ) >> 1;
ll res = 0;
if( x <= mid )
res += sum( l , mid , x , y , nd->ls );
if( y >= mid + 1 )
res += sum( mid + 1 , r , x , y , nd->rs );
return res;
}
void modify( int l , int r , int x , int y , ll add , NODE *nd )
{
if( x <= l && r <= y )
{
nd->sum += ( r - l + 1 ) * add;
nd->flag += add;
return;
}
int mid = ( l + r ) >> 1;
nd->pushdown( l , r );
if( x <= mid )
modify( l , mid , x , y , add , nd->ls );
if( y > mid )
modify( mid + 1 , r , x , y , add , nd->rs );
nd->update();
}
void dfs1( int p ){
size[ p ] = 1;
deep[ p ] = deep[ fa[ p ] ] + 1;
for( int i = head[ p ] ; i ; i = e[ i ].next ){
int k = e[ i ].to;
if( k == fa[ p ] )
continue;
fa[ k ] = p;
dfs1( k );
size[ p ] += size[ k ];
if( size[ son[ p ] ] < size[ k ] || !son[ p ] )
son[ p ] = k;
}
}
void dfs2( int p , int tp ){
id[ p ] = ++tot;
nw[ tot ] = w[ p ];
top[ p ] = tp;
if( son[ p ] )
dfs2( son[ p ] , tp );
for( int i = head[ p ] ; i ; i = e[ i ].next ){
int k = e[ i ].to;
if( k == fa[ p ] || k == son[ p ] )
continue;
dfs2( k , k );
}
}
inline void ope1( int x , int y , ll add ){
while( top[ x ] != top[ y ] ){
if( deep[ top[ x ] ] < deep[ top[ y ] ] )
swap( x , y );
modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root );
x = fa[ top[ x ] ];
}
if( deep[ x ] > deep[ y ] )
swap( x , y );
modify( 1 , n , id[ x ] , id[ y ] , add , root );
}
inline ll ope2( int x , int y ){
ll res = 0;
while( top[ x ] != top[ y ] ){
if( deep[ top[ x ] ] < deep[ top[ y ] ] )
swap( x , y );
res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root );
x = fa[ top[ x ] ];
}
if( deep[ x ] > deep[ y ] )
swap( x , y );
res += sum( 1 , n , id[ x ] , id[ y ] , root );
return res;
}
inline void ope3( int x , int add ){
modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root );
}
inline ll ope4( int x ){
return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root );
}
int main()
{
cin>>n>>m>>rt>>md;
for( RI i = 1 ; i <= n ; i ++ )
read( w[ i ] );
for( RI i = 1 ; i <= n - 1 ; i ++ ){
int x,y;
read( x ),read( y );
add( x , y );
add( y , x );
}
dfs1( rt ),tot = 0;
dfs2( rt , rt );
root = build( 1 , n );
for( RI i = 1 ; i <= m ; i ++ ){
int f;
read( f );
switch( f ){
case 1:{
int x,y;
ll add;
read( x ),read( y ),read( add );
ope1( x , y , add );
break;
}
case 2:{
int x,y;
read( x ),read( y );
printf( "%lld\n" , ope2( x , y ) % md );
break;
}
case 3:{
int x;
ll add;
read( x ),read( add );
ope3( x , add );
break;
}
case 4:{
int x;
read( x );
printf( "%lld\n" , ope4( x ) % md );
break;
}
}
}
return 0;
}
前置知识
请先能够熟练写出线段树并了解\(dfs\)序的性质
预处理
预处理分两次\(dfs\)
第一次处理出各个结点的深度,\(size\),重儿子,父亲。
第二次处理出重链,\(dfs\)序和每个点的\(top\)。
dfs1:
void dfs1( int p ){
size[ p ] = 1;
deep[ p ] = deep[ fa[ p ] ] + 1;
for( int i = head[ p ] ; i ; i = e[ i ].next ){
int k = e[ i ].to;
if( k == fa[ p ] )
continue;
fa[ k ] = p;
dfs1( k );
size[ p ] += size[ k ];
if( size[ son[ p ] ] < size[ k ] || !son[ p ] )
son[ p ] = k;
}
}
dfs2:
void dfs2( int p , int tp ){
id[ p ] = ++tot;//每个点在dfs序里的位置
nw[ tot ] = w[ p ];
top[ p ] = tp;
if( son[ p ] )
dfs2( son[ p ] , tp );//重链
for( int i = head[ p ] ; i ; i = e[ i ].next ){
int k = e[ i ].to;
if( k == fa[ p ] || k == son[ p ] )
continue;
dfs2( k , k );//轻链
}
}
维护
为了更加高效的查询,我们选择用线段树来维护\(dfs\)序(树状数组等数据结构也可)。
没什么技术含量,直接套模板即可。
struct NODE{
ll sum,flag;
NODE *ls,*rs;
NODE(){
sum = flag = 0;
ls = rs = NULL;
}
inline void pushdown( int l , int r )
{
if( flag )
{
int midd = ( l + r ) >> 1;
ls->flag += flag;
rs->flag += flag;
ls->sum += flag * ( midd - l + 1 );
rs->sum += flag * ( r - midd );
flag = 0;
}
}
inline void update()
{
sum = ls->sum + rs->sum;
}
}tree[ N * 2 + 5 ],*p = tree,*root;
NODE *build( int l , int r )
{
NODE *nd = ++p;
if( l == r )
{
nd->sum = nw[ l ];
return nd;
}
int mid = ( l + r ) >> 1;
nd->ls = build( l , mid );
nd->rs = build( mid + 1 , r );
nd->update();
return nd;
}
ll sum( int l , int r , int x , int y , NODE *nd )
{
if( x <= l && r <= y )
{
return nd->sum;
}
nd->pushdown( l , r );
int mid = ( l + r ) >> 1;
ll res = 0;
if( x <= mid )
res += sum( l , mid , x , y , nd->ls );
if( y >= mid + 1 )
res += sum( mid + 1 , r , x , y , nd->rs );
return res;
}
void modify( int l , int r , int x , int y , ll add , NODE *nd )
{
if( x <= l && r <= y )
{
nd->sum += ( r - l + 1 ) * add;
nd->flag += add;
return;
}
int mid = ( l + r ) >> 1;
nd->pushdown( l , r );
if( x <= mid )
modify( l , mid , x , y , add , nd->ls );
if( y > mid )
modify( mid + 1 , r , x , y , add , nd->rs );
nd->update();
}
查询
这是核心操作(敲黑板)。
子树有关操作
子树查询
由于\(dfs\)序的性质,以一个点为根的子树在\(dfs\)序中一定是连续的,所以我们只需要进行一次区间查询,需要查询的区间为:
[根结点在\(dfs\)序中的位置,根结点在\(dfs\)序中的位置+\(size\) - 1 ]
复杂度为\(O(logn)\)
代码如下:
inline ll ope4( int x ){
return sum( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , root );
}
子树修改
同理,进行一次区间修改
复杂度为\(O(logn)\)
代码如下:
inline void ope3( int x , int add ){
modify( 1 , n , id[ x ] , id[ x ] + size[ x ] - 1 , add , root );
}
树链有关操作
这才是树剖的精髓所在啊!(战术后仰 )
这里主要会利用重链在\(dfs\)序中一定是连续的性质,一定要记住,否则你将无法理解接下来的操作
链查询
操作流程:
- 若两个点的top不同,则让top较深的点爬升到它的top的father,每次爬升进行一次区间查询[2],把结果加到res上,直到top相等为止
- 此时两点的top为原来两点的LCA,且其中深度较浅的点就是LCA,再进行一次区间查询即可。
最坏时间复杂度\(O(log_{2}n)\)
代码如下:
inline ll ope2( int x , int y ){
ll res = 0;
while( top[ x ] != top[ y ] ){
if( deep[ top[ x ] ] < deep[ top[ y ] ] )//把x调整为top深度更深的的点
swap( x , y );
res += sum( 1 , n , id[ top[ x ] ] , id[ x ] , root );
x = fa[ top[ x ] ];
}
if( deep[ x ] > deep[ y ] )
swap( x , y );
res += sum( 1 , n , id[ x ] , id[ y ] , root );
return res;
}
链修改
同理,爬升过程一模一样,只需要将链查询的区间查询改为区间修改即可。
最坏时间复杂度O(log2n)
代码如下:
inline void ope1( int x , int y , ll add ){
while( top[ x ] != top[ y ] ){
if( deep[ top[ x ] ] < deep[ top[ y ] ] )
swap( x , y );
modify( 1 , n , id[ top[ x ] ] , id[ x ] , add , root );
x = fa[ top[ x ] ];
}
if( deep[ x ] > deep[ y ] )
swap( x , y );
modify( 1 , n , id[ x ] , id[ y ] , add , root );
}