Description
秀秀有一棵带n个顶点的树T,每个节点有一个点权ai。
有一天,她想拥有两棵树,于是她从T删去了一条边。
第二天,她认为三棵树或许会更好一些。因此,她又从她拥有的某一棵树中删去了一条边。
如此往复,每一天秀秀都会删去一条尚未被删去的边,直到她得到由n棵只有一个点的树构成的森林。
秀秀定义一条简单路径(节点不重复出现的路径)的权值为路径上所有点的权值之和,一棵树的直径为树上权值最大的简单路径。秀秀认为树最重要的特征就是它的直径。所以她想请你算出任一时刻她拥有的所有树的直径的乘积。因为这个数可能很大,你只需输出这个数对10^9+7取模之后的结果即可。
Input
从文件 forest.in 中读入数据。
输入的第一行包含一个整数n,表示树T顶点的数量。
下一行包含n个空格分隔的整数ai,表示顶点的权值。
之后的n-1行中,每一行包含两个用空格分隔的整数ui和vi,表示节点ui和vi之间连
有一条边,编号为i。
再之后n-1行中,每一行包含一个整数kj,表示在第j天里会被删除的边的编号。
Output
输出文件到forest.out 中。
共n行,在第i行,输出删除i-1条边之后,所有树直径的乘积对10^9+7取模的结果。
Sample Input
3
1 2 3
1 2
1 3
2
1
Sample Output
6
9
6
Hint
【样例解释】
初始时,树的直径为6(由节点2、1和3构成的路径)。在第一天之后,得到了两棵直径都为3的树。第二天之后,得到了三棵直径分别为 1,2,3的树,乘积为 6。
【数据规模与约定】
对于40%的数据:n≤100;
另有20%的数据:n≤1000;
另有20%的数据:n≤10000;
对于100%的数据:n≤100000,ai≤10000。
题解
- 倒序操作,发现两棵树合并时,树的直径只有可能由原来两棵树直径两边的点构成,枚举可能的直径点对
代码
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
const int maxn=100003;
const int mod=1e9+7;
int n,val[maxn],head[maxn],cnt,del[maxn],dep[maxn],sum[maxn],fa[maxn][20];
int p[maxn];
int getfa(int x){return x==p[x]?x:p[x]=getfa(p[x]);}
long long ans[maxn];
struct tftftf{int u,v,maxx;}d[maxn];
tftftf max(tftftf a,tftftf b){return a.maxx>b.maxx?a:b;}
struct node{int u,v;}b[maxn];
struct fdfdfd{int next,to;}e[maxn<<1];
void addedge(int x,int y){e[++cnt]=(fdfdfd){head[x],y}; head[x]=cnt;}
long long qpow(long long x,long long a)
{
long long ans=1;
while(a){
if(a&1) ans=ans*x%mod;
x=x*x%mod; a>>=1;
}
return ans;
}
void dfs1(int x,int pre)
{
fa[x][0]=pre; dep[x]=dep[pre]+1; sum[x]=sum[pre]+val[x];
for(int i=head[x];i;i=e[i].next) {
int v=e[i].to; if(v==pre) continue;
dfs1(v,x);
}
}
void st()
{
for(int i=1;i<=log2(n);++i)
for(int j=1;j<=n;++j) fa[j][i]=fa[fa[j][i-1]][i-1];
}
int getlca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
while(dep[u]>dep[v]) u=fa[u][(int)log2(dep[u]-dep[v])];
if(u==v) return u;
for(int i=log2(n);i>=0;--i)
if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int js(int u,int v){int lca=getlca(u,v);return sum[u]+sum[v]-sum[lca]*2+val[lca];}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
scanf("%d",&n); ans[n]=1;
for(int i=1;i<=n;++i) p[i]=i,scanf("%d",&val[i]),ans[n]=ans[n]*val[i]%mod,d[i]=(tftftf){i,i,val[i]};
for(int i=1,u,v;i<n;++i) scanf("%d%d",&b[i].u,&b[i].v),addedge(b[i].u,b[i].v),addedge(b[i].v,b[i].u);
for(int i=1;i<n;++i) scanf("%d",&del[i]);
dfs1(1,0); st();
for(int i=n-1;i;--i)
{
int ax=getfa(b[del[i]].u),bx=getfa(b[del[i]].v);
ans[i]=ans[i+1]*qpow(d[ax].maxx,mod-2)%mod*qpow(d[bx].maxx,mod-2)%mod;
int s1=d[ax].u,t1=d[ax].v,s2=d[bx].u,t2=d[bx].v;
p[ax]=bx;
d[bx]=max(d[bx],d[ax]);
d[bx]=max(d[bx],(tftftf){s1,s2,js(s1,s2)});
d[bx]=max(d[bx],(tftftf){s1,t2,js(s1,t2)});
d[bx]=max(d[bx],(tftftf){t1,s2,js(t1,s2)});
d[bx]=max(d[bx],(tftftf){t1,t2,js(t1,t2)});
ans[i]=ans[i]*d[bx].maxx%mod;
}
for(int i=1;i<=n;++i) cout<<ans[i]<<'
';
return 0;
}