题目
题目链接:https://codeforces.com/contest/1111/problem/E
给一棵 (n) 个结点的树,(q) 次询问,每次询问首先是三个数 (k,m,r),接下来跟着 (k) 个结点编号,请你将这 (k) 个结点分成不超过 (m) 组,使得在以 (r) 为根的情况下,组内的任意两个结点不存在祖先关系,求方案数对 (10^9+7) 取模。
(n,q,sum kleq 10^5),(mleq 300)。
思路
先不考虑复杂度,对于每一次询问把虚树建出来。那么虚树上任意一个点 (x),就不可以和 (x) 虚树上祖先放在同一组里。
把点按照 dfs 序排序,记 (cnt[x]) 表示虚树上 (x) 特殊点祖先数量,设 (f[i][j]) 表示前 (i) 个点分成 (j) 组的方案数。转移为
[f[i][j]=f[i-1][j-1]+f[i-1][j] imes (j-cnt[i])
]
前面就是 (i) 单独分为一组,后面是因为这 (cnt[i]) 个祖先肯定都属于不同组,(i) 能选的组数就是 (j-cnt[i])。
观察到虚树的作用仅仅是求目前询问的根到一个点的路径上有多少个特殊点。所以直接用树剖 + 树状数组代替就可以了。
而换根把点按照 dfs 序排序的方法有很多。最简单的是其实观察到不一定需要是 dfs 序,只需要保证点 (x) 所有祖先都在 (x) 的前面,所以直接按照 (cnt) 排就行。除此之外还可以离线下来,换根用线段树维护 dfs 序,而我的代码中是求出 LCA 之后再计算 dfs 序。比较麻烦。
时间复杂度 (O(sum k(m+log^2 n)))。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=100010,M=310,LG=18,MOD=1e9+7;
int n,Q,tot,a[N],ans[N],head[N],id[N],top[N],siz[N],fa[N],dep[N],son[N],dfn[N],f[N][M],pa[N][LG+1];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs1(int x,int fat)
{
dep[x]=dep[fat]+1; fa[x]=fat; siz[x]=1;
pa[x][0]=fat;
for (int i=1;i<=LG;i++)
pa[x][i]=pa[pa[x][i-1]][i-1];
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fat)
{
dfs1(v,x); siz[x]+=siz[v];
if (siz[v]>siz[son[x]]) son[x]=v;
}
}
}
void dfs2(int x,int tp)
{
id[x]=++tot; top[x]=tp;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa[x] && v!=son[x]) dfs2(v,v);
}
}
struct BIT
{
int c[N];
void add(int x,int v)
{
for (int i=x;i<=n;i+=i&-i)
c[i]+=v;
}
int query(int x)
{
int ans=0;
for (int i=x;i;i-=i&-i)
ans+=c[i];
return ans;
}
}bit;
int lca(int x,int y)
{
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if (dep[x]<dep[y]) swap(x,y);
return y;
}
int jump(int x,int y)
{
for (int i=LG;i>=0;i--)
if (dep[pa[x][i]]>dep[y]) x=pa[x][i];
return x;
}
int query(int x,int y)
{
int res=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
res+=bit.query(id[x])-bit.query(id[top[x]]-1);
x=fa[top[x]];
}
if (dep[x]<dep[y]) swap(x,y);
return res+bit.query(id[x])-bit.query(id[y]-1);
}
bool cmp(int x,int y)
{
return dfn[x]<dfn[y];
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&Q);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
tot=0;
dfs1(1,0); dfs2(1,1);
while (Q--)
{
int k,m,rt;
scanf("%d%d%d",&k,&m,&rt);
for (int i=1;i<=k;i++)
{
scanf("%d",&a[i]);
int p=lca(a[i],rt),q=jump(rt,p);
dfn[a[i]]=siz[q]+id[a[i]]-id[p]+1;
bit.add(id[a[i]],1);
}
sort(a+1,a+1+k,cmp);
f[0][0]=1;
for (int i=1;i<=k;i++)
{
int cnt=query(a[i],rt)-1;
for (int j=1;j<=m;j++)
{
f[i][j]=f[i-1][j-1];
if (j>cnt) f[i][j]=(f[i][j]+1LL*f[i-1][j]*(j-cnt))%MOD;
}
}
int ans=0;
for (int i=1;i<=m;i++)
ans=(ans+f[k][i])%MOD;
cout<<ans<<"
";
for (int i=1;i<=k;i++)
bit.add(id[a[i]],-1);
}
return 0;
}