题目大意
task0:有两棵(n)(nleq10^5)个点的树(T1,T2),每个点的点权可以是一个在([1,y])里的数,如果两个点既在(T1)中有直接连边,又在(T2)中有直接连边,那么它们的点权必须相同。求有多少种分配点权的方案。
task1:有一棵(n)个点的树(T1),给定(y),求(T2)所有形态的task1之和(mod 998244353)
task2:给定(n,y),求(T1)所有形态的task2之和(mod 998244353)
题解
首先,当(y=1)时,所有点权都得是1,task0的答案就是1,task1的答案就是(n^{n-2}),task2的答案就是(n^{2*n-4})
task0
设森林(T3=T1igcap T2)
那么就会发现(T3)中同一个连通块内的所有点的点权必须相等,分配点权的方案数为(y^{n-|T3|})
task1
枚举(T2)的话,设(x=y^{-1}),答案=(sumlimits_{T2}y^{n-m}=sumlimits_{T2}y^n*x^{|T3|})
给定(n)的情况下,(y^n)不变,只考虑怎么算(x^{|T3|})
通过查看题解 根据二项式定理发现(y^n*x^{|T3|}=y^n*(x-1+1)^{|T3|}=y^n*sumlimits_{i=0}^{|T3|}{C_{|T3|}^{i}*(x-1)^i})
考虑枚举(T4subseteq T3),则每个(T4)对上式的贡献是(y^n*(x-1)^{|T4|})(相当于在(|T3|)条边中选了(|T4|)条边)
原式就变成了:答案=(sumlimits_{T2}sumlimits_{T4subseteq(T1igcap T2)}y^n*(x-1)^{|T4|})
枚举(T4)的(Sigma)拿到前面,则有(sumlimits_{T4subseteq T1}y^n*(x-1)^{|T4|}sumlimits_{E otsubset T4 且 |E|+|T4|=n}1)
相当于对于每个(T4),它的贡献就是把它这个森林连成树的方案数
枚举(T4)的连通块个数,则有(sumlimits_{m=1}^{n}sumlimits_{n-|T4|=m且T4subseteq T1}y^n*(x-1)^{n-m}sumlimits_{E otsubset T4 且 |E|+|T4|=n}1)
这样就能把和(x,y)有关的部分拿到前面了:(sumlimits_{m=1}^{n}y^n*(x-1)^{n-m}sumlimits_{n-|T4|=m且T4subseteq T1}sumlimits_{E otsubset T4 且 |E|+|T4|=n-1}1)
设(T4)的(m)个连通块里的点数分别为(a_1,a_2,...,a_m),枚举每个连通块的度数(d_1,d_2,...,d_m)
先把(T4)的(m)个连通块看成(m)个点,那么就相当于要求出(m)个点的有标号无根树有多少棵
这个可以用prufer序列相关的定理算出是(m^{m-2})棵
设prufer序列中(i)的出现次数是(d_i)
(m)个连通块与(m)个点的区别在于,(m)个点的话,点(i)要连(d_i+1)条边;而(m)个连通块的话,连向连通块(i)的(d_i+1)条边的每一条边都可以在(a_i)个点中任选一个连边,总方案数要再乘上个(prodlimits_{i=1}^{m}a_i^{d_i+1})
枚举prufer序列,假设prufer序列是(p_1,p_2,...,p_m),就会有(m)个连通块,第(i)个连通块(a_i)个点的有标号无根树有(sumlimits_{p_1,p_2,...,p_{m-2}, 1leq p_ileq m}prodlimits_{i=1}^{m}a_i^{d_i+1})棵
把指数(d_i+1)中“+1”的那个(a_i)拿到前面,该式=((prodlimits_{i=1}^{m}a_i)*sumlimits_{p_1,p_2,...,p_{m-2}, 1leq p_ileq m}prodlimits_{i=1}^{m}a_i^{d_i})
这样变化之后,枚举的prufer序列中每个出现的(p_i)都会对该式有(a_{p_i})的贡献
就会有该式=((prodlimits_{i=1}^{m}a_i)*sumlimits_{p_1,p_2,...,p_{m-2}, 1leq p_ileq m}prodlimits_{i=1}^{m-2}a_{p_i})
交换(sum)和(prod),得该式=((prodlimits_{i=1}^{m}a_i)*prodlimits_{i=1}^{m-2}sumlimits_{p_i=1}^{m}a_{p_i})
其中(sumlimits_{p_i=1}^{m}a_{p_i}=n)
那么该式=((prodlimits_{i=1}^{m}a_i)*n^{m-2})
就会有:答案=(sumlimits_{m=1}^{n}y^n*(x-1)^{n-m}sumlimits_{sum a_i=n}(prodlimits_{i=1}^{m}a_i)*n^{m-2})
将与(m)有关的部分都挪到后面,与(n)有关的挪到前面,得:答案=(y^n*(x-1)^n*n^{-2} sumlimits_{sum a_i=n}(prodlimits_{i=1}^{m}a_i)*n^{m}*(x-1)^{-m})
将(n^{m}*(x-1)^{-m})拿到(prod)中,得:答案=(y^n*(x-1)^n*n^{-2} prodlimits_{i=1}^{m}(a_i*(x-1)^{-1}*n))
其中(prodlimits_{i=1}^{m}a_i)的部分相当于将(n)个点划分成一些连通块,在每一个连通块中选一个关键点的方案数
这个可以dp,设(f(i,0/1))表示以(i)为根的子树中,(i)所在连通块有/没有选出关键点的方案数
那么(prodlimits_{i=1}^{m}(a_i*(x-1)^{-1}*n))可以看成在此基础上,每个连通块又乘了((x-1)^{-1}*n),这个也可以dp
答案=(y^n*(x-1)^n*n^{-2}*f(1,1))
task2
既要枚举(T2),又要枚举(T1)
可以把答案看成“(n)个有标号的点分(m)个无标号的组的方案数”、“(m)个连通块中每个连通块能构出多少棵树”、“确定块内的连边情况后,T1的块间连通方案数”、“确定块内的连边情况后,T2的块间连通方案数”四部分相乘
“(n)个有标号的点分(m)个无标号的组的方案数”=(sumlimits_{sum{a_i}=n}{frac{n!}{m!*prod_{i=1}^{m}{a_i!}}})
根据prufer序列,(x)个点的连通块构出(x^{x-2})棵树,考虑(m)个连通块中每个连通块都能构出(a_i^{a_i})棵树,“(m)个连通块中每个连通块能构出多少棵树”=(prodlimits_{i=1}^{m}{a_i^{a_i-2}})
而“确定块内的连边情况后,T1/T2的块间连通方案数”这部分之前已经算过了,就是((prodlimits_{i=1}^{m}a_i)*n^{m-2})
所以,答案=(y^n*(x-1)^{n-m}*sumlimits_{sum{a_i}=n}{frac{n!}{m!*prodlimits_{i=1}^{m}{a_i!}}}*(prodlimits_{i=1}^{m}{a_i^{a_i-2}})*((prodlimits_{i=1}^{m}a_i)*n^{m-2})^2)
把((n^{-2})^2)拿到前面,把((x-1)^{-m})拿到后面,将三个(prod)合并,就会有:答案=(y^n*(x-1)^n*n^{-4}*n!*sumlimits_{sum{a_i}=n}{frac{1}{m!}}prodlimits_{i=1}^{m}frac{a_i^{a_i}*n^2}{a_i!*(x-1)})
用([x^n])表示多项式中(n)次项的系数,就会有:答案=(y^n*(x-1)^n*n^{-4}*n!*[x^n]sumlimits_{m=1}^{n}{frac{1}{m!}}(frac{n^2}{x-1}sumlimits_{a>0}frac{a^a}{a!}x^a)^m)
根据泰勒展开发现(exp(x)=sumlimits_{i=0}^{infty}frac{x^i}{i!}),那么([x^n]sumlimits_{m=1}^{n}{frac{1}{m!}}(frac{n^2}{x-1}sumlimits_{a>0}frac{a^a}{a!}x^a)^m)这部分就可以看成(exp(frac{n^2}{x-1}sumlimits_{a>0}frac{a^a}{a!}x^a))
对这个多项式求exp后,将(n)次项系数乘以(y^n*(x-1)^n*n!*n^{-4})
代码
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define view(u,k) for(int k=fir[u];~k;k=nxt[k])
#define maxn 100010
#define maxm (maxn<<1)
#define maxlen (maxn<<3)
#define LL long long
#define mo(x) (x>=mod?x-mod:(x<0?x+mod:x))
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(int x)
{
if(x==0){putchar('0'),putchar('
');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('
');
return;
}
const LL mod=998244353;
struct edge{int u,v;}e1[maxn],e2[maxn];
int f[maxn][2],fu[maxlen],ans[maxlen],fir[maxn],nxt[maxm],to[maxm],cnt,n,y,op,y2,ny2,t[2];
bool cmp(edge x,edge y){return x.u==y.u?x.v<y.v:x.u<y.u;}
int mul(int a,int b){int ans=1;while(b){if(b&1)ans=(LL)ans*(LL)a%mod;a=(LL)a*(LL)a%mod;b>>=1;}return ans;}
void ade(int u1,int v1){to[cnt]=v1,nxt[cnt]=fir[u1],fir[u1]=cnt++;}
void getf(int u,int fa)
{
f[u][0]=f[u][1]=y2;
view(u,k)if(to[k]!=fa)
{
getf(to[k],u);t[0]=t[1]=0;
rep(yesu,0,1)rep(yesv,0,1)
{
if(!(yesu&yesv))t[yesu|yesv]=mo(t[yesu|yesv]+(LL)f[u][yesu]*(LL)f[to[k]][yesv]%mod*ny2%mod);
if(yesv)t[yesu]=mo(t[yesu]+(LL)f[u][yesu]*(LL)f[to[k]][yesv]%mod);
}
f[u][0]=t[0],f[u][1]=t[1];
}
}
int g[maxlen],h[maxlen],nowlen,nown,r[maxlen],tmp[maxlen],tmp2[maxlen];
void dnt(int * u,int fh)
{
rep(i,0,nown-1)r[i]=(r[i>>1]>>1)|((i&1)<<(nowlen-1));
rep(i,0,nown-1)if(i<r[i])swap(u[i],u[r[i]]);
for(int i=1;i<nown;i<<=1)
{
int wn=mul(3,(mod-1)/(i<<1));
if(fh==-1)wn=mul(wn,mod-2);
for(int j=0;j<nown;j+=(i<<1))
{
int w=1;
rep(k,0,i-1){int a=u[j+k],b=(LL)w*(LL)u[i+j+k]%mod;u[j+k]=mo(a+b),u[i+j+k]=mo(a-b),w=(LL)w*(LL)wn%mod;}
}
}
if(fh==-1)
{
int inv=mul(nown,mod-2);
rep(i,0,nown-1)u[i]=(LL)u[i]*(LL)inv%mod;
}
return;
}
void getny(int * u,int * v,int nlen)
{
v[0]=mul(u[0],mod-2);
for(int len=0,tmpn=1;tmpn<nlen;len++,tmpn<<=1)
{
nown=tmpn<<1,nowlen=len+1;
rep(i,0,nown-1)tmp[i]=u[i];
nown<<=1,nowlen++;
rep(i,(tmpn<<1),nown-1)tmp[i]=0;
dnt(tmp,1),dnt(v,1);
rep(i,0,nown-1)v[i]=mo(2ll-(LL)tmp[i]*(LL)v[i]%mod)*(LL)v[i]%mod;
dnt(v,-1);
rep(i,(tmpn<<1),nown-1)v[i]=0;
}
rep(i,0,nown)tmp[i]=0;
rep(i,nlen,nown)v[i]=0;
return;
}
void getup(int * u,int * v,int nlen)
{
rep(i,1,nlen-1)v[i-1]=(LL)i*(LL)u[i]%mod;v[nlen-1]=0;
return;
}
void getdx(int * u,int * v,int nlen)
{
rep(i,1,nlen-1)v[i]=(LL)u[i-1]*(LL)mul(i,mod-2)%mod;v[0]=0;
return;
}
void getln(int * u,int * v,int nlen)
{
getup(u,h,nlen),getny(u,g,nlen);
for(nowlen=0,nown=1;nown<(nlen+nlen);nowlen++,nown<<=1);
dnt(h,1),dnt(g,1);
rep(i,0,nown-1)h[i]=(LL)h[i]*(LL)g[i]%mod;
rep(i,0,nown-1)g[i]=0;
dnt(h,-1);getdx(h,v,nlen);
rep(i,nlen,nown)v[i]=0;
return;
}
void getexp(int * u,int * v,int nlen)
{
rep(i,0,(nlen<<2))v[i]=0;
v[0]=1;
for(int len=0,tmpn=1;tmpn<nlen;len++,tmpn<<=1)
{
rep(i,0,(tmpn<<1))tmp2[i]=0;
getln(v,tmp2,(tmpn<<1));
nown=(tmpn<<2),nowlen=len+2;
rep(i,(tmpn<<1),nown)tmp2[i]=0;
rep(i,0,(tmpn<<1)-1)tmp[i]=u[i];
rep(i,(tmpn<<1),nown)tmp[i]=0;
dnt(tmp2,1),dnt(v,1),dnt(tmp,1);
rep(i,0,nown-1)v[i]=(LL)v[i]*(LL)mo(mo(1ll-tmp2[i])+tmp[i])%mod;
dnt(v,-1);
rep(i,(tmpn<<1),nown)v[i]=0;
}
rep(i,nlen,nown)v[i]=0;
return;
}
int main()
{
n=read(),y=read(),op=read();
if(!op)
{
rep(i,1,n-1){int x1=read(),y1=read();e1[i].u=min(x1,y1),e1[i].v=max(x1,y1);}
rep(i,1,n-1){int x1=read(),y1=read();e2[i].u=min(x1,y1),e2[i].v=max(x1,y1);}
sort(e1+1,e1+n,cmp),sort(e2+1,e2+n,cmp);
int j=1;
rep(i,1,n-1)
{
while(cmp(e1[j],e2[i])&&j<n-1)j++;
if(e1[j].u==e2[i].u&&e1[j].v==e2[i].v){cnt++;}
}
write(mul(y,n-cnt));
}
else if(op==1)
{
memset(fir,-1,sizeof(fir));
if(y==1){write(mul(n,n-2));return 0;} y2=(LL)mul((mul(y,mod-2)-1+mod)%mod,mod-2)*(LL)n%mod,ny2=mul(y2,mod-2);
rep(i,1,n-1){int x1=read(),y1=read();ade(x1,y1),ade(y1,x1);}
getf(1,0),write((LL)f[1][1]*(LL)mul((mul(y,mod-2)-1+mod)%mod,n)%mod*(LL)mul((LL)n*(LL)n%mod,mod-2)%mod*(LL)mul(y,n)%mod);
}
else
{
if(y==1){write(mul(n,n*2-4));return 0;}
int facn=1,rfac,x=mul(y,mod-2),a=(LL)n*(LL)n%mod*(LL)mul(mo(x-1),mod-2)%mod;
rep(i,2,n)facn=(LL)facn*(LL)i%mod;
rfac=mul(facn,mod-2);
dwn(i,n,1)fu[i]=(LL)rfac*(LL)mul(i,i)%mod*(LL)a%mod,rfac=(LL)rfac*(LL)i%mod;
getexp(fu,ans,n+1);
write((LL)mul(y,n)*(LL)mul((x-1+mod)%mod,n)%mod*(LL)mul((LL)n*(LL)n%mod*(LL)n%mod*(LL)n%mod,mod-2)%mod*(LL)facn%mod*(LL)ans[n]%mod);
}
return 0;
}