http://poj.org/problem?id=3237 (题目链接)
树链剖分模板题,然而这150+行的程序我调了一天,历经艰辛,终于ac。。
题意
给出一个n个节点的带权树,要求维护操作:1.求出树上两点之间的边权的最大值;2.更改一条边上的权值;3.将树上两点之间的所有边权取各自的相反数。
solution
神奇的树链剖分+线段树维护查询和修改操作。
树链剖分时,我们将每条边的权值转换为除树根外每个节点上的权值(也就是对于每个节点与它父亲的边的权值转换到了自己的权值)。
之后就是标准的树链剖分后跑线段树了,那个全部取相反数的操作其实是一样的,树链剖分相关知识请见http://blog.sina.com.cn/s/blog_7a1746820100wp67.html
编程时请注意细节,邻接表写错了就悲剧了。。
data
#include <ctime> #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; int main() { int i,j,k; freopen("aaa.in","r",stdin);freopen("aaa.in","w",stdout); puts("1"); srand((unsigned)time(NULL)); int n=rand()%1000+5; printf("%d ",n); for(i=1;i<n;i++) printf("%d %d %d ",i+1,rand()%i+1,rand()%14513546); for(i=1;i<n;i++) { int opt=rand()%3; if(opt==0) printf("CHANGE %d %d ",rand()%(n-1)+1,rand()%42534567); else if(opt==1) { int a=0,b=0; while(a==b)a=rand()%n+1,b=rand()%n+1; printf("NEGATE %d %d ",a,b); } else { int a=0,b=0; while(a==b)a=rand()%n+1,b=rand()%n+1; printf("QUERY %d %d ",a,b); } } puts("DONE"); fclose(stdin);fclose(stdout); return 0; }
代码
// poj3237 #include<algorithm> #include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #include<cmath> #define MOD 1000000007 #define inf 2147483640 #define LL long long #define free(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout); using namespace std; inline int getint() { int x=0,f=1;char ch=getchar(); while (ch>'9' || ch<'0') {if (ch=='-') f=-1;ch=getchar();} while (ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();} return x*f; } const int maxn=100010; struct edge {int to,next,w;}e[maxn<<2]; struct tree {int l,r,tag,mn,mx;}tr[maxn<<2]; int pos[maxn],deep[maxn],head[maxn],bin[20],fa[maxn][20],size[maxn],to[maxn],bl[maxn]; int cnt,P,n; void insert(int u,int v,int w) { e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt;e[cnt].w=w; e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt;e[cnt].w=w; } void solve(int &x,int &y) { int t=x;x=-y;y=-t; } void update(int k) { tr[k].mn=min(tr[k<<1].mn,tr[k<<1|1].mn); tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx); } void pushdown(int k) { int l=tr[k].l,r=tr[k].r; if (l==r || !tr[k].tag) return; tr[k].tag=0; tr[k<<1].tag^=1,tr[k<<1|1].tag^=1; solve(tr[k<<1].mn,tr[k<<1].mx); solve(tr[k<<1|1].mn,tr[k<<1|1].mx); } void build(int k,int s,int t) { tr[k].l=s,tr[k].r=t,tr[k].tag=0,tr[k].mn=inf,tr[k].mx=-inf; if (s==t) return; int mid=(s+t)>>1; build(k<<1,s,mid); build(k<<1|1,mid+1,t); } void change(int k,int x,int val) { pushdown(k); int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1; if (l==r) {tr[k].mn=tr[k].mx=val;return;} if (x<=mid) change(k<<1,x,val); else change(k<<1|1,x,val); update(k); } void rever(int k,int x,int y) { pushdown(k); int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1; if (l==x && r==y) {solve(tr[k].mn,tr[k].mx);tr[k].tag=1;return;} if (y<=mid) rever(k<<1,x,y); else if (x>mid) rever(k<<1|1,x,y); else rever(k<<1,x,mid),rever(k<<1|1,mid+1,y); update(k); } int query(int k,int x,int y) { pushdown(k); int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1; if (l==x && r==y) return tr[k].mx; if (y<=mid) return query(k<<1,x,y); else if (x>mid) return query(k<<1|1,x,y); else return max(query(k<<1,x,mid),query(k<<1|1,mid+1,y)); } void dfs1(int x) { size[x]=1; for (int i=1;i<=13;i++) { if (bin[i]<=deep[x]) fa[x][i]=fa[fa[x][i-1]][i-1]; else break; } for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x][0]) { deep[e[i].to]=deep[x]+1; fa[e[i].to][0]=x; dfs1(e[i].to); size[x]+=size[e[i].to]; } } void dfs2(int x,int chain) { /* if (x==22) { ++P; --P; } */ bl[x]=chain; pos[x]=++P; int k=0; for (int i=head[x];i;i=e[i].next) { if (e[i].to!=fa[x][0]) { if (size[e[i].to]>size[k]) k=e[i].to; } else { to[i>>1]=pos[x];//记录每个节点在线段树上的标号 change(1,pos[x],e[i].w);//将权值插入线段树 } } if (!k) return; dfs2(k,chain); for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x][0] && e[i].to!=k) dfs2(e[i].to,e[i].to); } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); int t=deep[x]-deep[y]; for (int i=0;i<=13;i++) if (t&bin[i]) x=fa[x][i]; for (int i=13;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if (x==y) return x; return fa[x][0]; } int solvequery(int x,int f) { int mx=-inf; while (bl[x]!=bl[f]) { mx=max(mx,query(1,pos[bl[x]],pos[x])); x=fa[bl[x]][0]; } if (pos[f]+1<=pos[x]) mx=max(mx,query(1,pos[f]+1,pos[x])); return mx; } void solverever(int x,int f) { while (bl[x]!=bl[f]) { rever(1,pos[bl[x]],pos[x]); x=fa[bl[x]][0]; } if (pos[f]+1<=pos[x]) rever(1,pos[f]+1,pos[x]); } int main() { free("aaa"); int T=getint(); bin[0]=1;for (int i=1;i<15;i++) bin[i]=bin[i-1]<<1; while (T--) { P=0,cnt=1;//便于将边权转为点权 memset(head,0,sizeof(head)); memset(deep,0,sizeof(deep)); memset(fa,0,sizeof(fa)); n=getint(); for (int i=1;i<n;i++) { int u=getint(),v=getint(),w=getint(); insert(u,v,w); } build(1,1,n); dfs1(1); dfs2(1,1); char ch[10]; while (scanf("%s",ch+1)) { if (ch[1]=='D') break; int x=getint(),y=getint(); if (ch[1]=='Q') { int f=lca(x,y); printf("%d ",max(solvequery(x,f),solvequery(y,f))); } if (ch[1]=='C') change(1,to[x],y); if (ch[1]=='N') { int f=lca(x,y); solverever(x,f);solverever(y,f); } } } fclose(stdin);fclose(stdout); return 0; }