题目
题目链接:https://www.luogu.com.cn/problem/P4516
外星人又双叒叕要攻打地球了,外星母舰已经向地球航行!这一次,JYY
已经联系好了黄金舰队,打算联合所有 JSOIer
抵御外星人的进攻。
在黄金舰队就位之前,JYY
打算事先了解外星人的进攻计划。现在,携带了监听设备的特工已经秘密潜入了外星人的母舰,准备对外星人的通信实施监听。
外星人的母舰可以看成是一棵 (n) 个节点、 (n-1) 条边的无向树,树上的节点用 (1,2,cdots,n) 编号。JYY
的特工已经装备了隐形模块,可以在外星人母舰中不受限制地活动,可以神不知鬼不觉地在节点上安装监听设备。
如果在节点 (u) 上安装监听设备,则 JYY
能够监听与 (u) 直接相邻所有的节点的通信。换言之,如果在节点 (u) 安装监听设备,则对于树中每一条边 ((u,v)) ,节点 (v) 都会被监听。特别注意放置在节点 (u) 的监听设备并不监听 (u) 本身的通信,这是 JYY
特别为了防止外星人察觉部署的战术。
JYY
的特工一共携带了 (k) 个监听设备,现在 JYY
想知道,有多少种不同的放置监听设备的方法,能够使得母舰上所有节点的通信都被监听?为了避免浪费,每个节点至多只能安装一个监听设备,且监听设备必须被用完。
(nleq 10^5;kleq 100)。
思路
太强了这道题。(说不定是我菜了 /kk)
很容易想到一个无脑的 dp:设 (f[x][i][0/1][0/1]) 表示点 (x) 为根的子树内,选择了 (i) 个点,其中点 (x) 选不选,点 (x) 是否被覆盖的方案数。
直接暴力转移看起来是 (O(nk^2)) 的。发现是一个卷积的形式用 MTT 即可做到 O(nk log k)。
其实如果转移的时候第二维严格的只枚举到上界((min(k, ext{siz}(x)))),复杂度是 (O(nk)) 的。
证明如下,完全照抄 yyb 神仙:
- 合并的两个子树大小都 (geq k) 时,不难发现这样的合并次数是 (O(frac{n}{k})) 的,所以时间复杂度为 (O(frac{n}{k} imes k^2)=O(nk))。
- 合并的两个子树大小都 (<k) 且子树大小之和 (<k) 时,两棵子树都只枚举到子树大小,那么可以看作是枚举两边的点(复杂度一样),那么考虑树上的任意一个点,它作为其中一棵子树内的节点被枚举到时,另一棵子树大小之和是 (O(k)) 的。因为如果所有另一棵子树大小之和超过 (k),就不满足合并的两棵子树大小之和 (<k)。复杂度为 (O(nk))。
- 合并的两个子树大小中至少一棵 (<k) 且两棵大小只和 (geq k) 时,依然可以看作枚举了每一个点,且每一个点最多只会在一个祖先的合并处发生这样的情况,也就是每一个点最多只会处理这种情况一次,而每一次一个点的复杂度为 (O(k)),所以总复杂度依然是 (O(nk))。
综上,时间复杂度是 (O(nk)) 的。
代码
#include <bits/stdc++.h>
#define reg register
using namespace std;
const int N=100010,M=105,MOD=1000000007;
int n,m,tot,head[N],siz[N],f[N][M][2][2],g[M][2][2];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs(int x,int fa)
{
f[x][0][0][0]=f[x][1][1][0]=siz[x]=1;
for (reg int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs(v,x);
memcpy(g,f[x],sizeof(g));
memset(f[x],0,sizeof(f[x]));
for (reg int j=0;j<=min(siz[x],m);j++)
for (reg int k=0;k<=min(siz[v],m-j);k++)
for (reg int l=0;l<=15;l++)
{
int a=l&1,b=(l>>1)&1,c=(l>>2)&1,d=(l>>3)&1;
if (a|d) f[x][j+k][a][b|c]=(f[x][j+k][a][b|c]+1LL*g[j][a][b]*f[v][k][c][d])%MOD;
}
siz[x]+=siz[v];
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs(1,0);
cout<<(f[1][m][0][1]+f[1][m][1][1])%MOD;
return 0;
}