• [BJOI2017]树的难题


    IX.[BJOI2017]树的难题

    debug三天,精神崩溃

    论一行if(vis[v[x][r].second]){r++;continue;}忘记加上后所有代码全都莫名其妙TLE且查不出锅的痛苦

    首先,我们考虑常规淀粉质。

    我们考虑一条路径,它会被(淀粉质的分治根)截成两段。如果我们对于分治树中的每一个节点,预处理出来它到树根的路径权值,记为\(sum_x\),则一条完整路径的权值则为\(sum_x+sum_y\)

    稍等,我们好像忘记了一种情况——如果这两条路径顶端的边的颜色相同怎么办

    换句话说,假如某一半路径的颜色段为(顺序为从根节点往下)ABABC,另一半为ABAC,两半拼一起,我们得到CBABAABAC。显然,这个A就被算了两次,应该被减掉。

    因此,这种情况的权值则为\(sum_x+sum_y-val_c\),其中\(val_c\)\(c\)颜色的权值,而\(c\)为两条路径顶端的颜色。

    很明显这两者要分开考虑。

    然后就是求值了。显然,对于半条路径,与它可以拼成完整路径的另一半的长度,是一个区间。所以,我们可以建一棵线段树,以深度为下标,存储所有当前深度的点中,\(sum_x\)的最大值。

    我们需要两棵线段树——一棵用于同种颜色的储存,而另一棵用于全局颜色的储存。在一种颜色全部处理完之后,将当前颜色线段树与全局线段树合并,并清空当前线段树即可。

    明显复杂度为\(O(n\log^2n)\)。由于人傻常数大,它TLE50,即使开O3也一样。

    TLE的线段树代码:

    #pragma GCC optimize(3)
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<vector>
    using namespace std;
    typedef long long ll;
    const ll fni=-1e18;
    #define lson x<<1
    #define rson x<<1|1
    #define mid ((l+r)>>1)
    struct SegTree{
    	ll seg[800100];
    	void init(){
    		for(int i=1;i<=800000;i++)seg[i]=fni;
    	}
    	void pushup(int x){
    		seg[x]=max(seg[lson],seg[rson]);
    	}
    	void modify(int x,int l,int r,int P,ll vl){
    		if(l>P||r<P)return;
    		if(l==r){seg[x]=max(seg[x],vl);return;} 
    		modify(lson,l,mid,P,vl),modify(rson,mid+1,r,P,vl),pushup(x);
    	}
    	void setzero(int x,int l,int r){
    		if(seg[x]==fni)return;
    		seg[x]=fni;
    		if(l!=r)setzero(lson,l,mid),setzero(rson,mid+1,r);
    	}
    	ll query(int x,int l,int r,int L,int R){
    		if(l>R||r<L)return fni;
    		if(L<=l&&r<=R)return seg[x];
    		return max(query(lson,l,mid,L,R),query(rson,mid+1,r,L,R));
    	}
    }all,same;
    void merge(int x,int l,int r){
    	if(same.seg[x]==fni)return;
    	all.seg[x]=max(all.seg[x],same.seg[x]);
    	same.seg[x]=fni;
    	if(l!=r)merge(lson,l,mid),merge(rson,mid+1,r);
    }
    int n,m,L,R,val[200100],ROOT,SZ,sz[200100],msz[200100];
    ll mx=fni;
    vector<pair<int,int> >v[200100];
    bool vis[200100];
    void getsz(int x,int fa){
    	sz[x]=1;
    	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getsz(i.second,x),sz[x]+=sz[i.second];
    }
    void getroot(int x,int fa){
    	sz[x]=1,msz[x]=0;
    	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getroot(i.second,x),sz[x]+=sz[i.second],msz[x]=max(msz[x],sz[i.second]);
    	msz[x]=max(msz[x],SZ-sz[x]);
    	if(msz[x]<msz[ROOT])ROOT=x;
    }
    void write(int x,int fa,SegTree &sg,int las,int dep,ll sum){
    	if(dep>R)return;
    	sg.modify(1,0,R,dep,sum);
    	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)write(i.second,x,sg,i.first,dep+1,sum+(i.first==las?0:val[i.first]));
    }
    void read(int x,int fa,SegTree &sg,int las,int dep,ll sum){
    	if(dep>R)return;
    	mx=max(mx,sum+sg.query(1,0,R,L-dep,R-dep));
    	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)read(i.second,x,sg,i.first,dep+1,sum+(i.first==las?0:val[i.first]));
    }
    void calc(int x){
    	all.modify(1,0,R,0,0);
    	for(int l=0,r=0;r<v[x].size();l=r){
    		while(r<v[x].size()&&v[x][r].first==v[x][l].first){
    			if(vis[v[x][r].second]){r++;continue;}
    			int i=v[x][r].second,j=v[x][r].first;
    			read(i,x,same,j,1,0);
    			write(i,x,same,j,1,val[j]);
    			read(i,x,all,j,1,val[j]);
    			r++;
    		}
    		merge(1,0,R);
    	}
    	all.setzero(1,0,R);
    }
    void solve(int x){
    	calc(x);
    	getsz(x,0); 
    	vis[x]=true;
    	for(auto i:v[x])if(!vis[i.second])ROOT=0,SZ=sz[i.second],getroot(i.second,0),solve(ROOT);
    }
    void read(int &x){
    	x=0;
    	char c=getchar();
    	int fl=1;
    	while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
    	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
    	x*=fl;
    }
    int main(){
    	read(n),read(m),read(L),read(R);
    	for(int i=1;i<=m;i++)read(val[i]);
    	for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),v[x].push_back(make_pair(z,y)),v[y].push_back(make_pair(z,x));
    	for(int i=1;i<=n;i++)sort(v[i].begin(),v[i].end());
    	all.init(),same.init();
    	msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
    	printf("%lld\n",mx);
    	return 0;
    }
    

    然后就是正解了——一个名叫单调队列按秩合并的trick。

    显然,如果我们把所有半路径按照深度排序,它们合法的转移区间是单调递减的。

    比如说,如果设深度为\(dep_x\)的话,则合法深度区间则为\([L-dep_x,R-dep_x]\)。当\(dep_x\nearrow\)时,整个区间\(\searrow\)

    这不是经典老题滑动窗口吗?使用单调队列维护即可。

    我们考虑颜色相同的情况。我们可以用来维护相同深度时的\(sum\)的最大值。对于每一棵子树,我们按照节点深度处理,在桶上跑滑动窗口。在整棵子树跑完后,用它们的值更新桶即可。

    然后颜色不同的情况类似,只不过是对于每一种颜色一起处理,不需要关心具体从哪棵子树过来罢了。

    稍等,这个算法是假的。很明显,这个算法的复杂度为桶的大小。该大小最大可以到直径,即\(n\)级别。如果开门见喜,一上来就遇到了直径,则之后每一次滑动窗口都要完整跑一遍直径。如果儿子数量很多的话,复杂度是会退化成\(O(n^2)\)的。

    那怎么办呢?

    对于颜色相同时,我们按照子树内最大深度递增的方式处理子树,先处理深度浅的子树,再处理深度深的。这就保证了单调队列的复杂度是严格\(O(\sum dep)\),即\(O(n)\)的。

    在颜色不同时,我们仍然这样做,先处理深度浅的颜色,再处理深度深的。

    这种trick,就是单调队列按秩合并——按照长度递增的顺序处理多条单调队列

    最后还有一件事——排序。显然,如果直接排序,总复杂度是\(O(n\log^2n)\)的。当然,常数比之前小很多(因为每一层的排序复杂度都跑不满),因此可以过去。当然,因为深度范围小,使用桶排即可。

    或者放弃dfs,使用bfs,毕竟bfs本身就保证了按照深度排序。

    这里两种代码都能AC,\(n\log n\)的bfs或者\(n\log^2n\)的dfs。

    bfs:

    #pragma GCC optimize(3)
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<vector>
    #include<queue>
    using namespace std;
    int n,m,L,R,val[200100],ROOT,SZ,sz[200100],msz[200100],dep[200100],sum[200100],mdp[200100],cdp[200100],mx=0x80808080,Glo[200100],Loc[200100],glo,loc;
    vector<pair<int,int> >v[200100];
    bool vis[200100];
    void getsz(int x,int fa){
    	sz[x]=1;
    	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getsz(i.second,x),sz[x]+=sz[i.second];
    }
    void getroot(int x,int fa){
    	sz[x]=1,msz[x]=0;
    	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getroot(i.second,x),sz[x]+=sz[i.second],msz[x]=max(msz[x],sz[i.second]);
    	msz[x]=max(msz[x],SZ-sz[x]);
    	if(msz[x]<msz[ROOT])ROOT=x;
    }
    void getdep(int x,int fa,int las){
    	mdp[x]=dep[x]=dep[fa]+1;
    //	printf("%d:%d %d\n",x,dep[x],sum[x]);
    	for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)sum[i.second]=sum[x]+(i.first==las?0:val[i.first]),getdep(i.second,x,i.first),mdp[x]=max(mdp[x],mdp[i.second]);
    }
    bool cmp(pair<int,int>x,pair<int,int>y){
    	if(x.first==y.first)return mdp[x.second]<mdp[y.second];
    	return cdp[x.first]==cdp[y.first]?x.first<y.first:cdp[x.first]<cdp[y.first];
    }
    deque<int>dq;
    queue<int>q;
    void bfswrite(int *arr,int &lim){
    	while(!q.empty()){
    		int x=q.front();q.pop();
    		if(dep[x]>R)break;
    		arr[dep[x]]=max(arr[dep[x]],sum[x]),lim=max(lim,dep[x]);
    		for(auto i:v[x])if(dep[i.second]>dep[x]&&!vis[i.second])q.push(i.second);
    	}
    }
    void bfsread(int *arr,int lim,int delta){
    	while(!q.empty()){
    		int x=q.front();q.pop();
    		if(dep[x]>R)break;
    		while(lim>=0&&lim+dep[x]>=L){
    			while(!dq.empty()&&arr[dq.back()]<=arr[lim])dq.pop_back();
    			dq.push_back(lim--);
    		}
    		while(!dq.empty()&&dq.front()+dep[x]>R)dq.pop_front();
    		if(!dq.empty())mx=max(1ll*mx,0ll+arr[dq.front()]+sum[x]-delta);
    		for(auto i:v[x])if(dep[i.second]>dep[x]&&!vis[i.second])q.push(i.second);
    	}
    	dq.clear();
    }
    void calc(int x){
    //	printf("ROOT:%d:\n",x);
    	dep[0]=-1,sum[x]=0;
    	getdep(x,0,0);
    	for(auto i:v[x])if(!vis[i.second])cdp[i.first]=max(cdp[i.first],mdp[i.second]);
    	sort(v[x].begin(),v[x].end(),cmp);
    	Glo[0]=0;
    	for(int l=0,r=0;r<v[x].size();l=r){
    		while(r<v[x].size()&&v[x][r].first==v[x][l].first){
    			if(!vis[v[x][r].second])q.push(v[x][r].second),bfsread(Loc,loc,val[v[x][r].first]),q.push(v[x][r].second),bfswrite(Loc,loc);
    			r++;
    		}
    		for(int k=l;k<r;k++)if(!vis[v[x][k].second])q.push(v[x][k].second);
    		bfsread(Glo,glo,0);
    		for(int k=0;k<=loc;k++)Glo[k]=max(Glo[k],Loc[k]),Loc[k]=0x80808080;
    		glo=max(glo,loc),loc=0;
    	}
    	for(int k=0;k<=glo;k++)Glo[k]=0x80808080;
    	glo=0;
    	for(auto i:v[x])if(!vis[i.second])cdp[i.first]=0;
    }
    void solve(int x){
    	calc(x);
    	getsz(x,0); 
    	vis[x]=true;
    	for(auto i:v[x])if(!vis[i.second])ROOT=0,SZ=sz[i.second],getroot(i.second,0),solve(ROOT);
    }
    void read(int &x){
    	x=0;
    	char c=getchar();
    	int fl=1;
    	while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
    	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
    	x*=fl;
    }
    int main(){
    	read(n),read(m),read(L),read(R),memset(Glo,0x80,sizeof(Glo)),memset(Loc,0x80,sizeof(Loc));
    	for(int i=1;i<=m;i++)read(val[i]);
    	for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),v[x].push_back(make_pair(z,y)),v[y].push_back(make_pair(z,x));
    	msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
    	printf("%d\n",mx);
    	return 0;
    }
    

    dfs:

    #pragma GCC optimize(3)
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<vector>
    #include<queue>
    using namespace std;
    int n,m,L,R,head[200100],cnt,val[200100],mx=0x80808080;
    int ROOT,SZ,sz[200100],msz[200100];
    int cdp[200100],ddp[200100],buc[200100];
    struct Edge{
    	int to,next,val;
    }edge[400100];
    void ae(int u,int v,int w){
    	edge[cnt].next=head[u],edge[cnt].to=v,edge[cnt].val=w,head[u]=cnt++;
    	edge[cnt].next=head[v],edge[cnt].to=u,edge[cnt].val=w,head[v]=cnt++;
    }
    struct node{
    	int dep,sum,col,frm;
    	node(int A,int B,int C,int D){dep=A,sum=B,col=C,frm=D;}
    };
    bool cmp1(const node &x,const node &y){
    	if(x.col!=y.col){
    		if(cdp[x.col]!=cdp[y.col])return cdp[x.col]<cdp[y.col];
    		return x.col<y.col;
    	}
    	if(x.frm!=y.frm){
    		if(ddp[x.frm]!=ddp[y.frm])return ddp[x.frm]<ddp[y.frm];
    		return x.frm<y.frm;
    	}
    	return x.dep<y.dep;
    }
    bool cmp2(const node &x,const node &y){
    	return x.dep<y.dep;
    }
    vector<node>arr;
    bool vis[200100];
    void getsz(int x,int fa){
    	sz[x]=1;
    	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getsz(edge[i].to,x),sz[x]+=sz[edge[i].to];
    }
    void getroot(int x,int fa){
    	sz[x]=1,msz[x]=0;
    	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getroot(edge[i].to,x),sz[x]+=sz[edge[i].to],msz[x]=max(msz[x],sz[edge[i].to]);
    	msz[x]=max(msz[x],SZ-sz[x]);
    	if(msz[x]<msz[ROOT])ROOT=x;
    }
    void getdep(int x,int fa,int dep,int las,int sum,int col,int frm){
    	if(dep>R)return;
    	cdp[col]=max(cdp[col],dep);
    	ddp[frm]=max(ddp[frm],dep);
    	arr.push_back(node(dep,sum,col,frm));
    	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getdep(edge[i].to,x,dep+1,edge[i].val,sum+(las==edge[i].val?0:val[edge[i].val]),col,frm);
    }
    deque<int>dq;
    void calc(int x){
    	arr.clear();
    	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])getdep(edge[i].to,x,1,edge[i].val,val[edge[i].val],edge[i].val,edge[i].to);
    	sort(arr.begin(),arr.end(),cmp1);
    	for(int l=0,r=0,tmp=0,lim=0;r<arr.size();l=r){
    		dq.clear(),tmp=lim;
    		while(r<arr.size()&&arr[r].frm==arr[l].frm){
    			while(tmp>=0&&tmp+arr[r].dep>=L){
    				while(!dq.empty()&&buc[dq.back()]<=buc[tmp])dq.pop_back();
    				dq.push_back(tmp--);
    			}
    			while(!dq.empty()&&dq.front()+arr[r].dep>R)dq.pop_front();
    			if(!dq.empty())mx=max(1ll*mx,0ll+buc[dq.front()]+arr[r].sum-val[arr[r].col]);
    			r++;
    		}
    		if(r==arr.size()||arr[r].col!=arr[l].col){
    			for(int k=0;k<=lim;k++)buc[k]=0x80808080;
    			lim=0;
    		}else for(int k=l;k<r;k++)buc[arr[k].dep]=max(buc[arr[k].dep],arr[k].sum),lim=max(lim,arr[k].dep);
    	}
    	buc[0]=0;
    	for(int l=0,r=0,tmp=0,lim=0;r<arr.size();l=r){
    		dq.clear();tmp=lim;
    		while(r<arr.size()&&arr[r].col==arr[l].col)r++;
    		sort(arr.begin()+l,arr.begin()+r,cmp2);
    		for(int k=l;k<r;k++){
    			while(tmp>=0&&tmp+arr[k].dep>=L){
    				while(!dq.empty()&&buc[dq.back()]<=buc[tmp])dq.pop_back();
    				dq.push_back(tmp);
    				tmp--;
    			}
    			while(!dq.empty()&&dq.front()+arr[k].dep>R)dq.pop_front();
    			if(!dq.empty())mx=max(1ll*mx,0ll+buc[dq.front()]+arr[k].sum);
    		}
    		if(r==arr.size()){
    			for(int k=0;k<=lim;k++)buc[k]=0x80808080;
    			lim=0;
    		}else for(int k=l;k<r;k++)buc[arr[k].dep]=max(buc[arr[k].dep],arr[k].sum),lim=max(lim,arr[k].dep);
    	}
    	buc[0]=0x80808080;
    	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])cdp[edge[i].val]=ddp[edge[i].to]=0;
    }
    void solve(int x){
    	calc(x);
    	getsz(x,0); 
    	vis[x]=true;
    	for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])ROOT=0,SZ=sz[edge[i].to],getroot(edge[i].to,0),solve(ROOT);
    }
    void read(int &x){
    	x=0;
    	char c=getchar();
    	int fl=1;
    	while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
    	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
    	x*=fl;
    }
    int main(){
    	read(n),read(m),read(L),read(R),memset(head,-1,sizeof(head)),memset(buc,0x80,sizeof(buc));
    	for(int i=1;i<=m;i++)read(val[i]);
    	for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),ae(x,y,z);
    	msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
    	printf("%d\n",mx);
    	return 0;
    }
    

  • 相关阅读:
    Javascript&Html-系统对话框
    Javascript&Html-延迟调用和间歇调用
    Javascript&Html-弹出窗口的屏蔽程序
    iPhone屏幕旋转
    iPhone深度学习-ARM
    xCode控制台的使用-应用闪退原因跟踪
    IOS-内存检测以及优化
    Javascript-Array
    http与https的区别
    Nginx:处理HTTP请求
  • 原文地址:https://www.cnblogs.com/Troverld/p/14605822.html
Copyright © 2020-2023  润新知