Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
1
2
2
10
6
5
6
5
16
正解:树链剖分+线段树
解题报告:
维护树上一条路径上的结点权值最大值或和
没什么好说的,链剖裸题。先树链剖分再根据访问次序建立线段树,用线段树动态维护。
模板题练手。
1 //It is made by jump~ 2 #include <iostream> 3 #include <cstdlib> 4 #include <cstring> 5 #include <cstdio> 6 #include <cmath> 7 #include <algorithm> 8 using namespace std; 9 typedef long long LL; 10 const int MAXN = 30011; 11 const int inf = (1<<30); 12 int n; 13 int total,ecnt; 14 int U,VV; 15 int a[MAXN]; 16 int id[MAXN],pre[MAXN]; 17 int top[MAXN],siz[MAXN],zhongerzi[MAXN],father[MAXN],deep[MAXN]; 18 int next[MAXN*2],to[MAXN*2],first[MAXN]; 19 char ch[8]; 20 21 struct node{ 22 int l,r; 23 int _max;int _sum; 24 }jump[MAXN*4]; 25 26 void link(int x,int y){ next[++ecnt]=first[x]; first[x]=ecnt; to[ecnt]=y; } 27 28 int getint() 29 { 30 int w=0,q=0; 31 char c=getchar(); 32 while((c<'0' || c>'9') && c!='-') c=getchar(); 33 if (c=='-') q=1, c=getchar(); 34 while (c>='0' && c<='9') w=w*10+c-'0', c=getchar(); 35 return q ? -w : w; 36 } 37 38 void build(int root,int l,int r){ 39 jump[root].l=l;jump[root].r=r; 40 if(jump[root].l==jump[root].r) { 41 jump[root]._sum=jump[root]._max=a[ pre[l] ]; 42 return ; 43 } 44 int lc=root*2,rc=root*2+1; 45 int mid=l+(r-l)/2; 46 build(lc,l,mid); build(rc,mid+1,r); 47 jump[root]._sum=jump[lc]._sum+jump[rc]._sum; 48 jump[root]._max=max(jump[lc]._max,jump[rc]._max); 49 } 50 51 void dfs1(int u,int fa){ 52 siz[u]=1; 53 for(int i=first[u];i;i=next[i]) { 54 int v=to[i]; 55 if(v!=fa) { 56 father[v]=u; 57 deep[v]=deep[u]+1; 58 dfs1(v,u); 59 siz[u]+=siz[v]; 60 if(siz[v]>siz[ zhongerzi[u] ]) zhongerzi[u]=v; 61 } 62 } 63 } 64 65 void dfs2(int u,int fa){ 66 id[u]=++total; pre[total]=u; 67 if(zhongerzi[u]) top[zhongerzi[u]]=top[u],dfs2(zhongerzi[u],u); 68 for(int i=first[u];i;i=next[i]) { 69 int v=to[i]; 70 if(v==fa || v==zhongerzi[u]) continue; 71 top[v]=v; 72 dfs2(v,u); 73 } 74 } 75 76 int query_sum(int root,int x,int y){ 77 if(jump[root].l>=x && jump[root].r<=y) return jump[root]._sum; 78 int da=0; 79 int mid=jump[root].l+(jump[root].r-jump[root].l)/2; 80 int lc=root*2,rc=root*2+1; 81 if(x<=mid) da+=query_sum(lc,x,y); 82 if(y>mid) da+=query_sum(rc,x,y); 83 return da; 84 } 85 86 87 int query_max(int root,int x,int y){ 88 if(jump[root].l>=x && jump[root].r<=y) return jump[root]._max; 89 int da=-inf; 90 int mid=jump[root].l+(jump[root].r-jump[root].l)/2; 91 int lc=root*2,rc=root*2+1; 92 if(x<=mid) da=max(da,query_max(lc,x,y)); 93 if(y>mid) da=max(da,query_max(rc,x,y)); 94 return da; 95 } 96 97 int find_max(int x,int y){ 98 int f1=top[x],f2=top[y]; 99 int daan=-inf; 100 while(f1!=f2){ 101 if(deep[f1]<deep[f2]) swap(f1,f2),swap(x,y); 102 daan=max(daan,query_max(1,id[f1],id[x])); 103 x=father[f1]; 104 f1=top[x]; 105 } 106 if(deep[x]<deep[y]) swap(x,y); 107 daan=max(daan,query_max(1,id[y],id[x])); 108 return daan; 109 } 110 111 int find_sum(int x,int y){ 112 int f1=top[x],f2=top[y]; 113 int daan=0; 114 while(f1!=f2){ 115 if(deep[f1]<deep[f2]) swap(f1,f2),swap(x,y); 116 daan+=query_sum(1,id[f1],id[x]); 117 x=father[f1]; f1=top[x]; 118 } 119 if(deep[x]<deep[y]) swap(x,y); 120 daan+=query_sum(1,id[y],id[x]); 121 return daan; 122 } 123 124 void update(int root,int o,int add){ 125 if(jump[root].l==jump[root].r){ 126 jump[root]._sum+=add; 127 jump[root]._max+=add;return ; 128 } 129 int lc=root*2,rc=root*2+1; 130 int mid=jump[root].l+(jump[root].r-jump[root].l)/2; 131 if(o<=mid) update(lc,o,add); else update(rc,o,add); 132 jump[root]._sum=jump[lc]._sum+jump[rc]._sum; 133 jump[root]._max=max(jump[lc]._max,jump[rc]._max); 134 } 135 136 int main() 137 { 138 n=getint(); 139 int x,y; 140 for(int i=1;i<n;i++){ 141 x=getint();y=getint(); 142 next[++ecnt]=first[x]; first[x]=ecnt; to[ecnt]=y; 143 next[++ecnt]=first[y]; first[y]=ecnt; to[ecnt]=x; 144 } 145 146 deep[1]=1; dfs1(1,0); 147 top[1]=1; dfs2(1,0); 148 149 for(int i=1;i<=n;i++) a[i]=getint(); 150 build(1,1,n); 151 int Q=getint(); 152 153 for(int i=1;i<=Q;i++){ 154 scanf("%s",ch); 155 if(ch[1]=='M'){ 156 printf("%d ",find_max(x,y)); 157 } 158 else if(ch[1]=='S'){ 159 x=getint();y=getint(); 160 printf("%d ",find_sum(x,y)); 161 } 162 else{ 163 U=getint();VV=getint(); 164 update(1,id[U],VV-a[U]);a[U]=VV; 165 } 166 } 167 return 0; 168 }