Solution
op=0
令两棵树的边集分别为 (S_1,S_2),(T=S_1and S_2),于是原问题等价于形成一个边集为 (T) 的新图。新图上有 (n-|T|) 个连通块,于是答案就是 (y^{n-|T|})。
当 (op=0) 时,直接 (mathcal O(nlog n)) (map)/双指针找出 (|T|) 即可。
namespace sub0{
pair<int,int> a[N],b[N];
inline void main(){
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);if(u>v) swap(u,v);
a[i]=make_pair(u,v);
}
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);if(u>v) swap(u,v);
b[i]=make_pair(u,v);
}
sort(a+1,a+n);sort(b+1,b+n);
int ans=0;
for(int i=1,j=1;i<n;++i){
while(j<n&&b[j]<a[i]) ++j;
if(b[j]==a[i]) ++ans;
}
printf("%d
",ksm(y,n-ans));
exit(0);
}
}
op=1
此时
后面这个形式不是我们所喜欢的,我们希望转化为形如 (Tsubseteq S_1and S_2) 的形式,这可以利用容斥原理来实现:
证明考虑计算一个集合 (P) 被计算的次数为 (sum_{i=0}^{|S|-|P|}dbinom{|S|-|P|}{i}(-1)^{i}=(1-1)^{|S|-|P|}=[|S|=|P|]),于是得证。
回到原式子:
其中 (g(S)) 表示包含边集 (S) 的树的数量。
可以发现其实我们一直只关心 (|T|-|P|) ,因此可以化为:
考虑 (g(T)) 如何计算,设 (T) 中的边使 (n) 个点形成了 (k=n-|T|) 个连通块,第 (i) 个的大小为 (a_i)。于是这个问题等价于在 (k) 个点之间连边形成生成树,第 (i,j) 个点之间有 (a_i imes a_j) 条重边。利用 (matrix-tree) 定理可以得到,(g(T)=n^{k-2}prod_{i=1}^{k}a_i)。
于是回到原式子,有
令 (w=dfrac{ny}{1-y}),考虑 (a_iw) 的组合意义,这等架于要求在每个连通块中选择一个点,每选择一个点造成 (w) 的贡献。对此考虑树形 (DP),设 (f_{i,0/1}) 表示仅考虑了 (i) 所在的子树,(i) 所在连通块是否选择了点时的答案,于是可以 (mathcal O(n)) 完成转移。
namespace sub1{
vector<int> to[N];
int W,f[N][2];
inline void dfs(int u,int fa){
f[u][0]=1;f[u][1]=W;
for(int v:to[u]){
if(v==fa) continue;
dfs(v,u);
int f0=f[u][0],f1=f[u][1];
f[u][0]=(1ll*f0*f[v][0]+1ll*f0*f[v][1])%mod;
f[u][1]=(1ll*f0*f[v][1]+1ll*f1*f[v][0]+1ll*f1*f[v][1])%mod;
}
}
inline void main(){
if(y==1){printf("%d
",ksm(n,n-2));exit(0);}
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);
to[u].push_back(v);to[v].push_back(u);
}
W=1ll*n*y%mod*ksm(dec(1,y),mod-2)%mod;
dfs(1,0);int p=1ll*ksm(dec(1,y),n)*ksm(n,mod-3)%mod;
printf("%d
",1ll*f[1][1]*p%mod);
exit(0);
}
}
op=2
继续考虑 (op=1) 时的柿子,从枚举 (T) 为 (S_1) 的子集,要求 (T) 为 (S_2) 的子集,改为直接枚举 (T),要求 (T) 为 (S_1,S_2) 的子集,于是有:
进行与 (op=1) 一模一样的操作后有:
此处 (T) 是枚举了所有边集,因此这就相当与考虑了所有划分连通块的方式,原问题等价于将 (n) 个点划分为若干个连通块,一个大小为 (x) 的连通块会产生 (w=x^2dfrac{n^2y}{1-y}) 的贡献,而连通块内部又有 (x^{x-2}) 种方法形成一棵生成树,于是单个连通块的贡献为 (x^xdfrac{n^2y}{1-y})。
原问题相当于将 (n) 个点无序划分为若干个连通块,这与城市规划一题的形式一样,因此我们可以同样得到结论,原问题答案的 (EGF) 就是连通块贡献的 (EGF) 的 (exp)。
于是使用一遍多项式 (exp) 即可。
Code
#include<bits/stdc++.h>
using namespace std;
const int N=(1<<18)+20;
const int mod=998244353;
int n,y,op;
inline void inc(int &x,int y){x=(x+y>=mod)?x+y-mod:x+y;}
inline int dec(int x,int y){return (x-y<0)?x-y+mod:x-y;}
inline int ksm(int x,int y){
int ret=1;
for(;y;y>>=1,x=1ll*x*x%mod) if(y&1) ret=1ll*ret*x%mod;
return ret;
}
namespace sub0{
pair<int,int> a[N],b[N];
inline void main(){
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);if(u>v) swap(u,v);
a[i]=make_pair(u,v);
}
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);if(u>v) swap(u,v);
b[i]=make_pair(u,v);
}
sort(a+1,a+n);sort(b+1,b+n);
int ans=0;
for(int i=1,j=1;i<n;++i){
while(j<n&&b[j]<a[i]) ++j;
if(b[j]==a[i]) ++ans;
}
printf("%d
",ksm(y,n-ans));
exit(0);
}
}
namespace sub1{
vector<int> to[N];
int W,f[N][2];
inline void dfs(int u,int fa){
f[u][0]=1;f[u][1]=W;
for(int v:to[u]){
if(v==fa) continue;
dfs(v,u);
int f0=f[u][0],f1=f[u][1];
f[u][0]=(1ll*f0*f[v][0]+1ll*f0*f[v][1])%mod;
f[u][1]=(1ll*f0*f[v][1]+1ll*f1*f[v][0]+1ll*f1*f[v][1])%mod;
}
}
inline void main(){
if(y==1){printf("%d
",ksm(n,n-2));exit(0);}
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);
to[u].push_back(v);to[v].push_back(u);
}
W=1ll*n*y%mod*ksm(dec(1,y),mod-2)%mod;
dfs(1,0);int p=1ll*ksm(dec(1,y),n)*ksm(n,mod-3)%mod;
printf("%d
",1ll*f[1][1]*p%mod);
exit(0);
}
}
namespace sub2{
typedef vector<int> vec;
typedef unsigned long long ull;
int iv[N],tp,fac[N],jc[N];
inline void init_inv(int n){
if(!tp){iv[0]=iv[1]=fac[0]=fac[1]=jc[0]=jc[1]=1;tp=2;}
for(;tp<=n;++tp){
iv[tp]=1ll*(mod-mod/tp)*iv[mod%tp]%mod;
fac[tp]=1ll*fac[tp-1]*tp%mod;
jc[tp]=1ll*jc[tp-1]*iv[tp]%mod;
}
}
struct poly{
vec v;
inline poly(int w=0):v(1){v[0]=w;}
inline poly(const vec&w):v(w){}
inline int operator [](int x)const{return x>=v.size()?0:v[x];}
inline int& operator [](int x){if(x>=v.size()) v.resize(x+1);return v[x];}
inline int size(){return v.size();}
inline void resize(int x){v.resize(x);}
inline poly slice(int len)const{
if(len<=v.size()) return vec(v.begin(),v.begin()+len);
vec ret(v);ret.resize(len);
return ret;
}
inline poly operator *(const int &x)const{
poly ret(v);
for(int i=0;i<v.size();++i) ret[i]=1ll*ret[i]*x%mod;
return ret;
}
};
int Wn[N<<1],lg[N],r[N],tot;
inline void init_poly(int n){
int p=1;while(p<=n)p<<=1;
for(int i=2;i<=p;++i) lg[i]=lg[i>>1]+1;
for(int i=1;i<p;i<<=1){
int wn=ksm(3,(mod-1)/(i<<1));
Wn[++tot]=1;
for(int j=1;j<i;++j) ++tot,Wn[tot]=1ll*Wn[tot-1]*wn%mod;
}
}
inline void init_pos(int lim){
int len=lg[lim]-1;
for(int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<len);
}
ull fr[N];
const ull Mod=998244353;
inline void NTT(int *f,int lim,int tp){
for(int i=0;i<lim;++i) fr[i]=f[r[i]];
for(int mid=1;mid<lim;mid<<=1){
for(int len=mid<<1,l=0;l+len-1<lim;l+=len){
for(int k=l;k<l+mid;++k){
ull w1=fr[k],w2=fr[k+mid]*Wn[mid+k-l]%Mod;
fr[k]=w1+w2;fr[k+mid]=w1+Mod-w2;
}
}
}
for(int i=0;i<lim;++i) f[i]=fr[i]%Mod;
if(!tp){
reverse(f+1,f+lim);
int iv=ksm(lim,mod-2);
for(int i=0;i<lim;++i) f[i]=1ll*f[i]*iv%mod;
}
}
inline poly to_poly(int *a,int n){
poly ret;
ret.resize(n);
memcpy(ret.v.data(),a,n<<2);
return ret;
}
namespace Exp{
const int logB=4;
const int B=16;
int f[N],ret[N],H[N];
poly g[4][B];
inline void exp(int lim,int l,int r,int dep){
if(r-l<=128){
for(int i=l;i<r;++i){
ret[i]=(!i)?1:1ll*ret[i]*iv[i]%mod;
for(int j=i+1;j<r;++j)
inc(ret[j],1ll*ret[i]*f[j-i]%mod);
}
return ;
}
int k=(r-l)/B;
int len=1<<lim-logB+1;
vector<unsigned long long> bl[B];
for(int i=0;i<B;++i) bl[i].resize(k<<1);
for(int i=0;i<B;++i){
if(i>0){
init_pos(len);
for(int j=0;j<(k<<1);++j) H[j]=bl[i][j]%mod;
NTT(H,len,0);
for(int j=0;j<k;++j)
inc(ret[l+i*k+j],H[j+k]);
}
exp(lim-logB,l+i*k,l+(i+1)*k,dep+1);
if(i<B-1){
memcpy(H,ret+l+i*k,sizeof(int)*(k));
memset(H+k,0,sizeof(int)*(k));
init_pos(len);NTT(H,len,1);
for(int j=i+1;j<B;++j)
for(int t=0;t<(k<<1);++t)
bl[j][t]+=1ll*H[t]*g[dep][j-i-1][t];
}
}
}
inline poly getexp(poly F,int n){
memcpy(f,F.v.data(),sizeof(int)*(n));
int mx=lg[n]+1;init_inv(1<<mx);
for(int i=0;i<n;++i) f[i]=1ll*f[i]*i%mod;
memset(ret,0,sizeof(int)*(1<<mx));
for(int lim=mx,dep=0;lim>=8;lim-=logB,dep++){
int len=1<<(lim-logB+1);
init_pos(len);
for(int i=0;i<B-1;++i){
g[dep][i].resize(len);
memcpy(g[dep][i].v.data(),f+(len>>1)*i,sizeof(int)*(len));
NTT(g[dep][i].v.data(),len,1);
}
}
exp(mx,0,1<<mx,0);
return to_poly(ret,n);
}
}
inline void main(){
if(y==1){printf("%d
",ksm(n,2*n-4));exit(0);}
init_poly((n+1)<<1);init_inv(n);
poly f;
int p=1ll*n*n%mod*y%mod*ksm(dec(1,y),mod-2)%mod;
for(int i=0;i<=n;++i) f[i]=1ll*ksm(i,i)*p%mod*jc[i]%mod;
f=Exp::getexp(f,n+1);
printf("%d
",1ll*f[n]*fac[n]%mod*ksm(dec(1,y),n)%mod*ksm(n,mod-5)%mod);
}
}
int main(){
scanf("%d%d%d",&n,&y,&op);
if(!op) sub0::main();
else if(op==1) sub1::main();
else sub2::main();
return 0;
}