Code:
#include <bits/stdc++.h> #define setIO(s) freopen(s".in","r",stdin) #define maxn 300000 #define ll long long using namespace std; vector<int>G[maxn]; ll sumv[maxn]; ll Sum(ll k) { return k*(k+1)/2; } ll calc(ll k) { return 1ll*(k+1)*Sum(k)-1ll*sumv[k]; } void Initialize() { for(int i=1;i<maxn;++i) sumv[i]=sumv[i-1]+1ll*i*i; } int main() { // setIO("input"); int n,m; scanf("%d",&n),m=n-1; for(int i=1;i<=m;++i) { int a,b; scanf("%d%d",&a,&b); G[a].push_back(i), G[b].push_back(i); } Initialize(); ll ans=(ll)((ll)(1ll*m*n+n)*m)/2; ans -= (ll) calc(m); for(int i=1;i<=n;++i) { sort(G[i].begin(), G[i].end()); for(int j=0;j<G[i].size();++j) { if(j==0) { ans-=Sum((ll)G[i][j]-1); } if(j>0) { ans-=Sum((ll)G[i][j]-G[i][j-1]-1); } if(j==G[i].size()-1) { ans-=Sum((ll)m-G[i][j]); } } } printf("%lld ",ans); return 0; }