题目大意
有一棵(n)((nleq1.5*10^5))个节点的二叉树,有点权(x),边权(w),(q)((qleq2*10^5))组询问,每组询问给出(u,l,r),求点权在([l,r])的点到点(u)的距离之和,强制在线
题解
边分治:
边分树的每个点记一个数组,记录子树中每个点到重心边的端点的距离,按点的点权排序;查询时直接在点(u)边分树的每个祖先的兄弟的那个数组中二分。复杂度(Theta(n log^2 n))。
树剖+可持久化线段树:
发现点(x)到点(u)的距离相当于(dep(u)+dep(x)-2*dep(lca)),那么就可以把所有符合要求的点的祖先加上这个祖先到它的父亲的边权,这样所有符合要求的点到点(u)的距离是((符合要求的点的深度之和)+dep(u)*(符号要求的点的个数)-2*(u和它的祖先加上了多少))
这个用可持久化线段树维护,将所有点按点权排序,依次加入;查询时用(r版本-(l-1)版本)。复杂度(Theta(n log^2 n))
代码
边分治
#include<algorithm>
#include<cmath>
#include<complex>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define maxn 150010
#define maxm (maxn<<1)
#define view(u,k) for(int k=fir[u];k!=-1;k=nxt[k])
#define LL long long
#define ls son[u][0]
#define rs son[u][1]
#define mi (l+r>>1)
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(LL x)
{
if(x==0){putchar('0'),putchar('
');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('
');
return;
}
struct node{int x,k;}s[maxn];
int n,q,ag[maxn],vis[maxm],siz[maxn],fir[maxn],v[maxm],nxt[maxm],tp,cnt,wt,mnsz,nowsiz,st[20][maxm];
int lg[maxm],dfn[maxn],rt[maxn<<3],nd,ql,qr,son[maxn<<5][2],num[maxn<<5],to[maxn<<3],tim,bac[maxn],trsiz[maxn<<5],tmp;
LL w[maxm],dep[maxn],sum[maxn<<5],key[maxn<<5],dis[maxn],A,ans;
bool cmp(node x,node y){return x.k<y.k;}
void ade(int u1,int v1,int w1){v[cnt]=v1,w[cnt]=w1,nxt[cnt]=fir[u1],fir[u1]=cnt++;}
int LCA(int x,int y)
{
x=dfn[x],y=dfn[y];
if(x>y)swap(x,y);
int len=y-x+1;
return dis[st[lg[len]][x]]<dis[st[lg[len]][y-(1<<lg[len])+1]]?st[lg[len]][x]:st[lg[len]][y-(1<<lg[len])+1];
}
LL dist(int x,int y)
{
int lca=LCA(x,y);
return dis[x]+dis[y]-dis[lca]-dis[lca];
}
void getsiz(int u,int fa)
{
siz[u]=1;s[++tp].k=ag[u],s[tp].x=dep[u];
view(u,k)if(!vis[k]&&!vis[k^1]&&v[k]!=fa)
{
dep[v[k]]=dep[u]+w[k],getsiz(v[k],u),siz[u]+=siz[v[k]];
if(max(siz[v[k]],nowsiz-siz[v[k]])<mnsz)wt=k,mnsz=max(siz[v[k]],nowsiz-siz[v[k]]);;
}
}
int build(int l,int r)
{
int u=++nd;
if(l==r){sum[u]=s[l].x,key[u]=s[l].k,num[u]=s[l].x,trsiz[u]=1;return u;}
if(l<=mi-1)ls=build(l,mi-1);
if(mi+1<=r)rs=build(mi+1,r);
sum[u]=sum[ls]+sum[rs]+s[mi].x,key[u]=s[mi].k,num[u]=s[mi].x,trsiz[u]=trsiz[ls]+trsiz[rs]+1;
return u;
}
void add(int u,int lim,LL f)
{
if(!u)return;
if(key[u]>lim)return add(ls,lim,f);
if(ls)ans+=f*sum[ls],tmp+=f*trsiz[ls];ans+=f*num[u],tmp+=f;return add(rs,lim,f);
}
void getwt(int u,int sumsiz,int tr)
{
if(sumsiz==1){to[tr]=u,bac[u]=tr,rt[tr]=++nd,sum[nd]=num[nd]=0,trsiz[nd]=1,key[nd]=ag[u];return;}
dep[u]=0,tp=0,nowsiz=sumsiz,mnsz=n+1,getsiz(u,0);int now=wt,nxt1=v[now],nxt2=v[now^1];
if(tr!=1)
{
sort(s+1,s+tp+1,cmp),rt[tr]=build(1,tp),to[tr]=u;
}
vis[now]=vis[now^1]=tr;
if(siz[nxt1]>siz[nxt2])getwt(nxt1,nowsiz-siz[nxt2],tr<<1),getwt(nxt2,siz[nxt2],tr<<1|1);
else getwt(nxt2,nowsiz-siz[nxt1],tr<<1),getwt(nxt1,siz[nxt1],tr<<1|1);
}
void getans(int u)
{
int now=bac[u];ans=0;
while(now>>1)
{
tmp=0;
add(rt[now^1],qr,1),add(rt[now^1],ql-1,-1),ans+=(LL)tmp*dist(to[now^1],u);
now>>=1;
}
}
void getdep(int u,int fa)
{
dfn[u]=++tim;st[0][tim]=u;
view(u,k)if(v[k]!=fa){dis[v[k]]=dis[u]+w[k],getdep(v[k],u),st[0][++tim]=u;}
}
int main()
{
memset(fir,-1,sizeof(fir));
n=read(),q=read(),A=read();
rep(i,1,n)ag[i]=read();
rep(i,1,n-1)
{
int x=read(),y=read(),z=read();
ade(x,y,z),ade(y,x,z);
}
getdep(1,0);lg[0]=-1;
rep(i,1,tim)lg[i]=lg[i>>1]+1;
rep(k,1,lg[tim])for(int i=1;i+(1<<k)-1<=tim;++i)
st[k][i]=dis[st[k-1][i]]<dis[st[k-1][i+(1<<(k-1))]]?st[k-1][i]:st[k-1][i+(1<<(k-1))];
getwt(1,n,1);
while(q--)
{
int u=read();LL a=read(),b=read();
ql=min((a+ans)%A,(b+ans)%A),qr=max((a+ans)%A,(b+ans)%A);
getans(u),write(ans);
}
return 0;
}
树剖+可持久化线段树
#include<algorithm>
#include<cmath>
#include<complex>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define maxn 150010
#define maxm (maxn<<1)
#define view(u,k) for(int k=fir[u];k!=-1;k=nxt[k])
#define LL long long
#define ls son[u][0]
#define rs son[u][1]
#define mi (l+r>>1)
#define ls2 (u2<<1)
#define rs2 (u2<<1|1)
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(LL x)
{
if(x==0){putchar('0'),putchar('
');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('
');
return;
}
struct node{int id,ag;}s[maxn];
int n,q,ag[maxn],siz[maxn],fir[maxn],v[maxm],nxt[maxm],cnt;
int fa[maxn],dfn[maxn],top[maxn],rt[maxn],nd,ql,qr,son[maxn<<6][2],wson[maxn],tim,bac[maxn],ver[maxn<<6],rnk[maxn];
LL w[maxm],dep[maxn],sum[maxn],tr[maxn<<6],mk[maxn<<6],ad[maxn<<2],tofa[maxn],A,ans;
void ade(int u1,int v1,int w1){v[cnt]=v1,w[cnt]=w1,nxt[cnt]=fir[u1],fir[u1]=cnt++;}
bool cmp1(node x,node y){return x.ag<y.ag;}
void getson(int u)
{
siz[u]=1;
view(u,k)if(v[k]!=fa[u])
{
fa[v[k]]=u,dep[v[k]]=dep[u]+w[k],tofa[v[k]]=w[k],getson(v[k]),siz[u]+=siz[v[k]];
if(!wson[u]||siz[wson[u]]<siz[v[k]])wson[u]=v[k];
}
}
void gettop(int u,int anc)
{
dfn[u]=++tim,bac[tim]=u,top[u]=anc;
if(wson[u])gettop(wson[u],anc);
view(u,k)if(v[k]!=wson[u]&&v[k]!=fa[u])gettop(v[k],v[k]);
}
void build(int u2,int l,int r){if(l==r){ad[u2]=tofa[bac[l]]<<1;return;}build(ls2,l,mi),build(rs2,mi+1,r),ad[u2]=ad[ls2]+ad[rs2];}
int add(int u,int u2,int l,int r,int x,int y,int vers)
{
int nu;
if(vers==ver[u]&&u)nu=u;
else nu=++nd,ver[nu]=vers;
mk[nu]=mk[u];son[nu][0]=ls,son[nu][1]=rs;
if(x<=l&&r<=y){mk[nu]++;tr[nu]=tr[ls]+tr[rs]+mk[nu]*ad[u2];return nu;}
if(x<=mi)son[nu][0]=add(ls,ls2,l,mi,x,y,vers);
if(y>mi)son[nu][1]=add(rs,rs2,mi+1,r,x,y,vers);
tr[nu]=tr[son[nu][0]]+tr[son[nu][1]]+mk[nu]*ad[u2];
return nu;
}
LL ask(int u,int u2,int l,int r,int x,int y,LL admk)
{
if(!u&&!admk)return 0;
if(x<=l&&r<=y)return tr[u]+ad[u2]*admk;
LL res=0;
if(x<=mi)res=ask(ls,ls2,l,mi,x,y,admk+mk[u]);
if(y>mi)res+=ask(rs,rs2,mi+1,r,x,y,admk+mk[u]);
return res;
}
void addrd(int u,int t1,int t2)
{
while(top[u]!=1)rt[t1]=add(rt[t2],1,1,n,dfn[top[u]],dfn[u],t1),u=fa[top[u]],t2=t1;
rt[t1]=add(rt[t2],1,1,n,dfn[top[u]],dfn[u],t1);
}
int getrnk(int x)
{
int l=1,r=n,ans=0;
while(l<=r){if(s[mi].ag<=x)ans=max(ans,mi),l=mi+1;else r=mi-1;}
return ans;
}
void getans(int u)
{
int pl=getrnk(ql-1),pr=getrnk(qr);ans=sum[pr]-sum[pl]+dep[u]*(LL)(pr-pl);
while(top[u]!=1)ans+=ask(rt[rnk[pl]],1,1,n,dfn[top[u]],dfn[u],0)-ask(rt[rnk[pr]],1,1,n,dfn[top[u]],dfn[u],0),u=fa[top[u]];
ans+=ask(rt[rnk[pl]],1,1,n,dfn[top[u]],dfn[u],0)-ask(rt[rnk[pr]],1,1,n,dfn[top[u]],dfn[u],0);
return;
}
int main()
{
memset(fir,-1,sizeof(fir));
n=read(),q=read(),A=read();
rep(i,1,n)ag[i]=read(),s[i].ag=ag[i],s[i].id=i;
rep(i,1,n-1){int x=read(),y=read(),z=read();ade(x,y,z),ade(y,x,z);}
sort(s+1,s+n+1,cmp1),ver[0]=-1,getson(1),gettop(1,1),build(1,1,n);
rep(i,1,n)
{
sum[i]=sum[i-1]+dep[s[i].id];
if(i!=1&&s[i].ag==s[i-1].ag){rnk[i]=rnk[i-1],addrd(s[i].id,rnk[i],rnk[i]);}
else{rnk[i]=rnk[i-1]+1,addrd(s[i].id,rnk[i],rnk[i-1]);}
}
while(q--)
{
int u=read();LL a=read(),b=read();
ql=min((a+ans)%A,(b+ans)%A),qr=max((a+ans)%A,(b+ans)%A);
getans(u),write(ans);
}
return 0;
}