题目大意
给定一棵带边权树,定义一个点集合法当且仅当点集大小为$m$,且存在一个树上的任意节点满足点集中的所有点到这个点的距离均不超过$k$,求合法点集的数目,对$998244353$取模。
节点数$1 leq n leq 10^5$。
题解
考虑如何不重不漏地统计。
于是,对于每个点集,考虑在深度最小的、令这个点集合法的位置统计它。
此时,一组会被当前节点统计到的点集为所有至少包含了一个与当前节点的父亲的距离大于$k$的合法节点的点集。
那么,问题变成了对于每个节点,求出有多少个节点满足与当前节点的距离小于等于$k$,同时再求出有多少个节点在满足上面的条件的情况下,与当前节点的父亲的距离大于$k$。
前者考虑使用点分治配合数据结构完成,后者由于只存在于当前节点的子树中,可以考虑按$dfs$序遍历整棵树,并通过线段树数出。
最后通过一个节点统计到的所有点集数量减去不包含任何一个与当前节点父亲距离大于$k$的节点的点集数量即可。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+9;
const int E=1e9+7;
const ll md=998244353;
inline int read()
{
int x=0;char ch=getchar();
while(ch<'0' || '9'<ch)ch=getchar();
while('0'<=ch && ch<='9')x=x*10+(ch^48),ch=getchar();
return x;
}
inline int chkmax(int &a,int b){if(a<b)a=b;}
inline ll qpow(ll a,ll b)
{
ll ret=1;
while(b)
{
if(b&1)ret=ret*a%md;
a=a*a%md;b>>=1;
}
return ret;
}
int n,m,k;
int to[N<<1],nxt[N<<1],w[N<<1],beg[N],tot;
int id[N],ed[N],seg[N],siz[N<<1],dfn;
ll dis[N],cnt[N],cntf[N];
ll fac[N],inv[N],ans;
bool ban[N<<1];
namespace segt
{
const int M=N*60;
int t[M],ls[M],rs[M],tot;
inline void clear(){tot=0;}
inline int newnode()
{
int ret=++tot;
t[ret]=ls[ret]=rs[ret]=0;
return ret;
}
inline void insert(int &x,int l,int r,int p,int v)
{
if(!x)x=newnode();t[x]+=v;
if(l==r)return;int mid=l+r>>1;
if(p<=mid)insert(ls[x],l,mid,p,v);
else insert(rs[x],mid+1,r,p,v);
}
inline int query(int &x,int l,int r,int dl,int dr)
{
if(l==dl && r==dr)return t[x];
if(!x)return 0;int mid=l+r>>1;
if(dr<=mid)return query(ls[x],l,mid,dl,dr);
if(mid<dl)return query(rs[x],mid+1,r,dl,dr);
return query(ls[x],l,mid,dl,mid)+query(rs[x],mid+1,r,mid+1,dr);
}
}
inline void add(int u,int v,int c)
{
to[++tot]=v;
nxt[tot]=beg[u];
w[tot]=c;
beg[u]=tot;
}
inline int dfs_rt(int u,int fa,int &rt,int tsiz)
{
int sz=1,mx=0;
for(int i=beg[u];i;i=nxt[i])
if(!ban[i] && to[i]!=fa)
{
chkmax(mx,siz[i]=dfs_rt(to[i],u,rt,tsiz));
sz+=siz[i];siz[i^1]=tsiz-siz[i];
}
chkmax(mx,tsiz-sz);
if((mx<<1)<=tsiz)rt=u;
return sz;
}
inline void dfs_seq(int u,int fa)
{
seg[id[u]=++dfn]=u;
for(int i=beg[u];i;i=nxt[i])
if(!ban[i] && to[i]!=fa)
{
dis[to[i]]=dis[u]+w[i];
dfs_seq(to[i],u);
}
ed[u]=dfn;
}
inline void insert(int u,int v)
{
for(int j=id[u];j<=ed[u];j++)
if(dis[seg[j]]<=k)
segt::insert(1,0,E,dis[seg[j]],-1);
else j=ed[seg[j]];
}
inline void query(int u)
{
for(int j=id[u];j<=ed[u];j++)
if(dis[seg[j]]<=k)
cnt[seg[j]]+=segt::query(1,0,E,0,k-dis[seg[j]]);
else j=ed[seg[j]];
}
inline void work(int u,int tsiz)
{
int rt;
dfs_rt(u,0,rt,tsiz);
dfs_seq(rt,0);
segt::clear();
insert(u,1);
cnt[u]=cnt[u]+segt::query(1,0,E,k);
for(int i=beg[u];i;i=nxt[i])
if(!ban[to[i]])
{
insert(to[i],-1);
query(to[i]);
insert(to[i],1);
}
for(int i=beg[u];i;i=nxt[i])
if(!ban[i])
{
ban[i]=ban[i^1]=1;
work(to[i],siz[i]);
}
}
inline void workf(int u,int fa)
{
if(~fa)
cntf[u]-=segt::query(dis[fa]+k+1,dis[u]+k);
segt::insert(1,1,n,dis[u],1);
for(int i=beg[u];i;i=nxt[i])
if(to[i]!=fa)
{
dis[to[i]]=dis[u]+w[i];
workf(to[i],u);
}
if(~fa)
cntf[u]+=segt::query(dis[fa]+k+1,dis[u]+k);
}
inline void init()
{
fac[0]=1;
for(ll i=1;i<N;i++)
fac[i]=fac[i-1]*i%md;
inv[N-1]=qpow(fac[N-1],md-2);
for(ll i=N-1;i>=1;i--)
inv[i-1]=inv[i]*i%md;
}
inline ll c(ll a,ll b)
{
return fac[a]*inv[b]%md*inv[a-b]%md;
}
int main()
{
freopen("party.in","r",stdin);
freopen("party.out","w",stdout);
n=read();m=read();k=read();
for(int i=1,u,v,c;i<n;i++)
{
u=read();v=read();c=read();
add(u,v,c);add(v,u,c);
}
work(1);
for(int i=1;i<=n;i++)
(ans+=c(cnt[i],m))%=md;
segt::clear();
dis[1]=0;
workf(1,-1);
for(int i=1;i<=n;i++)
(ans+=md-c((cnt[i]-cntf[i]+md)%md,m))%=md;
printf("%lld
",ans);
return 0;
}