题目大意:给定一棵 1~n 标号的树。Tree[L,R]表示最少需要选择的边的数量使得 L~R 号点两两连通。求:
[sum_{L=1}^{n} sum_{R=L}^{n} operatorname{Tree}[L, R]
]
题解:
要求的是经过边的数量,可以考虑每条边对答案的贡献。
对于边 (u,v) ,定义一个特征序列 a[1...n],若 i 在以 u 为根的子树中,则 a[i]=1,否则 a[i]=0。发现若要求该边的贡献为 1,则选择的范围 a[l]-a[r] 中必须要既有 0 又有 1,这样才能跨过这条边,到达子树外面的节点。问题转化成了求对于树上每个节点的特征序列来说,既有 0 又有 1 的区间的个数和是多少。但是这个问题并不好求,可以转化为求特征序列中,只有 0 或 1 的区间个数,再用总共的区间个数减掉这些即可。因此,可以采用线段树来维护区间最长前后缀 0/1 的长度,区间合并的时候只需处理左区间的后缀和右区间的前缀即可完成。树上的问题还需要进行线段树合并来完成,时间复杂度为 (O(nlogn))。
代码如下
#include <bits/stdc++.h>
#define pb push_back
using namespace std;
const int maxn=1e5+10;
typedef long long LL;
int n; LL ans;
vector<int> G[maxn];
struct node{
#define ls(o) t[o].lc
#define rs(o) t[o].rc
int lc,rc,lmx0,lmx1,rmx0,rmx1;
LL sum0,sum1;
}t[maxn*20];
int tot,rt[maxn];
inline void pushup(int o,int l,int r){
int mid=l+r>>1;
if(!ls(o))t[ls(o)].sum0=(LL)(mid-l+1)*(mid-l+2)/2,t[ls(o)].lmx0=t[ls(o)].rmx0=mid-l+1;
if(!rs(o))t[rs(o)].sum0=(LL)(r-mid)*(r-mid+1)/2,t[rs(o)].lmx0=t[rs(o)].rmx0=r-mid;
t[o].sum0=t[ls(o)].sum0+t[rs(o)].sum0+(LL)t[ls(o)].rmx0*t[rs(o)].lmx0;
t[o].sum1=t[ls(o)].sum1+t[rs(o)].sum1+(LL)t[ls(o)].rmx1*t[rs(o)].lmx1;
t[o].lmx0=t[ls(o)].lmx0==mid-l+1?t[ls(o)].lmx0+t[rs(o)].lmx0:t[ls(o)].lmx0;
t[o].lmx1=t[ls(o)].lmx1==mid-l+1?t[ls(o)].lmx1+t[rs(o)].lmx1:t[ls(o)].lmx1;
t[o].rmx0=t[rs(o)].rmx0==r-mid?t[rs(o)].rmx0+t[ls(o)].rmx0:t[rs(o)].rmx0;
t[o].rmx1=t[rs(o)].rmx1==r-mid?t[rs(o)].rmx1+t[ls(o)].rmx1:t[rs(o)].rmx1;
}
void insert(int &o,int l,int r,int pos){
if(!o)o=++tot;
if(l==r){t[o].lmx1=t[o].rmx1=t[o].sum1=1;return;}
int mid=l+r>>1;
if(pos<=mid)insert(ls(o),l,mid,pos);
else insert(rs(o),mid+1,r,pos);
pushup(o,l,r);
}
int merge(int x,int y,int l,int r){
if(!x||!y)return x+y;
if(l==r)return t[x].sum1?x:y;
int mid=l+r>>1;
ls(x)=merge(ls(x),ls(y),l,mid);
rs(x)=merge(rs(x),rs(y),mid+1,r);
return pushup(x,l,r),x;
}
void dfs(int u,int fa){
for(auto v:G[u]){
if(v==fa)continue;
dfs(v,u);
rt[u]=merge(rt[u],rt[v],1,n);
}
insert(rt[u],1,n,u);
LL ret=(LL)n*(n+1)/2-t[rt[u]].sum0-t[rt[u]].sum1;
if(u!=1)ans+=ret;
}
void read_and_parse(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
G[x].pb(y),G[y].pb(x);
}
}
void solve(){
dfs(1,0);
printf("%lld
",ans);
}
int main(){
read_and_parse();
solve();
return 0;
}