传送门
(log)多似乎不是什么问题,最容易想到的就是4个(log)的树剖+树套树吧
衷心希望大家别写线段树套线段树,这个东西真的极容易MLE(这题不会)以及TLE(这题在洛谷会TLE1个点,bzoj可以AC)
这题本质就是树上路径动态求第k大经过的点值,处理路径问题用树链剖分就好了,动态第k大用树套树,建议树状数组套线段树或者套平衡树也行,由于树链剖分会将一个路径划分成多段,信息无法合并,还需要二分答案
代码(线段树套线段树):
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<map>
using namespace std;
void read(int &x) {
char ch; bool ok;
for(ok=0,ch=getchar(); !isdigit(ch); ch=getchar()) if(ch=='-') ok=1;
for(x=0; isdigit(ch); x=x*10+ch-'0',ch=getchar()); if(ok) x=-x;
}
#define rg register
const int maxn=1e5+10;map<int,int>mp;
int tmp,n,q,tot,a[maxn],num,now,nmp[maxn*2],id[maxn],cnt,pre[maxn*2],nxt[maxn*2],h[maxn],w[maxn*2],f[maxn];
int ans,k[maxn],x[maxn],y[maxn],top[maxn],size[maxn],dep[maxn],sum[maxn*200],ls[maxn*200],rt[maxn*4],rs[maxn*200];
struct segment_tree
{
void update(int x){sum[x]=sum[ls[x]]+sum[rs[x]];}
void change(int &k,int l,int r,int a,int b)
{
if(!k)k=++tmp;int mid=(l+r)>>1;
if(l==r){sum[k]+=b;return ;}
if(a<=mid)change(ls[k],l,mid,a,b);
else change(rs[k],mid+1,r,a,b);
update(k);
}
int get(int x,int l,int r,int a,int b)
{
if(!x)return 0;
int mid=(l+r)>>1,ans=0;
if(a<=l&&b>=r)return sum[x];
if(a<=mid)ans+=get(ls[x],l,mid,a,b);
if(b>mid)ans+=get(rs[x],mid+1,r,a,b);
return ans;
}
}s[maxn*4];
void add(int x,int y)
{
pre[++cnt]=y,nxt[cnt]=h[x],h[x]=cnt;
pre[++cnt]=x,nxt[cnt]=h[y],h[y]=cnt;
}
void dfs(int x,int fa)
{
size[x]=1,f[x]=fa;
for(rg int i=h[x];i;i=nxt[i])
if(pre[i]!=fa)dep[pre[i]]=dep[x]+1,dfs(pre[i],x),size[x]+=size[pre[i]];
}
void dfs1(int x,int f)
{
id[x]=++now,top[x]=f;int k=0;
for(rg int i=h[x];i;i=nxt[i])
if(dep[pre[i]]>dep[x]&&size[pre[i]]>size[k])k=pre[i];
if(!k)return ;dfs1(k,f);
for(rg int i=h[x];i;i=nxt[i])
if(dep[pre[i]]>dep[x]&&pre[i]!=k)dfs1(pre[i],pre[i]);
}
void change(int x,int l,int r,int a,int b,int c)
{
s[x].change(rt[x],1,n,b,c);
if(l==r)return ;int mid=(l+r)>>1;
if(a<=mid)change(x<<1,l,mid,a,b,c);
else change(x<<1|1,mid+1,r,a,b,c);
}
int get(int x,int l,int r,int a,int b,int c,int d)
{
if(a<=l&&b>=r)return s[x].get(rt[x],1,n,c,d);
int mid=(l+r)>>1,ans=0;
if(a<=mid)ans+=get(x<<1,l,mid,a,b,c,d);
if(b>mid)ans+=get(x<<1|1,mid+1,r,a,b,c,d);
return ans;
}
int qsum(int x,int y)
{
int sum=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
sum+=get(1,1,num,1,num,id[top[x]],id[x]);
x=f[top[x]];
}
if(id[x]>id[y])swap(x,y);
sum+=get(1,1,num,1,num,id[x],id[y]);
return sum;
}
bool check(int mid,int x,int y,int v)
{
int sum=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
sum+=get(1,1,num,1,mid,id[top[x]],id[x]);
if(sum>=v)return 1;
x=f[top[x]];
}
if(id[x]>id[y])swap(x,y);
sum+=get(1,1,num,1,mid,id[x],id[y]);
return sum>=v;
}
int main()
{
read(n),read(q);tot=n;
for(rg int i=1;i<=n;i++)read(a[i]),w[i]=a[i];
for(rg int i=1,l,r;i<n;i++)read(l),read(r),add(l,r);
for(rg int i=1;i<=q;i++){read(k[i]),read(x[i]),read(y[i]);if(!k[i])w[++tot]=y[i];}
sort(w+1,w+tot+1);for(rg int i=1;i<=tot;i++)if(w[i]!=w[i-1])mp[w[i]]=++num,nmp[num]=w[i];
dfs(1,0),dfs1(1,1);
for(rg int i=1;i<=n;i++)change(1,1,num,mp[a[i]],id[i],1);
for(rg int i=1;i<=q;i++)
{
if(!k[i])
change(1,1,num,mp[a[x[i]]],id[x[i]],-1),a[x[i]]=y[i],
change(1,1,num,mp[a[x[i]]],id[x[i]],1);
else
{
int l=1,r=num,g=qsum(x[i],y[i]);
if(g<k[i]){printf("invalid request!
");continue;}
while(l<=r)
{
int mid=(l+r)>>1;
if(check(mid,x[i],y[i],g-k[i]+1))r=mid-1;
else l=mid+1;
}
printf("%d
",nmp[l]);
}
}
}
代码(树状数组套线段树):
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<map>
using namespace std;
void read(int &x) {
char ch; bool ok;
for(ok=0,ch=getchar(); !isdigit(ch); ch=getchar()) if(ch=='-') ok=1;
for(x=0; isdigit(ch); x=x*10+ch-'0',ch=getchar()); if(ok) x=-x;
}
#define rg register
#define lowbit(i) (i&(-i))
const int maxn=1e5+10;map<int,int>mp;
int tmp,n,q,tot,a[maxn],num,now,nmp[maxn*2],id[maxn],cnt,pre[maxn*2],nxt[maxn*2],h[maxn],w[maxn*2],f[maxn];
int ans,k[maxn],x[maxn],y[maxn],top[maxn],size[maxn],dep[maxn],sum[maxn*100],ls[maxn*100],rt[maxn],rs[maxn*100];
struct segment_tree
{
inline void update(int x){sum[x]=sum[ls[x]]+sum[rs[x]];}
inline void change(int &k,int l,int r,int a,int b)
{
if(!k)k=++tmp;int mid=(l+r)>>1;
if(l==r){sum[k]+=b;return ;}
if(a<=mid)change(ls[k],l,mid,a,b);
else change(rs[k],mid+1,r,a,b);
update(k);
}
inline int get(int x,int l,int r,int a,int b)
{
if(!x)return 0;
int mid=(l+r)>>1,ans=0;
if(a<=l&&b>=r)return sum[x];
if(a<=mid)ans+=get(ls[x],l,mid,a,b);
if(b>mid)ans+=get(rs[x],mid+1,r,a,b);
return ans;
}
}s[maxn];
inline void add(int x,int y)
{
pre[++cnt]=y,nxt[cnt]=h[x],h[x]=cnt;
pre[++cnt]=x,nxt[cnt]=h[y],h[y]=cnt;
}
inline void dfs(int x,int fa)
{
size[x]=1,f[x]=fa;
for(rg int i=h[x];i;i=nxt[i])
if(pre[i]!=fa)dep[pre[i]]=dep[x]+1,dfs(pre[i],x),size[x]+=size[pre[i]];
}
inline void dfs1(int x,int f)
{
id[x]=++now,top[x]=f;int k=0;
for(rg int i=h[x];i;i=nxt[i])
if(dep[pre[i]]>dep[x]&&size[pre[i]]>size[k])k=pre[i];
if(!k)return ;dfs1(k,f);
for(rg int i=h[x];i;i=nxt[i])
if(dep[pre[i]]>dep[x]&&pre[i]!=k)dfs1(pre[i],pre[i]);
}
inline int qsum(int x,int y)
{
int sum=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
for(rg int i=num;i;i-=lowbit(i))sum+=s[i].get(rt[i],1,n,id[top[x]],id[x]);
x=f[top[x]];
}
if(id[x]>id[y])swap(x,y);
for(rg int i=num;i;i-=lowbit(i))sum+=s[i].get(rt[i],1,n,id[x],id[y]);
return sum;
}
inline bool check(int mid,int x,int y,int v)
{
int sum=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
for(rg int i=mid;i;i-=lowbit(i))sum+=s[i].get(rt[i],1,n,id[top[x]],id[x]);
if(sum>=v)return 1;
x=f[top[x]];
}
if(id[x]>id[y])swap(x,y);
for(rg int i=mid;i;i-=lowbit(i))sum+=s[i].get(rt[i],1,n,id[x],id[y]);
return sum>=v;
}
int main()
{
read(n),read(q);tot=n;
for(rg int i=1;i<=n;i++)read(a[i]),w[i]=a[i];
for(rg int i=1,l,r;i<n;i++)read(l),read(r),add(l,r);
for(rg int i=1;i<=q;i++){read(k[i]),read(x[i]),read(y[i]);if(!k[i])w[++tot]=y[i];}
sort(w+1,w+tot+1);for(rg int i=1;i<=tot;i++)if(w[i]!=w[i-1])mp[w[i]]=++num,nmp[num]=w[i];
dfs(1,0),dfs1(1,1);
for(rg int i=1;i<=n;i++)
{
a[i]=mp[a[i]];
for(rg int j=a[i];j<=num;j+=lowbit(j))s[j].change(rt[j],1,n,id[i],1);
}
for(rg int i=1;i<=q;i++)
{
if(!k[i])
{
for(rg int j=a[x[i]];j<=num;j+=lowbit(j))s[j].change(rt[j],1,n,id[x[i]],-1);
a[x[i]]=mp[y[i]];
for(rg int j=a[x[i]];j<=num;j+=lowbit(j))s[j].change(rt[j],1,n,id[x[i]],1);
}
else
{
int l=1,r=num,g=qsum(x[i],y[i]);
if(g<k[i]){printf("invalid request!
");continue;}
while(l<=r)
{
int mid=(l+r)>>1;
if(check(mid,x[i],y[i],g-k[i]+1))r=mid-1;
else l=mid+1;
}
printf("%d
",nmp[l]);
}
}
}