题意简述
(n) 个点的无根树。
树上每条边有颜色。共 (m) 种颜色,第 (i) 种颜色有权值 (c_i)。
对于树上一条简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。
求长度在 ([l,r]) 的路径权值最大值。
想法
容易想到点分治。然后大体框架就有了。
问题是如何统计长度在 ([l,r]) 的路径的权值。
容易想到在访问子节点时,同一个颜色的挨着访问,将访问完的路径分为“不同颜色”与“相同颜色”分别按长度排序,然后单调队列。
然后TLE了几个点……
问题在于,如果我们访问的第一个颜色的最长路径很长的话,访问后面点时每个点都进行一次单调队列,相当于这个“长路径”访问了很多很多遍。
于是改变一下访问顺序。
仍是同一个颜色的挨着访问,维护两个单调队列。
将颜色按照此颜色的最大深度从小到大排序,将同一颜色中的点按该点可达的最大深度从小到大排序。
每个点访问完后,与该颜色的单调队列跑一次;一个颜色中的点都访问后,将此颜色的单调队列合并至总的单调队列中。
由于有排序,总复杂度 (O(nlog^2n))
还有一种线段树的【暴力】做法:
仍是同一个颜色的挨着访问。维护两个线段树,一个表示相同颜色,一个表示不同颜色。
访问完一个点后,将访问到的各长度的路径在线段树中找区间最大值更新答案。
访问完一个颜色后,两个线段树合并。合并复杂度 (O(nlogn))
总复杂度 (O(nlog^2n)),但是常数较大。
总结
我就是觉得复杂度这个东西太玄妙了!!爱了爱了!
顺序很重要!!
代码
写+调的我要裂开了。。。
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define INF 1000000000
using namespace std;
int read(){
int x=0,f=1;
char ch=getchar();
while(!isdigit(ch) && ch!='-') ch=getchar();
if(ch=='-') f=-1,ch=getchar();
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x*f;
}
const int N = 200005;
typedef long long ll;
typedef pair<int,int> Pr;
int n,m,L,R,cv[N];
struct edge{
int u,v,c;
bool operator < (const edge &b) const{ return c<b.c; }
}ed[N];
struct node{
int v,c;
node *nxt;
}pool[N*2],*h[N];
int cnt;
void addedge(int u,int v,int c){
node *p=&pool[++cnt],*q=&pool[++cnt];
p->v=v;p->nxt=h[u];h[u]=p;p->c=c;
q->v=u;q->nxt=h[v];h[v]=q;q->c=c;
}
int root,all,sz[N],mx[N],vis[N];
void getrt(int u,int fa){
int v;
sz[u]=1; mx[u]=0;
for(node *p=h[u];p;p=p->nxt)
if((v=p->v)!=fa && !vis[v]){
getrt(v,u);
sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],all-sz[u]);
if(mx[u]<mx[root]) root=u;
}
int dis[N];
void dfs_sz(int u,int fa){
int v;
sz[u]=1;
dis[u]=0;
for(node *p=h[u];p;p=p->nxt)
if((v=p->v)!=fa && !vis[v]){
dfs_sz(v,u);
dis[u]=max(dis[u],dis[v]);
sz[u]+=sz[v];
}
dis[u]++;
}
ll b[N],pre[N],cur[N];
void cal(int u,int fa,int c,ll sum,int len){
cur[len]=max(sum+cv[c],cur[len]);
if(len>=R) return;
int v;
for(node *p=h[u];p;p=p->nxt)
if((v=p->v)!=fa && !vis[v]){
if(p->c==c) cal(v,u,c,sum,len+1);
else cal(v,u,p->c,sum+cv[c],len+1);
}
}
ll ans;
int que[N],hd,tl,son[N],col[N],mxd[N];
bool cmp(int x,int y){
if(col[x]==col[y]) return dis[x]<dis[y]; /**/
if(mxd[col[x]]==mxd[col[y]]) return col[x]<col[y];
return mxd[col[x]]<mxd[col[y]];
}
ll mxb[N];
void work(int u){
int v,lastc=0,dep,tb=0,tp=0;
vis[u]=1;
int sn=0;
for(node *p=h[u];p;p=p->nxt)
if(!vis[v=p->v]){
son[sn++]=v;
col[v]=p->c;
dfs_sz(v,u);
dis[v]=min(dis[v],R);
mxd[col[v]]=max(mxd[col[v]],dis[v]);
}
sort(son,son+sn,cmp);
for(int i=0;i<sn;i++) mxd[col[son[i]]]=0;
for(int k=0;k<sn;k++){
v=son[k];
cal(v,u,col[v],0,1);
for(int i=L;i<=dis[v];i++) if(i<=R) ans=max(ans,cur[i]);
if(col[v]!=lastc){
hd=tl=0;
for(int i=tp,j=1;i>0;i--){
while(j<=tb && j+i<=R){
while(hd<tl && b[que[tl-1]]<=b[j]) tl--;
que[tl++]=j++;
}
while(hd<tl && que[hd]+i<L) hd++;
if(hd<tl) ans=max(ans,b[que[hd]]+pre[i]);
}
tb=tp;
for(int i=1;i<=tp;i++) b[i]=max(b[i],pre[i]),pre[i]=-INF;/**/
tp=dis[v];
for(int i=1;i<=dis[v];i++) pre[i]=cur[i];
}
else{
hd=tl=0;
for(int i=dis[v],j=1;i>0;i--){
while(j<=tp && j+i<=R){
while(hd<tl && pre[que[tl-1]]<=pre[j]) tl--;
que[tl++]=j++;
}
while(hd<tl && que[hd]+i<L) hd++;
if(hd<tl) ans=max(ans,pre[que[hd]]+cur[i]-cv[col[v]]);
}
for(int i=1;i<=dis[v];i++) pre[i]=max(pre[i],cur[i]);
tp=dis[v];
}
lastc=col[v];
for(int i=1;i<=dis[v];i++) cur[i]=-INF;
}
hd=tl=0;
for(int i=tp,j=1;i>0;i--){
while(j<=tb && j+i<=R){
while(hd<tl && b[que[tl-1]]<=b[j]) tl--;
que[tl++]=j++;
}
while(hd<tl && que[hd]+i<L) hd++;
if(hd<tl) ans=max(ans,b[que[hd]]+pre[i]);
}
//clear
for(int i=0;i<=tb;i++) b[i]=-INF;
for(int i=0;i<=tp;i++) pre[i]=-INF;
for(node *p=h[u];p;p=p->nxt)
if(!vis[v=p->v]){
root=0; all=sz[v]; getrt(v,u);
work(root);
}
}
int main()
{
n=read(); m=read(); L=read(); R=read();
for(int i=1;i<=m;i++) cv[i]=read();
for(int i=1;i<n;i++){ ed[i].u=read(); ed[i].v=read(); ed[i].c=read(); }
sort(ed+1,ed+n);
for(int i=1;i<n;i++) addedge(ed[i].u,ed[i].v,ed[i].c);
sz[0]=mx[0]=n+1;
root=0; all=n; getrt(1,0);
for(int i=0;i<=n;i++) b[i]=-INF,pre[i]=-INF,cur[i]=-INF;
ans=-INF;work(root);
printf("%lld
",ans);
return 0;
}