2019ICPC上海站 F-A Simple Problem On A Tree
题意
给定一颗(n)个结点的树,每个点都有一个点权(a_i),有(Q)次询问,询问有四种:
- (1~u~v~w),将(u)到(v)的路径上的点的点权赋值为(w)。
- (2~u~v~w),将(u)到(v)的路径上的点的点权加上(w)。
- (3~u~v~w),将(u)到(v)的路径上的点的点权乘上(w)。
- (4~u~v),询问(u)到(v)的路径上的点的点权立方和。
分析
((x+w)^3=x^3+3x^2w+3xw^2+w^3,(x+w)^2=x^2+2xw+w^2),这样展开一下,我们就发现只需要在线段树上维护区间立方和(x^3),区间平方和(x^2),区间和(x),对于修改操作需要三个lazy标记,分别记录赋值、加权和、乘积,下推标记的时候要注意优先级,对于赋值操作,要先将另两个标记初始化,再打上赋值标记,对于乘积操作,要把加权和标记也对应的乘上(w)。由于是树上操作,再套个树链剖分就好了。
Code
#include<bits/stdc++.h>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=1e5+10;
const int inf=1e9;
int T,Q,n;
vector<int>g[N];
int sz[N],f[N],d[N],top[N],son[N],p[N],id[N],tot;
ll a[N],sum[N<<2][3],tag[N<<2][3];
ll cal(ll x,int i){
ll y=1;
for(int j=0;j<i;j++) y=y*x%mod;
return y;
}
void pp(int p){
for(int i=0;i<3;i++) sum[p][i]=(sum[p<<1][i]+sum[p<<1|1][i])%mod;
}
void bd(int l,int r,int p){
tag[p][0]=1,tag[p][1]=0,tag[p][2]=-1;
if(l==r){
for(int i=0;i<3;i++) sum[p][i]=cal(a[id[l]],i+1);
return;
}
int mid=l+r>>1;
bd(lson);bd(rson);
pp(p);
}
void change(int l,int r,int p,ll k1,ll k2,ll k3){
if(k3!=-1){
for(int i=0;i<3;i++) sum[p][i]=(r-l+1)*cal(k3,i+1)%mod;
tag[p][0]=1,tag[p][1]=0,tag[p][2]=k3;
}
for(int i=0;i<3;i++) sum[p][i]=sum[p][i]*cal(k1,i+1)%mod;
sum[p][2]=(sum[p][2]+3*sum[p][1]*k2%mod+3*sum[p][0]*cal(k2,2)%mod+(r-l+1)*cal(k2,3)%mod)%mod;
sum[p][1]=(sum[p][1]+2*sum[p][0]*k2%mod+(r-l+1)*cal(k2,2)%mod)%mod;
sum[p][0]=(sum[p][0]+(r-l+1)*k2%mod)%mod;
tag[p][0]=tag[p][0]*k1%mod;
tag[p][1]=(tag[p][1]*k1%mod+k2%mod)%mod;
}
void up(int dl,int dr,int l,int r,int p,ll k1,ll k2,ll k3){
if(l==dl&&r==dr){
change(l,r,p,k1,k2,k3);
return;
}
int mid=l+r>>1;
if(tag[p][0]!=1||tag[p][1]!=0||tag[p][2]!=-1){
change(lson,tag[p][0],tag[p][1],tag[p][2]);
change(rson,tag[p][0],tag[p][1],tag[p][2]);
tag[p][0]=1,tag[p][1]=0,tag[p][2]=-1;
}
if(dr<=mid) up(dl,dr,lson,k1,k2,k3);
else if(dl>mid) up(dl,dr,rson,k1,k2,k3);
else up(dl,mid,lson,k1,k2,k3),up(mid+1,dr,rson,k1,k2,k3);
pp(p);
}
ll qy(int dl,int dr,int l,int r,int p){
if(l==dl&&r==dr) return sum[p][2];
int mid=l+r>>1;
if(tag[p][0]!=1||tag[p][1]!=0||tag[p][2]!=-1){
change(lson,tag[p][0],tag[p][1],tag[p][2]);
change(rson,tag[p][0],tag[p][1],tag[p][2]);
tag[p][0]=1,tag[p][1]=0,tag[p][2]=-1;
}
if(dr<=mid) return qy(dl,dr,lson);
else if(dl>mid) return qy(dl,dr,rson);
else return (qy(dl,mid,lson)+qy(mid+1,dr,rson))%mod;
}
void dfs(int u){
sz[u]=1;d[u]=d[f[u]]+1;
for(int x:g[u]){
if(x==f[u]) continue;
f[x]=u;
dfs(x);
sz[u]+=sz[x];
if(sz[x]>sz[son[u]]) son[u]=x;
}
}
void dfs1(int u,int t){
top[u]=t;p[u]=++tot;id[tot]=u;
if(son[u]) dfs1(son[u],t);
for(int x:g[u]){
if(x==son[u]||x==f[u]) continue;
dfs1(x,x);
}
}
void modify(int x,int y,ll k1,ll k2,ll k3){
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
up(p[top[x]],p[x],1,n,1,k1,k2,k3);
x=f[top[x]];
}
if(d[x]<d[y]) swap(x,y);
up(p[y],p[x],1,n,1,k1,k2,k3);
}
ll solve(int x,int y){
ll ret=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ret=(ret+qy(p[top[x]],p[x],1,n,1))%mod;
x=f[top[x]];
}
if(d[x]<d[y]) swap(x,y);
ret=(ret+qy(p[y],p[x],1,n,1))%mod;
return ret;
}
int main(){
scanf("%d",&T);
for(int cas=1;cas<=T;cas++){
tot=0;
scanf("%d",&n);
for(int i=2,x,y;i<=n;i++){
scanf("%d%d",&x,&y);
g[x].pb(y);g[y].pb(x);
}
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
dfs(1);dfs1(1,1);
bd(1,n,1);
scanf("%d",&Q);
printf("Case #%d:
",cas);
while(Q--){
int op,u,v,w;
scanf("%d%d%d",&op,&u,&v);
if(op==1){
scanf("%d",&w);
modify(u,v,1,0,w);
}else if(op==2){
scanf("%d",&w);
modify(u,v,1,w,-1);
}else if(op==3){
scanf("%d",&w);
modify(u,v,w,0,-1);
}else{
printf("%lld
",solve(u,v));
}
}
for(int i=1;i<=n;i++){
g[i].clear();
son[i]=0;
}
}
return 0;
}