题意
树链上带修改第K大
思路
树链剖分+线段树套平衡树+二分
具体就是先树剖把树映射到区间,然后对区间建线段树,线段树的每一个节点是一颗权值平衡树,修改就是在线段树上跑一遍,对于经过的线段树节点,其实就是平衡树,做一次删除和一次添加。查询就是二分最终的答案ans,然后计算链上大于ans的节点有多少,这一计算只需要在线段树上跑一遍,对于经过的线段树节点计算对应的平衡树里有多少大于ans的节点再加起来就可以了。
二分一个(logn),跑树链一个(logn),跑线段树一个(logn),查询平衡树一个(logn),一共(O(mlog^4n))。我的代码跑了将近46s
树链剖分+树状数组套主席树
类似query on a tree,借助lca的性质来维护树链信息,然后由于带修改,所以在外套一个树状数组来处理修改。每次询问需要对(logn)颗主席树进行修改,一次修改(logn),一共(O(mlog^2n))。
WA到自闭,对着空气改了半天,最后发现是数组开小了
跑了大概8s
整体二分
emmm,好像也没有多难,就是整体二分的板子改改。。。
跑了4s
AC代码:线段树套平衡树
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=4e5+5;
const int INF=1e9+7;
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x<<3) + (x<<1) + ch - 48; ch = getchar(); }
return x * f;
}
struct Treap{
const static int N=maxn<<2;
int L[N],R[N],v[N],p[N],A[N],C[N],tot;
void init(){A[0]=L[0]=R[0]=C[0]=0;tot=1;}
int newnode(int V,int P)
{
L[tot]=R[tot]=0;
v[tot]=V;p[tot]=P;
A[tot]=C[tot]=1;
return tot++;
}
void Count(int x){C[x]=A[x]+C[L[x]]+C[R[x]];}
void rotate_right(int &x)
{
int y=L[x];L[x]=R[y];R[y]=x;C[y]=C[x];Count(x);x=y;
}
void rotate_left(int &x)
{
int y=R[x];R[x]=L[y];L[y]=x;C[y]=C[x];Count(x);x=y;
}
void insert(int &x,int V,int P)
{
if(!x){x=newnode(V,P);return;}
if(v[x]==V)++A[x];
else if(V<v[x])
{
insert(L[x],V,P);
if(p[x]>p[L[x]])rotate_right(x);
}
else
{
insert(R[x],V,P);
if(p[x]>p[R[x]])rotate_left(x);
}
Count(x);
}
void Delete(int &x,int V)
{
if(!x)return;
if(V<v[x])Delete(L[x],V);
else if(V>v[x])Delete(R[x],V);
else if(A[x]>1)--A[x];
else if(!L[x] || !R[x])x=L[x]+R[x];
else if(p[L[x]]<p[R[x]]){rotate_right(x);Delete(R[x],V);}
else{rotate_left(x);Delete(L[x],V);}
Count(x);
}
void del(int &x,int V){Delete(x,V);}
void add(int &x,int V){insert(x,V,rand());}
int getrank(int x,int V)
{
int ans=0;
while(x)
{
if(V==v[x]){
ans+=C[R[x]];
break;
}
else if(V<v[x]){
ans+=(C[R[x]]+A[x]);
x=L[x];
}
else x=R[x];
}
return ans;
}
}treap;
struct Segment_Tree{
int rt[maxn<<2];
void init(){
memset(rt,0,sizeof(rt));
}
void update(int x,int l,int r,int pos,int preV,int V){
treap.del(rt[x],preV);
treap.add(rt[x],V);
if(l==r)return;
int mid=(l+r)/2;
if(pos<=mid)update(x<<1,l,mid,pos,preV,V);
else update(x<<1|1,mid+1,r,pos,preV,V);
}
int query_rank(int x,int l,int r,int L,int R,int V){
if(L==l && R==r){
return treap.getrank(rt[x],V);
}
int mid=(l+r)/2;
if(R<=mid)return query_rank(x<<1,l,mid,L,R,V);
else if(L>mid)return query_rank(x<<1|1,mid+1,r,L,R,V);
else return query_rank(x<<1,l,mid,L,mid,V)+query_rank(x<<1|1,mid+1,r,mid+1,R,V);
}
}segtree;
int w[maxn];
int n,q;
int tot,head[maxn];
struct Edge{
int v,nxt;
}e[maxn<<1];
inline void init(){
tot=0;
memset(head,-1,sizeof(head));
}
inline void addedge(int u,int v){
e[tot].v=v;e[tot].nxt=head[u];
head[u]=tot++;
e[tot].v=u;e[tot].nxt=head[v];
head[v]=tot++;
}
int sz[maxn],son[maxn],fa[maxn],h[maxn],A[maxn],pos[maxn],top[maxn],cnt;
void dfs1(int u,int f){
int v;
sz[u]=1;son[u]=0;fa[u]=f;h[u]=h[f]+1;
for(int i=head[u];i!=-1;i=e[i].nxt){
v=e[i].v;
if(v==f)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[son[u]]<sz[v])son[u]=v;
}
}
void dfs2(int u,int f,int k){
int v;
top[u]=k;
pos[u]=++cnt;
A[cnt]=w[u];
if(son[u])dfs2(son[u],u,k);
for(int i=head[u];i!=-1;i=e[i].nxt){
v=e[i].v;
if(v==f)continue;
if(v==son[u])continue;
dfs2(v,u,v);
}
}
int LCA(int u,int v){
while(top[u]!=top[v]){
if(h[top[u]]<h[top[v]])swap(u,v);
u=fa[top[u]];
}
if(h[u]>h[v])swap(u,v);
return u;
}
int query(int u,int v,int ww){
int ans=0;
while(top[u]!=top[v]){
if(h[top[u]]<h[top[v]])swap(u,v);
ans+=segtree.query_rank(1,1,n,pos[top[u]],pos[u],ww);
u=fa[top[u]];
}
if(h[u]>h[v])swap(u,v);
ans+=segtree.query_rank(1,1,n,pos[u],pos[v],ww);
return ans;
}
int t[maxn<<1],m;
int k[maxn],a[maxn],b[maxn];
void unik(){
sort(t+1,t+1+m);
m=unique(t+1,t+1+m)-(t+1);
}
int Hash(int x){
return lower_bound(t+1,t+1+m,x)-t;
}
int solve(int u,int v,int k){
int lca=LCA(u,v);
int sz=h[u]+h[v]-2*h[lca]+1;
if(sz<k)return -1;
int l=1,r=m,mid,ans;
while(l<=r){
int mid=(l+r)/2;
int rk=query(u,lca,mid)+query(v,lca,mid);
if(w[lca]>mid)rk--;
if(rk<=k-1){
r=mid-1;
ans=mid;
}else{
l=mid+1;
}
}
return t[ans];
}
int main()
{
n=read();q=read();
init();
int u,v,ww;
for(int i=1;i<=n;i++)w[i]=read(),t[++m]=w[i];
for(int i=1;i<n;i++){
u=read();v=read();
addedge(u,v);
}
cnt=0;
dfs1(1,0);
dfs2(1,0,1);
for(int i=1;i<=q;i++){
k[i]=read();a[i]=read();b[i]=read();
if(k[i]==0)t[++m]=b[i];
}
unik();
for(int i=1;i<=n;i++)w[i]=Hash(w[i]);
segtree.init();treap.init();
for(int i=1;i<=n;i++)
segtree.update(1,1,n,pos[i],0,w[i]);
int ans;
for(int i=1;i<=q;i++){
if(k[i]==0){
int V=Hash(b[i]);
segtree.update(1,1,n,pos[a[i]],w[a[i]],V);
w[a[i]]=V;
}
else{
ans=solve(a[i],b[i],k[i]);
if(ans!=-1)printf("%d
",ans);
else printf("invalid request!
");
}
}
return 0;
}
AC代码:树状数组套主席树
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=16e4+5;
const int maxn=N;
int n,m;
int w[maxn],t[maxn],vnum;
struct Query{
int k,a,b;
}q[maxn];
int addT[maxn],anum,subT[maxn],snum;
int T[maxn],L[maxn<<7],R[maxn<<7],sum[maxn<<7],tnum;
void update(int &rt,int l,int r,int x,int d){
if(!rt)rt=++tnum;
sum[rt]+=d;
if(l<r){
int mid=(l+r)/2;
if(x<=mid)update(L[rt],l,mid,x,d);
else update(R[rt],mid+1,r,x,d);
}
}
int query(int l,int r,int k){
if(l==r)return l;
int tmp=0;
for(int i=1;i<=anum;i++)tmp+=sum[R[addT[i]]];
for(int i=1;i<=snum;i++)tmp-=sum[R[subT[i]]];
int mid=(l+r)/2;
if(k<=tmp){
for(int i=1;i<=anum;i++)addT[i]=R[addT[i]];
for(int i=1;i<=snum;i++)subT[i]=R[subT[i]];
return query(mid+1,r,k);
}
else{
for(int i=1;i<=anum;i++)addT[i]=L[addT[i]];
for(int i=1;i<=snum;i++)subT[i]=L[subT[i]];
return query(l,mid,k-tmp);
}
}
int lowbit(int x){return x&(-x);}
void add(int x,int val,int d){
for(;x<=n;x+=lowbit(x))
update(T[x],1,vnum,val,d);
}
int head[maxn],tot=1;
struct Edge{
int v,nxt;
}e[maxn<<1];
void addedge(int u,int v){
e[tot].v=v;e[tot].nxt=head[u];head[u]=tot++;
e[tot].v=u;e[tot].nxt=head[v];head[v]=tot++;
}
int sz[maxn],son[maxn],fa[maxn],h[maxn],pos[maxn],top[maxn],low[maxn],cnt;
void dfs1(int u,int f){
sz[u]=1;son[u]=0;fa[u]=f;h[u]=h[f]+1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==f)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[son[u]]<sz[v])son[u]=v;
}
}
void dfs2(int u,int f,int k){
top[u]=k;pos[u]=++cnt;
if(son[u])dfs2(son[u],u,k);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==f || v==son[u])continue;
dfs2(v,u,v);
}
low[u]=cnt;
}
int LCA(int u,int v){
while(top[u]!=top[v]){
if(h[top[u]]<h[top[v]])swap(u,v);
u=fa[top[u]];
}
if(h[u]>h[v])swap(u,v);
return u;
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&w[i]),t[++vnum]=w[i];
int u,v,k;
for(int i=1;i<=n-1;i++){
scanf("%d %d",&u,&v);
addedge(u,v);
}
for(int i=1;i<=m;i++){
scanf("%d %d %d",&q[i].k,&q[i].a,&q[i].b);
if(!q[i].k)t[++vnum]=q[i].b;
}
sort(t+1,t+1+vnum);
vnum=unique(t+1,t+1+vnum)-(t+1);
for(int i=1;i<=n;i++)w[i]=lower_bound(t+1,t+1+vnum,w[i])-t;
for(int i=1;i<=m;i++)if(!q[i].k)q[i].b=lower_bound(t+1,t+1+vnum,q[i].b)-t;
dfs1(1,0);dfs2(1,0,1);
for(int i=1;i<=n;i++)add(pos[i],w[i],1),add(low[i]+1,w[i],-1);
for(int i=1;i<=m;i++){
u=q[i].a;v=q[i].b;k=q[i].k;
if(k){
int lca=LCA(u,v);
if(h[u]+h[v]-h[lca]+-h[fa[lca]]<k){
printf("invalid request!
");
continue;
}
anum=snum=0;
for(int i=pos[u];i;i-=lowbit(i))addT[++anum]=T[i];
for(int i=pos[v];i;i-=lowbit(i))addT[++anum]=T[i];
for(int i=pos[lca];i;i-=lowbit(i))subT[++snum]=T[i];
for(int i=pos[fa[lca]];i;i-=lowbit(i))subT[++snum]=T[i];
printf("%d
",t[query(1,vnum,k)]);
}
else{
add(pos[u],w[u],-1);add(low[u]+1,w[u],1);
w[u]=v;
add(pos[u],w[u],1);add(low[u]+1,w[u],-1);
}
}
return 0;
}
AC代码:整体二分
#include<bits/stdc++.h>
using namespace std;
const int maxn=8e4+5;
int n,q;
int w[maxn],ans[maxn];
struct Oprate{
int id,k,a,b,val,type;
}op[maxn<<1],tmp1[maxn<<1],tmp2[maxn<<1];
struct BIT{
int c[maxn];
inline int lb(int x){return x&(-x);}
inline void add(int x,int d){for(;x<=n;x+=lb(x))c[x]+=d;}
inline int getsum(int x){int r=0;for(;x;x-=lb(x))r+=c[x];return r;}
inline int getsum(int l,int r){return getsum(r)-getsum(l-1);}
}bit;
//graph
struct Edge{
int v,nxt;
}e[maxn<<1];
int head[maxn],tot;
void init(){
tot=1;
memset(head,0,sizeof(head));
}
void addedge(int u,int v){
e[tot].v=v;e[tot].nxt=head[u];head[u]=tot++;
e[tot].v=u;e[tot].nxt=head[v];head[v]=tot++;
}
//heavy-light
int sz[maxn],son[maxn],fa[maxn],h[maxn],pos[maxn],top[maxn],cnt;
void dfs1(int u,int f){
sz[u]=1;son[u]=0;fa[u]=f;h[u]=h[f]+1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==f)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[son[u]]<sz[v])son[u]=v;
}
}
void dfs2(int u,int f,int k){
top[u]=k;
pos[u]=++cnt;
if(son[u])dfs2(son[u],u,k);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==f || v==son[u])continue;
dfs2(v,u,v);
}
}
int LCA(int u,int v){
while(top[u]!=top[v]){
if(h[top[u]]<h[top[v]])swap(u,v);
u=fa[top[u]];
}
if(h[u]>h[v])swap(u,v);
return u;
}
int query(int u,int v){
int res=0;
while(top[u]!=top[v]){
if(h[top[u]]<h[top[v]])swap(u,v);
res+=bit.getsum(pos[top[u]],pos[u]);
u=fa[top[u]];
}
if(h[u]>h[v])swap(u,v);
res+=bit.getsum(pos[u],pos[v]);
return res;
}
//
void solve(int l,int r,int L,int R)
{
if(l>r || L>R)return;
if(l==r){
for(int i=L;i<=R;i++)
if(op[i].type==0 && ans[op[i].id]!=-1)ans[op[i].id]=l;
return;
}
int mid=(l+r)>>1,t1=0,t2=0;
for(int i=L;i<=R;i++){
if(op[i].type==0){
int tmp=query(op[i].a,op[i].b);
if(tmp>=op[i].k)tmp2[++t2]=op[i];
else op[i].k-=tmp,tmp1[++t1]=op[i];
}
else{
if(op[i].b>mid)bit.add(pos[op[i].a],op[i].val),tmp2[++t2]=op[i];
else tmp1[++t1]=op[i];
}
}
for(int i=1;i<=t2;i++)if(tmp2[i].type==1 && tmp2[i].b>mid)
bit.add(pos[tmp2[i].a],-tmp2[i].val);
for(int i=1;i<=t1;i++)op[L+i-1]=tmp1[i];
for(int i=1;i<=t2;i++)op[L+t1+i-1]=tmp2[i];
solve(l,mid,L,L+t1-1); solve(mid+1,r,L+t1,R);
}
int main()
{
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
int nn=0;
scanf("%d %d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%d",&w[i]);
op[++nn]=(Oprate){0,0,i,w[i],1,1};
}
int u,v;
init();
for(int i=1;i<=n-1;i++){
scanf("%d %d",&u,&v);
addedge(u,v);
}
dfs1(1,0);dfs2(1,0,1);
int k,a,b,id=0;
for(int i=1;i<=q;i++){
scanf("%d %d %d",&k,&a,&b);
if(k){
op[++nn]=(Oprate){++id,k,a,b,0,0};
int lca=LCA(a,b);
int len=h[a]+h[b]-h[lca]-h[fa[lca]];
if(len<k)ans[id]=-1;
}
else{
op[++nn]=(Oprate){0,k,a,w[a],-1,1};
w[a]=b;
op[++nn]=(Oprate){0,k,a,w[a],1,1};
}
}
solve(1,1e8,1,nn);
for(int i=1;i<=id;i++){
if(ans[i]==-1)printf("invalid request!
");
else printf("%d
",ans[i]);
}
return 0;
}