题目:http://www.lydsy.com/JudgeOnline/problem.php?id=3124
首先第一问直接求直径就可以了。
然后对于第二问,因为直径一定过树的重心,于是可以把重心找出来。
如果是菊花图的话,那就输出0。。
如果不是菊花图,对于很多条直径的两个端点,一定可以把它们分成两个点集使得每个点到对面点集的任一个点的路径都是直径。。
于是这两个点集分别求lca,两个lca之间的距离就是答案了。
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<queue> #include<cmath> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define inf 1000000007 #define ll long long #define maxn 200500 #define eps 1e-5 #define mm 2147483648 #define low(x) (x&(-x)) #define f(x,y,z) g[x][y][z+8] using namespace std; struct data{int obj,pre; ll c; }e[maxn*2]; ll diss[maxn],dist[maxn],len; int mx[maxn],head[maxn],qs[maxn],qt[maxn],sz[maxn],dep[maxn],bin[30],fa[maxn][30]; int s,t,n,cnt1,cnt2,tot,root; ll read(){ ll x=0,f=1; char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)){x=x*10+ch-'0'; ch=getchar();} return x*f; } void insert(int x,int y,ll z){ e[++tot].obj=y; e[tot].pre=head[x]; e[tot].c=z; head[x]=tot; } int dfs(int s,ll dis[]){ queue<int > q; q.push(s); rep(i,1,n) dis[i]=inf; dis[s]=0; int mx=0,ans=0; while (!q.empty()){ int u=q.front(); q.pop(); for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (dis[v]==inf) { dis[v]=dis[u]+e[j].c; q.push(v); if (dis[v]>mx) mx=dis[v],ans=v; } } } return ans; } void get(int u,int f){ sz[u]++; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=f) { get(v,u); sz[u]+=sz[v]; mx[u]=max(mx[u],sz[v]); } } mx[u]=max(mx[u],n-sz[u]); } void dfs(int u){ rep(i,1,20) if (dep[u]>=bin[i]) fa[u][i]=fa[fa[u][i-1]][i-1]; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=fa[u][0]){ fa[v][0]=u; dep[v]=dep[u]+1; dfs(v); } } } int lca(int x,int y){ if (dep[x]<dep[y]) swap(x,y); int t=dep[x]-dep[y]; rep(i,0,20) if (t&bin[i]) x=fa[x][i]; down(i,20,0) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if (x==y) return x; return fa[x][0]; } int main(){ // freopen("diameter.in","r",stdin); // freopen("diameter.out","w",stdout); bin[0]=1; rep(i,1,20) bin[i]=bin[i-1]*2; n=read(); rep(i,1,n-1){ int x=read(),y=read();ll z=read(); insert(x,y,z); insert(y,x,z); } s=dfs(1,diss); t=dfs(s,diss); dfs(t,dist); len=diss[t]; printf("%lld ",len); rep(i,1,n){ if (diss[i]==len&&dist[i]==len) {puts("0"); return 0;} if (diss[i]==len) qs[++cnt1]=i; if (dist[i]==len) qt[++cnt2]=i; } root=0; mx[0]=inf; get(s,0); rep(u,1,n) if (mx[root]>mx[u]) root=u; // printf("%d ",root); dfs(root); int a,b; if (cnt1==1) a=qs[cnt1]; else { a=lca(qs[1],qs[2]); rep(i,3,cnt1) a=lca(a,qs[i]); } if (cnt2==1) b=qt[cnt2]; else { b=lca(qt[1],qt[2]); rep(i,3,cnt2) b=lca(b,qt[i]); } printf("%d ",dep[a]+dep[b]-2*dep[lca(a,b)]); return 0; }