XV.[Codeforces GYM 101002K] YATP
(没有单独的页面,就放个到大页面的连接罢)
我们考虑先套一个点分治。点分治后,考虑计算所有LCA为根节点的对中,最优的那些对。
我们考虑就算某两个点它们位于同一棵子树内也不要紧——这里它的权值被表示成 \(dep_i+dep_j+a_ia_j\),但实际上再往子树内分治时该权值会更小,故实际上在此处计算也不会对最终答案有影响。故我们现在就要对于某个\(i\),计算对于所有\(j\)(不管它是否与\(i\)在同一子树内)中,\(dep_i+dep_j+a_ia_j\)的\(\min\)。
对于同一个\(i\),显然\(dep_i\)是可以最后再考虑的。故我们只需要考虑\(a_ia_j+dep_j\)即可。它实际上是一个一次函数\(a_jx+dep_j\),而所有的\(j\)画到平面直角坐标系中就是一条条的直线。对于不同的\(x\)(即不同的\(a_i\)),它可能需要不同的直线,故我们只需要求出上述直线的下凸壳即可。
\(O(n)\)扫一遍所有节点后,\(O(n\log n)\)排序并求凸壳。然后再扫一遍所有节点,在凸壳上二分出当前\(a_i\)对应哪条直线最优即可。
时间复杂度\(O(n\log^2n)\)。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,a[200100],head[200100],cnt,sz[200100],msz[200100],rt,SZ,stk[200100],tp;
ll f[200100],res,dep[200100];
struct node{
int to,next,val;
}edge[400100];
void ae(int u,int v,int w){
edge[cnt].next=head[u],edge[cnt].val=w,edge[cnt].to=v,head[u]=cnt++;
edge[cnt].next=head[v],edge[cnt].val=w,edge[cnt].to=u,head[v]=cnt++;
}
bool vis[200100];
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
for(int i=head[x];i!=-1;i=edge[i].next)if(edge[i].to!=fa&&!vis[edge[i].to])getroot(edge[i].to,x),sz[x]+=sz[edge[i].to],msz[x]=max(msz[x],sz[edge[i].to]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[rt])rt=x;
}
void getsz(int x,int fa){
sz[x]=1;
for(int i=head[x];i!=-1;i=edge[i].next)if(edge[i].to!=fa&&!vis[edge[i].to])getsz(edge[i].to,x),sz[x]+=sz[edge[i].to];
}
vector<int>v;
void getdep(int x,int fa){
v.push_back(x);
for(int i=head[x];i!=-1;i=edge[i].next)if(edge[i].to!=fa&&!vis[edge[i].to])dep[edge[i].to]=dep[x]+edge[i].val,getdep(edge[i].to,x);
}
int calc(int x){
int l=1,r=tp;
while(l<r){
int mid=(l+r)>>1;
if(1ll*a[stk[mid]]*a[x]+dep[stk[mid]]<=1ll*a[stk[mid+1]]*a[x]+dep[stk[mid+1]])r=mid;
else l=mid+1;
}
f[x]=min(f[x],1ll*a[x]*a[stk[r]]+dep[stk[r]]+dep[x]);
}
void getroute(int x,int fa){
calc(x);
for(int i=head[x];i!=-1;i=edge[i].next)if(edge[i].to!=fa&&!vis[edge[i].to])getroute(edge[i].to,x);
}
void getans(int x){
dep[x]=0;
getdep(x,0);
sort(v.begin(),v.end(),[](int u,int v){return a[u]==a[v]?dep[u]>dep[v]:a[u]>a[v];});
tp=0;
for(auto i:v){
while(tp>=2&&1ll*(dep[stk[tp-1]]-dep[stk[tp]])*(a[i]-a[stk[tp]])>=1ll*(dep[i]-dep[stk[tp]])*(a[stk[tp-1]]-a[stk[tp]]))tp--;
stk[++tp]=i;
}
getroute(x,0);
v.clear();
}
void solve(int x){
getans(x),getsz(x,0),vis[x]=true;
for(int i=head[x];i!=-1;i=edge[i].next)if(!vis[edge[i].to])rt=0,SZ=sz[edge[i].to],getroot(edge[i].to,x),solve(rt);
}
int main(){
scanf("%d",&n),memset(head,-1,sizeof(head)),memset(f,0x3f,sizeof(f));
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
for(int i=1,x,y,z;i<n;i++)scanf("%d%d%d",&x,&y,&z),ae(x,y,z);
msz[0]=0x3f3f3f3f,rt=0,SZ=n,getroot(1,0),solve(rt);
for(int i=1;i<=n;i++)res+=f[i];
printf("%lld\n",res);
return 0;
}