差点自闭,感谢大佬帮忙找bug
题目:https://codeforces.com/gym/101968/problem/A
找树的重心+思维
找到树的重心,如果重心只有一个,以重心为根节点dfs,求各节点深度,那么任意一对节点都符合题意,只要让先手mark深度小的节点即可,其中同样深度的节点,交换位置扔符合题意,要加上dep[i]*(dep[i]-1)/2
如果重心有两个,分别以两个重心为根节点dfs,以任意一对节点都符合题意为基础(ans=n*(n-1)/2),如果两个深度相同的节点分别挂在两个不同的重心上,那么他们无法作为符合题意的一对节点,要减去这一部分,对每个重心深度相同的节点,分别加上交换节点扔符合题意的一部分,嗯就这样
#include<iostream> #include<cstdio> #include<cmath> #include<queue> #include<vector> #include<string.h> #include<cstring> #include<algorithm> #include<set> #include<stack> #include<map> #include<fstream> #include<cstdlib> #include<ctime> #include<list> #include<climits> #include<bitset> #include<random> using namespace std; #define fopen freopen("input.in", "r", stdin);freopen("output.in", "w", stdout); #define left asfdasdasdfasdfsdfasfsdfasfdas1 #define right asfdasdasdfasdfsdfasfsdfasfdas2 #define y1 asfdasdasdfasdfsdfasfsdfasfdas3 typedef long long ll; typedef unsigned ui; typedef long double ld; const int dell[8][2]={{1,2},{1,-2},{2,1},{2,-1},{-1,2},{-1,-2},{-2,1},{-2,-1}}; ll mod=1e9+7; const ll inf=(1LL<<31)-1; const int maxn=1e6+7; const int maxm=1e6+7; const double eps=1e-8; const double pi=acos(-1.0); const int csize=22; int n,k,m,ar[maxn]; struct node{ int b,nex; }no[maxn*2]; int head[maxn],sz,mx,root; int pre[maxn]; ll sspre[maxn]; int pre2[maxn]; ll sspre2[maxn]; void init(){ for(int i=0;i<=n;i++)head[i]=-1; sz=0; } void add(int a,int b){ no[sz].b=b; no[sz].nex=head[a]; head[a]=sz++; } int dep[maxn],num[maxn]; void dfs(int u,int fa) { num[u]=1; if(dep[u]>mx){ mx=dep[u]; root=u; } for(int i=head[u];i!=-1;i=no[i].nex){ int v=no[i].b; if(v==fa)continue; dep[v]=dep[u]+1; dfs(v,u); num[u] += num[v]; } } bool findmid(int u,int& mid){ bool can=1; for(int i=head[u];i!=-1;i=no[i].nex){ int v=no[i].b; if(dep[v]==dep[u]+1){ if(findmid(v,mid)){ return 1; } can &= num[v]<(n+1)/2; } } if(can && num[u]>=(n+1)/2){ mid=u; return 1; } else return 0; } int main() { //fopen //freopen("input.in","r",stdin); int t;scanf("%d",&t); while(t--){ scanf("%d",&n); init(); for(int i=2;i<=n;i++){ int x;scanf("%d",&x); add(i,x); add(x,i); } mx=0;root=0; dep[1]=1; dfs(1,-1); int mid=0; findmid(1,mid); ll ans=1LL*n*(n-1)/2; if(n%2==0 && num[mid]==n/2){ int mid2=0; for(int i=head[mid];i!=-1;i=no[i].nex){ if(dep[no[i].b]+1==dep[mid]){ mid2=no[i].b; break; } } for(int i=1;i<=n;i++)dep[i]=0; dep[mid2]=1; dfs(mid2,mid); for(int i=0;i<=n+1;i++)pre[i]=pre2[i]=0; for(int i=1;i<=n;i++){ if(dep[i]>0)pre2[dep[i]]++; } for(int i=1;i<=n;i++)dep[i]=0; dep[mid]=1; dfs(mid,mid2); for(int i=1;i<=n;i++){ if(dep[i]>0){ pre[dep[i]]++; } } for(int i=1;i<=n;i++){ ans -= (ll)pre[i]*pre2[i]; } for(int i=1;i<=n;i++){ ans += (ll)pre[i]*(pre[i]-1)/2; ans += (ll)pre2[i]*(pre2[i]-1)/2; } } else{ //if(fir==1609)while(1); dep[mid]=1; dfs(mid,-1); for(int i=0;i<=n+1;i++)pre[i]=0; for(int i=1;i<=n;i++)pre[dep[i]]++; for(int i=1;i<=n;i++){ ans += (ll)pre[i]*(pre[i]-1)/2; } } printf("%lld ",ans); } return 0; }