前置知识:树状数组 差分 树链剖分 LCA
对树上路径经过的点进行操作,实际上是对区间维护一个函数。
开三个树状数组维护函数的三个系数。
都是基本操作,具体看代码注释。
题外话:
上次写树剖还是两年前(? 这几天重新又学了遍 树状数组学习博客 我的树剖板子们
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 100002 #define ll long long int n,m; int sumedge,cnt; int head[N]; int siz[N],dad[N],top[N],son[N],deep[N],tpos[N]; ll tr1[N],tr2[N],tr3[N]; struct Edge { int x,y,nxt; Edge(int x=0,int y=0,int nxt=0):x(x),y(y),nxt(nxt){} }edge[N<<1]; void add(int x,int y) { edge[++sumedge]=Edge(x,y,head[x]); head[x]=sumedge; } void dfs(int x) { siz[x]=1;deep[x]=deep[dad[x]]+1; for(int i=head[x];i;i=edge[i].nxt) { int v=edge[i].y; if(v==dad[x]) continue; dad[v]=x; dfs(v); siz[x]+=siz[v]; } } void dfs_(int x) { int s=0;tpos[x]=++cnt; if(!top[x]) top[x]=x; for(int i=head[x];i;i=edge[i].nxt) { int v=edge[i].y; if(v!=dad[x]&&siz[v]>siz[s]) s=v; } if(s) { top[s]=top[x]; dfs_(s); } for(int i=head[x];i;i=edge[i].nxt) { int v=edge[i].y; if(v!=dad[x]&&v!=s) dfs_(v); } } int LCA(int x,int y) { for(;top[x]!=top[y];) { if(deep[top[x]]>deep[top[y]]) swap(x,y); y=dad[top[y]]; } if(deep[x]>deep[y]) swap(x,y); return x; } int lowbit(int x) { return x&(-x); } void add_tree(ll d[],int x,ll v) { for(int i=x;i<=n;i+=lowbit(i)) { d[i]+=v; } } ll get_sum(ll d[],int x) { ll res=0; for(int i=x;i>=1;i-=lowbit(i)) { res+=d[i]; } return res; } void change(ll d[],int stp,int edp,ll v) { add_tree(d,stp,v); add_tree(d,edp+1,-v); } void update(int x,int y,int len) { int p1=1,p2=len; //两个端点 假设区间[1,2,3,4,5]需要加[1^2,2^2,3^2,4^2,5^2],则p1=1,p2=5; for(;top[x]!=top[y];) { if(deep[top[x]]>deep[top[y]]) //跳x所在的链 { int ed=tpos[x]; int st=tpos[top[x]]; /*对区间[st~ed]操作,对于区间中下标为i的,加上(ed-i+p1)^2=i^2-2*(ed+p1)*i+(ed+p1)^2;*/ change(tr1,st,ed,1); //二次项系数 change(tr2,st,ed,-1ll*(ed+p1));//一次项 change(tr3,st,ed,1ll*(ed+p1)*(ed+p1));//常数项 p1=p1+ed-st+1; //下一个区间开始加的平方数 x=dad[top[x]]; }else{ //跳y所在的链 int ed=tpos[y]; int st=tpos[top[y]]; /*对区间[st~ed]操作,令gg=p2-(ed-st); 对于区间中下标为i的,加上(i-st+gg)^2=i^2+2*(gg-st)+(gg-st)^2;*/ int gg=p2-(ed-st); change(tr1,st,ed,1);//二次项 change(tr2,st,ed,(gg-st));//一次项 change(tr3,st,ed,1ll*(gg-st)*(gg-st));//常数项 y=dad[top[y]]; p2=gg-1; } } if(deep[x]<=deep[y]) //要从x跳到y { int st=tpos[x]; //对区间[st~ed]操作 int ed=tpos[y]; /*对区间[st~ed]操作,对于区间中下标为i的,加上(i-st+p1)^2=i^2+2*(p1-st)*i+(p1-st)*(p1-st);*/ change(tr1,st,ed,1);//二次项 change(tr2,st,ed,p1-st);//一次项 change(tr3,st,ed,1ll*(p1-st)*(p1-st)); //常数项 }else{ int st=tpos[y]; int ed=tpos[x]; /*对区间[st~ed]操作,对于区间中下标为i的,加上(ed-i+p1)^2=i^2-2*(ed+p1)*i+(ed+p1)*(ed+p1);*/ change(tr1,st,ed,1); //二次项 change(tr2,st,ed,-1ll*(ed+p1));//一次项 change(tr3,st,ed,1ll*(ed+p1)*(ed+p1));//常数项 } } int main() { scanf("%d",&n); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1); dfs_(1); scanf("%d",&m); for(int i=1;i<=m;i++) { int od; scanf("%d",&od); if(od==1) { int x,y,lca; scanf("%d%d",&x,&y); lca=LCA(x,y); int len=deep[x]+deep[y]-2*deep[lca]+1; //从x跳到y之间的点的数目 update(x,y,len); }else { int x; scanf("%d",&x); x=tpos[x]; ll aa=get_sum(tr1,x); //二次项系数 ll bb=get_sum(tr2,x); //一次项 ll cc=get_sum(tr3,x); //常数项 ll ans=aa*x*x+2*bb*x+cc; printf("%lld ",ans); } } return 0; }