题解 (by;zjvarphi)
先建出一棵树,使得 (x,y) 在这棵树上的 (lca) 为原树上 (x ightarrow y) 路径上的最大值,求出这棵树的 (dfs) 序。
再建一棵树,使得其 (lca) 为最大值。
发现合法的路径 ((x,y)) 保证 (x) 在第一棵树上是 (y) 的祖先,(y) 在第二棵树上是 (x) 的祖先。
用树状数组求解即可。
Code
#include<bits/stdc++.h>
#define ri signed
#define pd(i) ++i
#define bq(i) --i
#define func(x) std::function<x>
namespace IO{
char buf[1<<21],*p1,*p2;
#define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?(-1):*p1++
#define dg1(x) std::cerr << #x"=" << x << ' '
#define dg2(x) std::cerr << #x"=" << x << std::endl
#define Dg(x) assert(x)
struct nanfeng_stream{
template<typename T>inline nanfeng_stream &operator>>(T &x) {
bool f=false;x=0;char ch=gc();
while(!isdigit(ch)) f|=ch=='-',ch=gc();
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=gc();
return x=f?-x:x,*this;
}
}cin;
}
using IO::cin;
namespace nanfeng{
#define pb emplace_back
#define FI FILE *IN
#define FO FILE *OUT
template<typename T>inline T cmax(T x,T y) {return x>y?x:y;}
template<typename T>inline T cmin(T x,T y) {return x>y?y:x;}
using ll=long long;
static const int N=2e6+7;
struct edge{int v,nxt;}e[N];
int fa[N],first[N],ld[N],rd[N],p[N],tot,n,t=1;
ll ans;
std::vector<int> vc[N];
auto add=[](int u,int v) {e[t]={v,first[u]},first[u]=t++;};
struct BIT{
#define lowbit(x) (x)&-(x)
int c[N];
func(void(int,int)) update=[&](int x,int k) {for (;x<=n;x+=lowbit(x)) c[x]+=k;};
func(int(int)) query=[&](int x) {
int res=0;
for (;x;x-=lowbit(x)) res+=c[x];
return res;
};
}B;
func(int(int)) find=[](int x) {return x==fa[x]?x:fa[x]=find(fa[x]);};
func(void(int)) dfs1=[](int x) {
ld[x]=++tot;
for (ri i(first[x]);i;i=e[i].nxt) dfs1(e[i].v);
rd[x]=tot;
};
func(void(int)) dfs2=[](int x) {
ans+=B.query(rd[x])-B.query(ld[x]-1);
B.update(ld[x],1);
for (ri i(first[x]);i;i=e[i].nxt) dfs2(e[i].v);
B.update(ld[x],-1);
};
inline int main() {
FI=freopen("charity.in","r",stdin);
FO=freopen("charity.out","w",stdout);
cin >> n;
for (ri i(1);i<=n;pd(i))
cin >> p[fa[i]=i],vc[p[i]].pb(i),vc[i].pb(p[i]);
for (ri i(n-1);i>=1;bq(i))
for (auto v:vc[i])
if (v>i) {
int x=find(v);
add(i,x);
fa[x]=i;
}
dfs1(1);
memset(first,0,sizeof(first));
t=1;
for (ri i(1);i<=n;pd(i)) fa[i]=i;
for (ri i(2);i<=n;pd(i))
for (auto v:vc[i])
if (v<i) {
int x=find(v);
add(i,x);
fa[x]=i;
}
dfs2(n);
printf("%lld
",ans);
return 0;
}
}
int main() {return nanfeng::main();}