IX.[BJOI2017]树的难题
debug三天,精神崩溃
论一行if(vis[v[x][r].second]){r++;continue;}
忘记加上后所有代码全都莫名其妙TLE且查不出锅的痛苦
首先,我们考虑常规淀粉质。
我们考虑一条路径,它会被(淀粉质的分治根)截成两段。如果我们对于分治树中的每一个节点,预处理出来它到树根的路径权值,记为\(sum_x\),则一条完整路径的权值则为\(sum_x+sum_y\)。
稍等,我们好像忘记了一种情况——如果这两条路径顶端的边的颜色相同怎么办?
换句话说,假如某一半路径的颜色段为(顺序为从根节点往下)ABABC
,另一半为ABAC
,两半拼一起,我们得到CBABAABAC
。显然,这个A
就被算了两次,应该被减掉。
因此,这种情况的权值则为\(sum_x+sum_y-val_c\),其中\(val_c\)为\(c\)颜色的权值,而\(c\)为两条路径顶端的颜色。
很明显这两者要分开考虑。
然后就是求值了。显然,对于半条路径,与它可以拼成完整路径的另一半的长度,是一个区间。所以,我们可以建一棵线段树,以深度为下标,存储所有当前深度的点中,\(sum_x\)的最大值。
我们需要两棵线段树——一棵用于同种颜色的储存,而另一棵用于全局颜色的储存。在一种颜色全部处理完之后,将当前颜色线段树与全局线段树合并,并清空当前线段树即可。
明显复杂度为\(O(n\log^2n)\)。由于人傻常数大,它TLE50,即使开O3也一样。
TLE的线段树代码:
#pragma GCC optimize(3)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
const ll fni=-1e18;
#define lson x<<1
#define rson x<<1|1
#define mid ((l+r)>>1)
struct SegTree{
ll seg[800100];
void init(){
for(int i=1;i<=800000;i++)seg[i]=fni;
}
void pushup(int x){
seg[x]=max(seg[lson],seg[rson]);
}
void modify(int x,int l,int r,int P,ll vl){
if(l>P||r<P)return;
if(l==r){seg[x]=max(seg[x],vl);return;}
modify(lson,l,mid,P,vl),modify(rson,mid+1,r,P,vl),pushup(x);
}
void setzero(int x,int l,int r){
if(seg[x]==fni)return;
seg[x]=fni;
if(l!=r)setzero(lson,l,mid),setzero(rson,mid+1,r);
}
ll query(int x,int l,int r,int L,int R){
if(l>R||r<L)return fni;
if(L<=l&&r<=R)return seg[x];
return max(query(lson,l,mid,L,R),query(rson,mid+1,r,L,R));
}
}all,same;
void merge(int x,int l,int r){
if(same.seg[x]==fni)return;
all.seg[x]=max(all.seg[x],same.seg[x]);
same.seg[x]=fni;
if(l!=r)merge(lson,l,mid),merge(rson,mid+1,r);
}
int n,m,L,R,val[200100],ROOT,SZ,sz[200100],msz[200100];
ll mx=fni;
vector<pair<int,int> >v[200100];
bool vis[200100];
void getsz(int x,int fa){
sz[x]=1;
for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getsz(i.second,x),sz[x]+=sz[i.second];
}
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getroot(i.second,x),sz[x]+=sz[i.second],msz[x]=max(msz[x],sz[i.second]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[ROOT])ROOT=x;
}
void write(int x,int fa,SegTree &sg,int las,int dep,ll sum){
if(dep>R)return;
sg.modify(1,0,R,dep,sum);
for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)write(i.second,x,sg,i.first,dep+1,sum+(i.first==las?0:val[i.first]));
}
void read(int x,int fa,SegTree &sg,int las,int dep,ll sum){
if(dep>R)return;
mx=max(mx,sum+sg.query(1,0,R,L-dep,R-dep));
for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)read(i.second,x,sg,i.first,dep+1,sum+(i.first==las?0:val[i.first]));
}
void calc(int x){
all.modify(1,0,R,0,0);
for(int l=0,r=0;r<v[x].size();l=r){
while(r<v[x].size()&&v[x][r].first==v[x][l].first){
if(vis[v[x][r].second]){r++;continue;}
int i=v[x][r].second,j=v[x][r].first;
read(i,x,same,j,1,0);
write(i,x,same,j,1,val[j]);
read(i,x,all,j,1,val[j]);
r++;
}
merge(1,0,R);
}
all.setzero(1,0,R);
}
void solve(int x){
calc(x);
getsz(x,0);
vis[x]=true;
for(auto i:v[x])if(!vis[i.second])ROOT=0,SZ=sz[i.second],getroot(i.second,0),solve(ROOT);
}
void read(int &x){
x=0;
char c=getchar();
int fl=1;
while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
x*=fl;
}
int main(){
read(n),read(m),read(L),read(R);
for(int i=1;i<=m;i++)read(val[i]);
for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),v[x].push_back(make_pair(z,y)),v[y].push_back(make_pair(z,x));
for(int i=1;i<=n;i++)sort(v[i].begin(),v[i].end());
all.init(),same.init();
msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
printf("%lld\n",mx);
return 0;
}
然后就是正解了——一个名叫单调队列按秩合并的trick。
显然,如果我们把所有半路径按照深度排序,它们合法的转移区间是单调递减的。
比如说,如果设深度为\(dep_x\)的话,则合法深度区间则为\([L-dep_x,R-dep_x]\)。当\(dep_x\nearrow\)时,整个区间\(\searrow\)。
这不是经典老题滑动窗口吗?使用单调队列维护即可。
我们考虑颜色相同的情况。我们可以用桶来维护相同深度时的\(sum\)的最大值。对于每一棵子树,我们按照节点深度处理,在桶上跑滑动窗口。在整棵子树跑完后,用它们的值更新桶即可。
然后颜色不同的情况类似,只不过是对于每一种颜色一起处理,不需要关心具体从哪棵子树过来罢了。
稍等,这个算法是假的。很明显,这个算法的复杂度为桶的大小。该大小最大可以到直径,即\(n\)级别。如果开门见喜,一上来就遇到了直径,则之后每一次滑动窗口都要完整跑一遍直径。如果儿子数量很多的话,复杂度是会退化成\(O(n^2)\)的。
那怎么办呢?
对于颜色相同时,我们按照子树内最大深度递增的方式处理子树,先处理深度浅的子树,再处理深度深的。这就保证了单调队列的复杂度是严格\(O(\sum dep)\),即\(O(n)\)的。
在颜色不同时,我们仍然这样做,先处理深度浅的颜色,再处理深度深的。
这种trick,就是单调队列按秩合并——按照长度递增的顺序处理多条单调队列。
最后还有一件事——排序。显然,如果直接排序,总复杂度是\(O(n\log^2n)\)的。当然,常数比之前小很多(因为每一层的排序复杂度都跑不满),因此可以过去。当然,因为深度范围小,使用桶排即可。
或者放弃dfs,使用bfs,毕竟bfs本身就保证了按照深度排序。
这里两种代码都能AC,\(n\log n\)的bfs或者\(n\log^2n\)的dfs。
bfs:
#pragma GCC optimize(3)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
using namespace std;
int n,m,L,R,val[200100],ROOT,SZ,sz[200100],msz[200100],dep[200100],sum[200100],mdp[200100],cdp[200100],mx=0x80808080,Glo[200100],Loc[200100],glo,loc;
vector<pair<int,int> >v[200100];
bool vis[200100];
void getsz(int x,int fa){
sz[x]=1;
for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getsz(i.second,x),sz[x]+=sz[i.second];
}
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)getroot(i.second,x),sz[x]+=sz[i.second],msz[x]=max(msz[x],sz[i.second]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[ROOT])ROOT=x;
}
void getdep(int x,int fa,int las){
mdp[x]=dep[x]=dep[fa]+1;
// printf("%d:%d %d\n",x,dep[x],sum[x]);
for(auto i:v[x])if(!vis[i.second]&&i.second!=fa)sum[i.second]=sum[x]+(i.first==las?0:val[i.first]),getdep(i.second,x,i.first),mdp[x]=max(mdp[x],mdp[i.second]);
}
bool cmp(pair<int,int>x,pair<int,int>y){
if(x.first==y.first)return mdp[x.second]<mdp[y.second];
return cdp[x.first]==cdp[y.first]?x.first<y.first:cdp[x.first]<cdp[y.first];
}
deque<int>dq;
queue<int>q;
void bfswrite(int *arr,int &lim){
while(!q.empty()){
int x=q.front();q.pop();
if(dep[x]>R)break;
arr[dep[x]]=max(arr[dep[x]],sum[x]),lim=max(lim,dep[x]);
for(auto i:v[x])if(dep[i.second]>dep[x]&&!vis[i.second])q.push(i.second);
}
}
void bfsread(int *arr,int lim,int delta){
while(!q.empty()){
int x=q.front();q.pop();
if(dep[x]>R)break;
while(lim>=0&&lim+dep[x]>=L){
while(!dq.empty()&&arr[dq.back()]<=arr[lim])dq.pop_back();
dq.push_back(lim--);
}
while(!dq.empty()&&dq.front()+dep[x]>R)dq.pop_front();
if(!dq.empty())mx=max(1ll*mx,0ll+arr[dq.front()]+sum[x]-delta);
for(auto i:v[x])if(dep[i.second]>dep[x]&&!vis[i.second])q.push(i.second);
}
dq.clear();
}
void calc(int x){
// printf("ROOT:%d:\n",x);
dep[0]=-1,sum[x]=0;
getdep(x,0,0);
for(auto i:v[x])if(!vis[i.second])cdp[i.first]=max(cdp[i.first],mdp[i.second]);
sort(v[x].begin(),v[x].end(),cmp);
Glo[0]=0;
for(int l=0,r=0;r<v[x].size();l=r){
while(r<v[x].size()&&v[x][r].first==v[x][l].first){
if(!vis[v[x][r].second])q.push(v[x][r].second),bfsread(Loc,loc,val[v[x][r].first]),q.push(v[x][r].second),bfswrite(Loc,loc);
r++;
}
for(int k=l;k<r;k++)if(!vis[v[x][k].second])q.push(v[x][k].second);
bfsread(Glo,glo,0);
for(int k=0;k<=loc;k++)Glo[k]=max(Glo[k],Loc[k]),Loc[k]=0x80808080;
glo=max(glo,loc),loc=0;
}
for(int k=0;k<=glo;k++)Glo[k]=0x80808080;
glo=0;
for(auto i:v[x])if(!vis[i.second])cdp[i.first]=0;
}
void solve(int x){
calc(x);
getsz(x,0);
vis[x]=true;
for(auto i:v[x])if(!vis[i.second])ROOT=0,SZ=sz[i.second],getroot(i.second,0),solve(ROOT);
}
void read(int &x){
x=0;
char c=getchar();
int fl=1;
while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
x*=fl;
}
int main(){
read(n),read(m),read(L),read(R),memset(Glo,0x80,sizeof(Glo)),memset(Loc,0x80,sizeof(Loc));
for(int i=1;i<=m;i++)read(val[i]);
for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),v[x].push_back(make_pair(z,y)),v[y].push_back(make_pair(z,x));
msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
printf("%d\n",mx);
return 0;
}
dfs:
#pragma GCC optimize(3)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
using namespace std;
int n,m,L,R,head[200100],cnt,val[200100],mx=0x80808080;
int ROOT,SZ,sz[200100],msz[200100];
int cdp[200100],ddp[200100],buc[200100];
struct Edge{
int to,next,val;
}edge[400100];
void ae(int u,int v,int w){
edge[cnt].next=head[u],edge[cnt].to=v,edge[cnt].val=w,head[u]=cnt++;
edge[cnt].next=head[v],edge[cnt].to=u,edge[cnt].val=w,head[v]=cnt++;
}
struct node{
int dep,sum,col,frm;
node(int A,int B,int C,int D){dep=A,sum=B,col=C,frm=D;}
};
bool cmp1(const node &x,const node &y){
if(x.col!=y.col){
if(cdp[x.col]!=cdp[y.col])return cdp[x.col]<cdp[y.col];
return x.col<y.col;
}
if(x.frm!=y.frm){
if(ddp[x.frm]!=ddp[y.frm])return ddp[x.frm]<ddp[y.frm];
return x.frm<y.frm;
}
return x.dep<y.dep;
}
bool cmp2(const node &x,const node &y){
return x.dep<y.dep;
}
vector<node>arr;
bool vis[200100];
void getsz(int x,int fa){
sz[x]=1;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getsz(edge[i].to,x),sz[x]+=sz[edge[i].to];
}
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getroot(edge[i].to,x),sz[x]+=sz[edge[i].to],msz[x]=max(msz[x],sz[edge[i].to]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[ROOT])ROOT=x;
}
void getdep(int x,int fa,int dep,int las,int sum,int col,int frm){
if(dep>R)return;
cdp[col]=max(cdp[col],dep);
ddp[frm]=max(ddp[frm],dep);
arr.push_back(node(dep,sum,col,frm));
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to]&&edge[i].to!=fa)getdep(edge[i].to,x,dep+1,edge[i].val,sum+(las==edge[i].val?0:val[edge[i].val]),col,frm);
}
deque<int>dq;
void calc(int x){
arr.clear();
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])getdep(edge[i].to,x,1,edge[i].val,val[edge[i].val],edge[i].val,edge[i].to);
sort(arr.begin(),arr.end(),cmp1);
for(int l=0,r=0,tmp=0,lim=0;r<arr.size();l=r){
dq.clear(),tmp=lim;
while(r<arr.size()&&arr[r].frm==arr[l].frm){
while(tmp>=0&&tmp+arr[r].dep>=L){
while(!dq.empty()&&buc[dq.back()]<=buc[tmp])dq.pop_back();
dq.push_back(tmp--);
}
while(!dq.empty()&&dq.front()+arr[r].dep>R)dq.pop_front();
if(!dq.empty())mx=max(1ll*mx,0ll+buc[dq.front()]+arr[r].sum-val[arr[r].col]);
r++;
}
if(r==arr.size()||arr[r].col!=arr[l].col){
for(int k=0;k<=lim;k++)buc[k]=0x80808080;
lim=0;
}else for(int k=l;k<r;k++)buc[arr[k].dep]=max(buc[arr[k].dep],arr[k].sum),lim=max(lim,arr[k].dep);
}
buc[0]=0;
for(int l=0,r=0,tmp=0,lim=0;r<arr.size();l=r){
dq.clear();tmp=lim;
while(r<arr.size()&&arr[r].col==arr[l].col)r++;
sort(arr.begin()+l,arr.begin()+r,cmp2);
for(int k=l;k<r;k++){
while(tmp>=0&&tmp+arr[k].dep>=L){
while(!dq.empty()&&buc[dq.back()]<=buc[tmp])dq.pop_back();
dq.push_back(tmp);
tmp--;
}
while(!dq.empty()&&dq.front()+arr[k].dep>R)dq.pop_front();
if(!dq.empty())mx=max(1ll*mx,0ll+buc[dq.front()]+arr[k].sum);
}
if(r==arr.size()){
for(int k=0;k<=lim;k++)buc[k]=0x80808080;
lim=0;
}else for(int k=l;k<r;k++)buc[arr[k].dep]=max(buc[arr[k].dep],arr[k].sum),lim=max(lim,arr[k].dep);
}
buc[0]=0x80808080;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])cdp[edge[i].val]=ddp[edge[i].to]=0;
}
void solve(int x){
calc(x);
getsz(x,0);
vis[x]=true;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])ROOT=0,SZ=sz[edge[i].to],getroot(edge[i].to,0),solve(ROOT);
}
void read(int &x){
x=0;
char c=getchar();
int fl=1;
while(c>'9'||c<'0')fl=(c=='-'?-fl:fl),c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
x*=fl;
}
int main(){
read(n),read(m),read(L),read(R),memset(head,-1,sizeof(head)),memset(buc,0x80,sizeof(buc));
for(int i=1;i<=m;i++)read(val[i]);
for(int i=1,x,y,z;i<n;i++)read(x),read(y),read(z),ae(x,y,z);
msz[0]=n+1,SZ=n,getroot(1,0),solve(ROOT);
printf("%d\n",mx);
return 0;
}