ytq鸽鸽出的题真是毒瘤
原题传送门
题目大意:
有一棵有(n)个点的树,求有多少方案选(k)个联通块使得存在一个中心点(p),所有(k)个联通块中所有点到(p)的距离都(leq L)
我们先从部分分考虑:
算法1
(2^{nk})枚举联通快判断是否可行
期望得分:8pts
代码就不给了
引理1
对于任意一个联通块,可能成为它的中心点的点组成的集合也是一个联通块。正确性显然
- 推论
这里作为(k)个联通块中心点的点组成的集合是一个联通块
算法2
枚举推论中中心点组成的集合,判断其他是否满足
期望得分:16pts
代码也不给了
引理2
注意到树上联通块中,点集为(S),边集为(E),有(|S|-|E|=1)。正确性显然
这样我们就珂以对每个点、边单独考虑,通过容斥算出答案。设(f(x))表示(x)控制的联通块个数,(g(e))表示(e)两端点都控制的联通块个数,易知答案为
算法3
我们考虑dp,设(f_{i,j})表示(i)点子树中距离(i)不超过(j)且包含(i)的个数,(g_{i,j})表示(i)点子树外距离(i)不超过(j)且包含(i)的个数,转移如下:
对于每个点(i),对答案的贡献为:
对于每条边(e)(设(v)为深度更深的点),对答案的贡献为:
时间复杂度:(O(nL))
期望得分:36pts
算法4:
对于(L=n)的测试点,第二维的限制没用了,去掉即可
时间复杂度:(O(n))
期望得分:结合算法3可得48pts
算法5:
对于链的情况,珂以手推一下dp的贡献,发现就是距离,算一下即可
时间复杂度:(O(n))
期望得分:结合算法4可得52pts
这是暴力52pts的做法,代码如下,subtask1是算法3,subtask2是算法4,subtask3是算法5
#include <bits/stdc++.h>
#define mod 998244353
#define N 100005
#define getchar nc
using namespace std;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
register int x=0,f=1;register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*f;
}
inline void write(register int x)
{
if(!x)putchar('0');if(x<0)x=-x,putchar('-');
static int sta[20];register int tot=0;
while(x)sta[tot++]=x%10,x/=10;
while(tot)putchar(sta[--tot]+48);
}
inline int power(register int a,register int b)
{
int res=1;
while(b)
{
if(b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod;
b>>=1;
}
return res;
}
struct egde{
int to,next;
}e[N<<2];
int head[N<<1],cnt=0;
inline void add(register int u,register int v)
{
e[++cnt]=(egde){v,head[u]};
head[u]=cnt;
}
int n,m,k,ans;
namespace subtask1{
vector<int> f[N],g[N];
inline void dfs1(register int x,register int fa)
{
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)
continue;
dfs1(v,x);
for(register int j=1;j<=m;++j)
f[x][j]=1ll*f[x][j]*(f[v][j-1]+1)%mod;
}
}
inline void dfs2(register int x,register int fa)
{
static int son[N],pre[N],suf[N];
int cnt=0;
for(register int i=head[x];i;i=e[i].next)
if(e[i].to!=fa)
son[++cnt]=e[i].to;
for(register int i=pre[0]=suf[cnt+1]=1;i<=m;++i)
{
if(i>=2)
{
for(register int j=1;j<=cnt;++j)
pre[j]=1ll*pre[j-1]*(f[son[j]][i-2]+1)%mod;
for(register int j=cnt;j>=1;--j)
suf[j]=1ll*suf[j+1]*(f[son[j]][i-2]+1)%mod;
}
else
{
for(register int j=1;j<=cnt;++j)
pre[j]=suf[j]=1;
}
for(register int j=1;j<=cnt;++j)
g[son[j]][i]=(1ll*g[x][i-1]*pre[j-1]%mod*suf[j+1]+1)%mod;
}
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)
continue;
ans=(ans+power(1ll*(g[v][m]-1)*f[v][m-1]%mod,k))%mod;
dfs2(v,x);
}
}
inline void solve()
{
for(register int i=1;i<=n;++i)
f[i].resize(m+1,1),g[i].resize(m+1,1);
dfs1(1,0);
dfs2(1,0);
ans=mod-ans;
for(register int i=1;i<=n;++i)
ans=(0ll+ans+power(1ll*f[i][m]*g[i][m]%mod,k))%mod;
write(ans);
}
}
namespace subtask2{
int a[200005],b[200005];
inline void dfs1(register int x,register int fa)
{
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)
continue;
dfs1(v,x);
a[x]=1ll*a[x]*(a[v]+1)%mod;
}
}
inline void dfs2(register int x,register int fa)
{
static int son[200005],pre[200005],suf[200005];
int cnt=0;
for(register int i=head[x];i;i=e[i].next)
if(e[i].to!=fa)
son[++cnt]=e[i].to;
pre[0]=suf[cnt+1]=1;
for(register int j=1;j<=cnt;++j)
pre[j]=1ll*pre[j-1]*(a[son[j]]+1)%mod;
for(register int j=cnt;j>=1;--j)
suf[j]=1ll*suf[j+1]*(a[son[j]]+1)%mod;
for(register int j=1;j<=cnt;++j)
b[son[j]]=(1ll*b[x]*pre[j-1]%mod*suf[j+1]+1)%mod;
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)
continue;
ans=(ans+power(1ll*(b[v]-1)*a[v]%mod,k))%mod;
dfs2(v,x);
}
}
inline void solve()
{
for(register int i=0;i<=n;++i)
a[i]=b[i]=1;
dfs1(1,0);
dfs2(1,0);
ans=mod-ans;
for(register int i=1;i<=n;++i)
ans=(ans+power(1ll*a[i]*b[i]%mod,k))%mod;
write(ans);
}
}
namespace subtask3{
inline void solve()
{
for(register int i=1;i<=n;++i)
{
int l=max(1,i-m),r=min(n,i+m);
ans=(ans+power(1ll*(i-l+1)*(r-i+1)%mod,k))%mod;
}
for(register int i=2;i<=n;++i)
{
int l=max(1,i-m),r=min(n,i+m-1);
ans=(ans+mod-power(1ll*(i-l+1-1)*(r-i+1)%mod,k))%mod;
}
write(ans);
}
}
int main()
{
n=read(),m=read(),k=read();
if(n<=200000)
{
for(register int i=1;i<n;++i)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
}
if(m==n&&n<=200000)
subtask2::solve();
else if(n<=100000)
subtask1::solve();
else
subtask3::solve();
return 0;
}
算法6
发现(f,g)的转移珂以用长链剖分优化,套上一个可持久化数据结构,需要资瓷区间加、区间乘
时间复杂度:(O(nlog n))
期望得分:结合算法5可得84pts
算法7
发现是全局加,后缀乘,这样可以像SDOI2019Day1T1 快速查询
珂以用数组线性维护,为了去掉求逆元的(log),我们珂以用类似于求阶乘逆元的方法(O(n))预处理要用的逆元
至此,我们得到了一个(O(n))的做法
期望得分:100pts
#include <bits/stdc++.h>
#define mod 998244353
#define N 1000005
#define getchar nc
using namespace std;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
register int x=0,f=1;register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*f;
}
inline void write(register int x)
{
if(!x)putchar('0');if(x<0)x=-x,putchar('-');
static int sta[20];register int tot=0;
while(x)sta[tot++]=x%10,x/=10;
while(tot)putchar(sta[--tot]+48);
}
inline int power(register int a,register int b)
{
int res=1;
while(b)
{
if(b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod;
b>>=1;
}
return res;
}
struct edge{
int to,next;
}e[N<<1];
int head[N],cnt=0;
inline void add(register int u,register int v)
{
e[++cnt]=(edge){v,head[u]};
head[u]=cnt;
}
int n,m,k;
int dp[N],son[N],len[N],val[N],idx[N],tot,inv[N];
int arr[N<<3],*pos=arr,*f[N<<1],*g[N],ans[2][N],Ans;
inline void dfs1(register int x,register int fa)
{
dp[x]=1;
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa)
continue;
dfs1(v,x);
son[x]=len[v]>len[son[x]]?v:son[x];
dp[x]=1ll*dp[x]*dp[v]%mod;
}
if(++dp[x]<mod)
val[idx[x]=++tot]=dp[x];
len[x]=len[son[x]]+1;
}
inline void init_inv()
{
static int pre[N];
for(register int i=pre[0]=1;i<=tot;++i)
pre[i]=1ll*pre[i-1]*val[i]%mod;
int pre_inv=power(pre[tot],mod-2);
for(register int i=tot;i;--i)
inv[i]=1ll*pre[i-1]*pre_inv%mod,pre_inv=1ll*pre_inv*val[i]%mod;
}
struct node{
int a,b,inv,pos,num;
};
namespace F{
node tag[N<<1];
vector<pair<node,vector<pair<int,int> > > >backup[N];
inline int get(register int u,register int i)
{
int res=i<tag[u].pos?f[u][i]:tag[u].num;
return (1ll*res*tag[u].a+tag[u].b)%mod;
}
inline void put(register int u,register int i,register int v)
{
f[u][i]=1ll*(v+mod-tag[u].b)*tag[u].inv%mod;
}
inline void merge(register int u,register int v,register int l)
{
node tmp=tag[u];
vector<pair<int,int> >vec;
for(register int i=1;i<=l;++i)
{
vec.push_back(make_pair(i,f[u][i]));
int val=get(v,i-1);
if(i==tag[u].pos)
f[u][tag[u].pos++]=tag[u].num;
put(u,i,1ll*get(u,i)*val%mod);
}
if(l<m)
{
int val=get(v,l);
if(!val)
tag[u].pos=l+1,tag[u].num=mod-1ll*tag[u].b*tag[u].inv%mod;
else
{
int t=inv[idx[v]];
vec.push_back(make_pair(0,f[u][0]));
for(register int i=0;i<=l;++i)
put(u,i,1ll*get(u,i)*t%mod);
tag[u].a=1ll*tag[u].a*val%mod;
tag[u].b=1ll*tag[u].b*val%mod;
tag[u].inv=1ll*tag[u].inv*t%mod;
}
}
if(u<=n)
backup[u].push_back(make_pair(tmp,vec));
}
inline void dfs2(register int x,register int fa)
{
if(son[x])
f[son[x]]=f[x]+1,dfs2(son[x],x),tag[x]=tag[son[x]];
else
tag[x]=(node){1,1,1,n,0};
put(x,0,1);
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa||v==son[x])
continue;
f[v]=pos;
pos+=len[v];
dfs2(v,x);
merge(x,v,min(len[v]-1,m));
}
ans[0][x]=get(x,min(len[x]-1,m-1));
ans[1][x]=get(x,min(len[x]-1,m));
++tag[x].b;
}
inline void rollback(register int u)
{
tag[u]=backup[u].back().first;
for(register int i=0;i<(backup[u].back().second).size();++i)
f[u][(backup[u].back().second)[i].first]=(backup[u].back().second)[i].second;
backup[u].pop_back();
}
}
namespace G{
node tag[N];
inline int get(register int u,register int i)
{
int res=i<tag[u].pos?g[u][i]:tag[u].num;
return (1ll*res*tag[u].a+tag[u].b)%mod;
}
inline void put(register int u,register int i,register int v)
{
g[u][i]=1ll*(v+mod-tag[u].b)*tag[u].inv%mod;
}
inline void dfs3(register int x,register int fa)
{
if(len[x]-m-1>=0)
put(x,len[x]-m-1,1);
Ans=(Ans+power(1ll*ans[1][x]*get(x,len[x]-1)%mod,k))%mod;
if(fa)
Ans=(Ans-power(1ll*ans[0][x]*(get(x,len[x]-1)+mod-1)%mod,k)+mod)%mod;
if(!son[x])
return;
vector<int> ch;
int maxlen=0;
for(register int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa||v==son[x])
continue;
ch.push_back(v);
maxlen=max(maxlen,len[v]);
}
maxlen=min(maxlen,m);
reverse(ch.begin(),ch.end());
f[x+n]=pos;
pos+=maxlen+1;
F::tag[x+n]=(node){1,1,1,n,0};
F::put(x+n,0,1);
for(register int id=0;id<ch.size();++id)
{
int v=ch[id];
F::rollback(x);
g[v]=pos;
pos+=len[v];
for(register int i=max(len[v]-m-1,0);i<len[v];++i)
{
if(m-len[v]+i==-1)
g[v][i]=get(x,len[son[x]]-len[v]+i);
else
g[v][i]=1ll*get(x,len[son[x]]-len[v]+i)*F::get(x,min(len[x]-1,m-len[v]+i))%mod*F::get(x+n,min(maxlen,m-len[v]+i))%mod;
}
tag[v]=(node){1,1,1,n,0};
F::merge(x+n,v,min(len[v]-1,m));
dfs3(v,x);
}
int v=son[x];
g[v]=g[x];
tag[v]=tag[x];
for(register int i=max(len[v]-m,0);i<=len[v]+maxlen-m-1;++i)
{
if(tag[v].pos==i)
g[v][tag[v].pos++]=tag[v].num;
put(v,i,1ll*get(v,i)*F::get(x+n,m-len[v]+i)%mod);
}
if(maxlen<m)
{
int vv=1,tt=1;
for(register int i=0;i<ch.size();++i)
vv=1ll*vv*val[idx[ch[i]]]%mod,tt=1ll*tt*inv[idx[ch[i]]]%mod;
if(!vv)
tag[v].pos=len[v]+maxlen-m,tag[v].num=mod-1ll*tag[v].b*tag[v].inv%mod;
else
{
for(register int i=max(len[v]-m-1,0);i<=len[v]+maxlen-m-1;++i)
put(v,i,1ll*get(v,i)*tt%mod);
tag[v].a=1ll*tag[v].a*vv%mod;
tag[v].b=1ll*tag[v].b*vv%mod;
tag[v].inv=1ll*tag[v].inv*tt%mod;
}
}
++tag[v].b;
dfs3(v,x);
}
}
int main()
{
n=read(),m=read(),k=read();
for(register int i=1;i<n;++i)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs1(1,0);
init_inv();
f[1]=pos;
pos+=len[1];
F::dfs2(1,0);
g[1]=pos;
pos+=len[1];
G::tag[1]=(node){1,1,1,n,0};
G::dfs3(1,0);
write(Ans);
return 0;
}