题意
写的很明白了,不需要解释。
( exttt{Data Range:}1leq nleq 234567)
题解
国 际 计 数 水 平
首先考虑一开始只有一个黑点的情况怎么做。
我们钦定黑点为根,设 (f_i) 表示以 (i) 为子树,并且 (i) 这个节点为黑色的答案,那么我们得到:
接下来我们阐述这个式子到底是在干什么。
首先我们枚举可能的排列,由于现在只有 (u) 是黑色的,第一轮肯定是先把 (u) 变成红色。也就是说,(u) 只能放在第一个,而剩下的可以任意。(注意这里是可能而不是合法)
接下来考虑排除掉所有不合法的排列,枚举每一棵子树,考虑某棵子树所有节点之间的偏序关系。
对于一个树 (T) 对应的可能的序列 (P) 来说,我们按顺序从 (P) 中选出一些元素构造 (Q),使得 (Q) 中的每一个元素都是 (T) 中某个子树 (S) 中存在的节点的编号,并且 (S) 中每个节点的编号都要在 (Q) 中出现。
这里,(T) 的根为 (u),(S) 的根为 (v) 且 (fa_v=u)。
然后我们发现,如果序列 (Q) 是不合法的,那么序列 (P) 也是不合法的。
对所有子树都这么做就得到同时满足所有约束的序列个数,也就是上面的式子。
接下来我们考虑直接计算根节点的答案。
然后我们对于所有根节点的儿子 (v) 代入进来,再将所有孙子代入进来……相当于我们枚举所有节点计算,于是有
简单化开
注意分母,也就是说我们对于每个节点 (u) 都枚举 (u) 的所有儿子,相当于枚举所有非根节点(因为每个非根节点总会是某个节点的儿子),于是有
现在来拆分子。单独把 (u=rt) 的拿出来,剩下的合并并约分,得到
然后分子分母同乘一个 (sz_{rt}) 得到答案为
接下来考虑两个黑点咋做,建立一个虚拟节点 (s),向 (a,b) 连无向边,初始的时候只有 (s) 是黑点。
第一次只能变 (s),然后就变成 (s) 是红点,(a,b) 是黑点。容易证明这样对答案没有任何影响。
于是树就变成了一颗基环树,考虑怎么求答案。
我们枚举环上染红的最后一个点的位置,但是这是不可行的。
那么我们就用类似于 NOIp2018 Day 2 T1 的方法,考虑枚举一条在环上的边并删除他来计算答案。
但是这样做可能会用重复的答案,所以我们要考虑如何来排除这些答案。
注意到对于一种染色方案来说,考虑这个方案在环上最后一个被染成红色的点。我们发现这个点可以有两种方案被染黑,也就是说每个方案都会算 (2) 此,于是最终答案要乘上 (frac{1}{2})。
对于不是环上的点和虚拟节点 (s) 我们都能用 ( exttt{dfs}) 求出子树大小,设他们的和为 (z)。
我们现在考虑这些在环上的点,对于某个在环上的点 (u) 来说,设 (a_u) 表示它的确定子树大小(也就是所有不在环上的儿子的大小之和加一),(b_u) 表示某种断环方式下的 (sz_u),那么通过人类智慧我们发现
将环上的 (k+1) 个点编号,(s) 点为 (0),其余的依次为 (1sim k),假设断的边是 (r o r+1),那么
我们对 (a) 做前缀和变成 (c),那么可以发现
其中 (c_0=0)。
然后我们可以用快速插值的套路去优化这个东西,复杂度 (O(nlog^2n))。
代码
#include<cstdio>
#include<cstring>
#include<cctype>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#define clr(f,n) memset(f,0,sizeof(int)*(n))
#define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
using namespace std;
typedef int ll;
typedef long long int li;
typedef unsigned long long int ull;
const ll MAXN=524291,MOD=998244353,G=3,INVG=332748118;
struct Edge{
ll to,prev;
};
Edge ed[MAXN];
ll n,x,y,z,fct,from,to,tot,tp,cur,res;
ll rev[MAXN],omgs[MAXN],invo[MAXN];
ll last[MAXN],sz[MAXN],fa[MAXN],st[MAXN],p[MAXN],c[MAXN];
ll pr[MAXN],px[MAXN];
inline ll read()
{
register ll num=0,neg=1;
register char ch=getchar();
while(!isdigit(ch)&&ch!='-')
{
ch=getchar();
}
if(ch=='-')
{
neg=-1;
ch=getchar();
}
while(isdigit(ch))
{
num=(num<<3)+(num<<1)+(ch-'0');
ch=getchar();
}
return num*neg;
}
inline void addEdge(ll from,ll to)
{
ed[++tot].prev=last[from];
ed[tot].to=to;
last[from]=tot;
}
inline void dfs(ll node,ll f)
{
sz[node]=1,fa[node]=f;
for(register int i=last[node];i;i=ed[i].prev)
{
if(ed[i].to!=f)
{
dfs(ed[i].to,node),sz[node]+=sz[ed[i].to];
}
}
}
inline ll qpow(ll base,ll exponent)
{
ll res=1;
while(exponent)
{
if(exponent&1)
{
res=(li)res*base%MOD;
}
base=(li)base*base%MOD,exponent>>=1;
}
return res;
}
inline void setupOmg(ll cnt)
{
ll limit=log2(cnt)-1,omg;
omg=qpow(G,(MOD-1)>>(limit+1)),omgs[cnt>>1]=1;
for(register int i=(cnt>>1|1);i!=cnt;i++)
{
omgs[i]=(li)omgs[i-1]*omg%MOD;
}
for(register int i=(cnt>>1)-1;i;i--)
{
omgs[i]=omgs[i<<1];
}
}
inline void NTT(ll *cp,ll cnt,ll inv)
{
static ull tcp[MAXN];
register ll cur=0,x,shift=log2(cnt)-__builtin_ctz(cnt);
if(inv==-1)
{
reverse(cp+1,cp+cnt);
}
for(register int i=0;i<cnt;i++)
{
tcp[rev[i]>>shift]=cp[i];
}
for(register int i=2;i<=cnt;i<<=1)
{
cur=i>>1;
for(register int j=0;j<cnt;j+=i)
{
for(register int k=0;k<cur;k++)
{
x=tcp[j|k|cur]*omgs[k|cur]%MOD;
tcp[j|k|cur]=tcp[j|k]+MOD-x,tcp[j|k]+=x;
}
}
}
for(register int i=0;i<cnt;i++)
{
cp[i]=tcp[i]%MOD;
}
if(inv==1)
{
return;
}
x=MOD-(MOD-1)/cnt;
for(register int i=0;i<cnt;i++)
{
cp[i]=(li)cp[i]*x%MOD;
}
}
inline void conv(ll fd,ll *f,ll *g,ll *res)
{
static ll tmpf[MAXN],tmpg[MAXN];
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
for(register int i=0;i<cnt;i++)
{
tmpf[i]=i<fd?f[i]:0,tmpg[i]=i<fd?g[i]:0;
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(tmpf,cnt,1),NTT(tmpg,cnt,1);
for(register int i=0;i<cnt;i++)
{
tmpf[i]=(li)tmpf[i]*tmpg[i]%MOD;
}
NTT(tmpf,cnt,-1),cpy(res,tmpf,fd);
}
inline void inv(ll fd,ll *f,ll *res)
{
static ll tmp[MAXN];
if(fd==1)
{
res[0]=qpow(f[0],MOD-2);
return;
}
inv((fd+1)>>1,f,res);
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
for(register int i=0;i<cnt;i++)
{
tmp[i]=i<fd?f[i]:0;
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(tmp,cnt,1),NTT(res,cnt,1);
for(register int i=0;i<cnt;i++)
{
res[i]=(li)(2-(li)tmp[i]*res[i]%MOD+MOD)%MOD*res[i]%MOD;
}
NTT(res,cnt,-1),clr(res+fd,cnt-fd-1);
}
inline void mod(ll fd,ll gd,ll *f,ll *g,ll *r)
{
static ll tmpf[MAXN],tmpg[MAXN],tinv[MAXN],q[MAXN];
if(fd<gd)
{
for(register int i=0;i<gd-1;i++)
{
r[i]=f[i];
}
return;
}
for(register int i=0;i<fd;i++)
{
tmpf[i]=f[fd-1-i];
}
for(register int i=0;i<gd;i++)
{
tmpg[i]=g[gd-1-i];
}
inv(fd-gd+2,tmpg,tinv);
ll cnt=1,limit=-1;
while(cnt<(fd<<1))
{
cnt<<=1,limit++;
}
for(register int i=0;i<cnt;i++)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(tmpf,cnt,1),NTT(tinv,cnt,1);
for(register int i=0;i<cnt;i++)
{
q[i]=1ll*tmpf[i]*tinv[i]%MOD;
}
NTT(q,cnt,-1),reverse(q,q+fd-gd+1);
for(register int i=0;i<cnt;i++)
{
tmpf[i]=tinv[i]=tmpg[i]=0;
q[i]=i<fd-gd+1?q[i]:0,g[i]=i<gd?g[i]:0;
}
cnt>>=1,limit--;
for(register int i=0;i<cnt;i++)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(q,cnt,1),NTT(g,cnt,1);
for(register int i=0;i<cnt;i++)
{
tmpf[i]=1ll*q[i]*g[i]%MOD;
}
NTT(g,cnt,-1),NTT(tmpf,cnt,-1);
for(register int i=0;i<gd-1;i++)
{
r[i]=(f[i]-tmpf[i]+MOD)%MOD;
}
for(register int i=0;i<cnt;i++)
{
q[i]=tmpf[i]=0;
}
}
vector<ll> tmpf2[MAXN<<2];
void dnc(ll *pts,ll l,ll r,ll node)
{
static ll tmp[MAXN],tmp2[MAXN];
if(l==r)
{
tmpf2[node].push_back((MOD-pts[l])%MOD),tmpf2[node].push_back(1);
return;
}
ll mid=(l+r)>>1,ls=node<<1,rs=ls|1;
dnc(pts,l,mid,ls),dnc(pts,mid+1,r,rs);
ll d=tmpf2[ls].size(),d2=tmpf2[rs].size();
copy(tmpf2[ls].begin(),tmpf2[ls].end(),tmp);
copy(tmpf2[rs].begin(),tmpf2[rs].end(),tmp2);
ll cnt=1,limit=-1;
while(cnt<(d+d2))
{
cnt<<=1,limit++;
}
for(register int i=0;i<cnt;i++)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
}
NTT(tmp,cnt,1),NTT(tmp2,cnt,1);
for(register int i=0;i<cnt;i++)
{
tmp[i]=(li)tmp[i]*tmp2[i]%MOD;
}
NTT(tmp,cnt,-1),tmpf2[node].resize(d+d2-1);
copy(tmp,tmp+d+d2-1,tmpf2[node].begin()),clr(tmp,cnt),clr(tmp2,cnt);
}
ll tmpf3[19][MAXN];
void dnc2(ll fd,ll *f,ll depth,ll l,ll r,ll node,ll *res)
{
static ll tmp[MAXN],pw[17];
if(r-l<=1024)
{
for(register int i=l;i<=r;i++)
{
ll x=c[i],cur=f[fd-1],v1,v2,v3,v4;
pw[0]=1;
for(register int j=1;j<=16;j++)
{
pw[j]=1ll*pw[j-1]*x%MOD;
}
for(register int j=fd-2;j-15>=0;j-=16)
{
v1=(1ll*cur*pw[16]+1ll*f[j]*pw[15]+
1ll*f[j-1]*pw[14]+1ll*f[j-2]*pw[13])%MOD;
v2=(1ll*f[j-3]*pw[12]+1ll*f[j-4]*pw[11]+
1ll*f[j-5]*pw[10]+1ll*f[j-6]*pw[9])%MOD;
v3=(1ll*f[j-7]*pw[8]+1ll*f[j-8]*pw[7]+
1ll*f[j-9]*pw[6]+1ll*f[j-10]*pw[5])%MOD;
v4=(1ll*f[j-11]*pw[4]+1ll*f[j-12]*pw[3]+
1ll*f[j-13]*pw[2]+1ll*f[j-14]*pw[1])%MOD;
cur=(0ll+v1+v2+v3+v4+f[j-15])%MOD;
}
for(register int j=((fd-1)&15)-1;~j;j--)
{
cur=(1ll*cur*x+f[j])%MOD;
}
res[i]=cur;
}
return;
}
ll sz=tmpf2[node].size()-1;
for(register int i=0;i<sz+1;i++)
{
tmp[i]=tmpf2[node][i];
}
clr(tmpf3[depth],sz+10),mod(fd,sz+1,f,tmp,tmpf3[depth]);
ll mid=(l+r)>>1;
dnc2(sz,tmpf3[depth],depth+1,l,mid,node<<1,res);
dnc2(sz,tmpf3[depth],depth+1,mid+1,r,(node<<1)|1,res);
for(register int i=0;i<sz;i++)
{
tmpf3[depth][i]=0;
}
}
inline void eval(ll fd,ll pcnt,ll *f,ll *pts,ll *res)
{
dnc(pts,0,pcnt-1,1),dnc2(fd,f,0,0,pcnt-1,1,res);
}
int main()
{
setupOmg(524288),n=read(),x=read(),y=read();
for(register int i=0;i<n-1;i++)
{
from=read(),to=read();
addEdge(from,to),addEdge(to,from);
}
dfs(x,0);
while(y!=x)
{
st[++tp]=y,p[y]=1,y=fa[y];
}
st[++tp]=y,p[y]=1,y=fa[y],z=n+1,fct=1;
for(register int i=1;i<=tp;i++)
{
c[i]=1;
for(register int j=last[st[i]];j;j=ed[j].prev)
{
if(!p[ed[j].to])
{
c[i]+=sz[ed[j].to];
}
}
}
for(register int i=1;i<=n;i++)
{
fct=(li)fct*i%MOD;
if(!p[i])
{
z=(li)z*sz[i]%MOD;
}
}
fct=(li)fct*(n+1)%MOD;
for(register int i=1;i<=tp;i++)
{
c[i]=(c[i-1]+c[i])%MOD;
}
dnc(c,1,tp+1,1);
for(register int i=1;i<=tp+1;i++)
{
pr[i-1]=(li)i*tmpf2[1][i]%MOD;
}
memset(tmpf2,0,sizeof(tmpf2)),eval(n+1,tp+1,pr,c,px);
for(register int i=0;i<=tp;i++)
{
cur=(li)fct*qpow((li)px[i]*z%MOD,MOD-2)%MOD;
res=(res+(((tp-i)&1)?MOD-cur:cur))%MOD;
}
printf("%d
",(li)res*499122177%MOD);
}