1131: [POI2008]Sta
Time Limit: 10 Sec Memory Limit: 162 MBSubmit: 713 Solved: 196
[Submit][Status]
Description
给出一个N个点的树,找出一个点来,以这个点为根的树时,所有点的深度之和最大
Input
给出一个数字N,代表有N个点.N<=1000000 下面N-1条边.
Output
输出你所找到的点,如果具有多个解,请输出编号最小的那个.
Sample Input
8
1 4
5 6
4 5
6 7
6 8
2 4
3 4
1 4
5 6
4 5
6 7
6 8
2 4
3 4
Sample Output
7
HINT
Source
题解:
总算自己yy出一题。。。
刚开始想了一个贪心的算法,每次选取子树大小最小的子树递归下去,这样sum一定会变大?可是怎么证明正确性呢?看了一下样例发现这样做是错的。。。
然后发现sum传到儿子是可以O(1)更新的,于是就两次dfs水过了。。。
从父亲传到儿子,所有儿子的子树内的深度-1,其余深度+1,O(1)
代码:
1 #include<cstdio> 2 3 #include<cstdlib> 4 5 #include<cmath> 6 7 #include<cstring> 8 9 #include<algorithm> 10 11 #include<iostream> 12 13 #include<vector> 14 15 #include<map> 16 17 #include<set> 18 19 #include<queue> 20 21 #include<string> 22 23 #define inf 1000000000 24 25 #define maxn 1000000+1000 26 27 #define maxm 500+100 28 29 #define eps 1e-10 30 31 #define ll long long 32 33 #define pa pair<int,int> 34 35 #define for0(i,n) for(int i=0;i<=(n);i++) 36 37 #define for1(i,n) for(int i=1;i<=(n);i++) 38 39 #define for2(i,x,y) for(int i=(x);i<=(y);i++) 40 41 using namespace std; 42 43 inline int read() 44 45 { 46 47 int x=0,f=1;char ch=getchar(); 48 49 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} 50 51 while(ch>='0'&&ch<='9'){x=10*x+ch-'0';ch=getchar();} 52 53 return x*f; 54 55 } 56 int n,tot,head[maxn],dep[maxn]; 57 ll s[maxn],sum[maxn]; 58 struct edge{int go,next;}e[2*maxn]; 59 inline void insert(int x,int y){e[++tot].go=y;e[tot].next=head[x];head[x]=tot;} 60 inline void dfs(int x) 61 { 62 s[x]=1; 63 for(int i=head[x],y;i;i=e[i].next) 64 if(!dep[y=e[i].go]) 65 { 66 sum[y]=dep[y]=dep[x]+1; 67 dfs(y); 68 s[x]+=s[y]; 69 sum[x]+=sum[y]; 70 } 71 } 72 inline void dfs2(int x) 73 { 74 for(int i=head[x],y;i;i=e[i].next) 75 if(dep[y=e[i].go]>dep[x]) 76 { 77 sum[y]=sum[x]-s[y]+n-s[y]; 78 dfs2(y); 79 } 80 } 81 82 int main() 83 84 { 85 86 freopen("input.txt","r",stdin); 87 88 freopen("output.txt","w",stdout); 89 90 n=read(); 91 for1(i,n-1){int x=read(),y=read();insert(x,y);insert(y,x);}; 92 dep[1]=sum[1]=1; 93 dfs(1); 94 dfs2(1); 95 int ans=0; 96 for1(i,n)if(sum[i]>sum[ans])ans=i; 97 printf("%d ",ans); 98 99 return 0; 100 101 }