题目大意
有(n)盏灯,(m)个限制。每个限制((x,y))表示第(x)盏灯与第(y)盏灯之间必须且只能亮一盏。
记一种情况(x)亮着的灯的数量为(f_x),求(sum {(f_x)}^k)
(nleq 200000,kleq 100)
题解
我们先把整张图黑白染色。
如果不是二分图就无解。
我们发现两个不同的联通分量的灯的状态是没有关系的。
我们可以考虑DP:
(f_{i,j}=)前(i)个联通分量中亮的彩灯个数的(j)次方和
(a_{i,j}=)第(i)个联通分量中亮的彩灯个数的(j)次方和
根据二项式定理({(a+b)}^n=sum_{i=0}^ninom nia^ib^{n-i})有:
[f_{i,j}=sum_{k=0}^{j}inom{j}{k} a_{i-1,k} imes f_{i-1,j-k}
]
[=sum_{k=0}^{j}frac{j!}{k!(j-k)!}a_{i-1,k} imes f_{i-1,j-k}
]
[=j!sum_{k=0}^{j}frac{a_{i-1,k}}{k!} imes frac{f_{i-1,j-k}}{(j-k)!}
]
观察到(p=1004535809)是一个NTT模数(原根为(3)),所以可以用NTT加速。
没了
时间复杂度:(O(nklog k))
题解2
FFT很明显会TLE
有一个公式:
[m^n=sum_{k=0}^minom{m}{k}S(n,k)k!
]
后面两个东西很容易算,只用考虑第一个怎么求
这不就是个组合数吗
直接DP就行了。
时间复杂度:(O(nk))
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
ll p=1004535809;
ll fp(ll a,ll b)
{
ll s=1;
while(b)
{
if(b&1)
s=s*a%p;
a=a*a%p;
b>>=1;
}
return s;
}
namespace ntt
{
int N;
ll w1[500010];
ll w2[500010];
int rev[500010];
void get(int n)
{
N=1;
while(N<n)
N<<=1;
int i;
for(i=2;i<=N;i++)
{
w1[i]=fp(3,(p-1)/i);
w2[i]=fp(w1[i],p-2);
}
rev[0]=0;
for(i=1;i<N;i++)
rev[i]=(rev[i>>1]>>1)|(i&1?N>>1:0);
}
void ntt(ll *a,int t)
{
int i,j,k;
ll u,v,w,wn;
for(i=0;i<N;i++)
if(rev[i]<i)
swap(a[i],a[rev[i]]);
for(i=2;i<=N;i<<=1)
{
wn=t?w1[i]:w2[i];
for(j=0;j<N;j+=i)
{
w=1;
for(k=j;k<j+i/2;k++)
{
u=a[k];
v=a[k+i/2]*w%p;
a[k]=(u+v)%p;
a[k+i/2]=(u-v+p)%p;
w=w*wn%p;
}
}
}
if(!t)
{
ll inv=fp(N,p-2);
for(i=0;i<N;i++)
a[i]=a[i]*inv%p;
}
}
}
struct list
{
int v[500010];
int t[500010];
int h[200010];
int n;
list()
{
n=0;
memset(h,0,sizeof h);
}
void add(int x,int y)
{
n++;
v[n]=y;
t[n]=h[x];
h[x]=n;
}
};
list l;
int vis[200010];
void failed()
{
putchar('0');
exit(0);
}
int s1,s2;
void dfs(int x,int fa,int v)
{
vis[x]=v;
if(v)
s1++;
else
s2++;
int i;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=fa)
{
if(vis[l.v[i]]==-1)
dfs(l.v[i],x,v^1);
else if(vis[l.v[i]]==vis[x])
failed();
}
}
ll f[500010];
ll a[500010];
int n,m,k;
ll fac[500010];
int main()
{
freopen("bulb.in","r",stdin);
freopen("bulb.out","w",stdout);
scanf("%d%d%d",&n,&m,&k);
ntt::get(2*(k+1));
int i,j,x,y;
fac[0]=1;
for(i=1;i<=1000;i++)
fac[i]=fac[i-1]*i%p;
for(i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
l.add(x,y);
l.add(y,x);
}
memset(vis,-1,sizeof vis);
f[0]=1;
for(i=1;i<=n;i++)
if(vis[i]==-1)
{
s1=s2=0;
dfs(i,0,1);
for(j=0;j<=k;j++)
a[j]=(fp(s1,j)+fp(s2,j))%p;
for(j=k+1;j<ntt::N;j++)
a[j]=f[j]=0;
for(j=0;j<ntt::N;j++)
{
ll inv=fp(fac[j],p-2);
a[j]=a[j]*inv%p;
f[j]=f[j]*inv%p;
}
ntt::ntt(a,1);
ntt::ntt(f,1);
for(j=0;j<ntt::N;j++)
f[j]=f[j]*a[j]%p;
ntt::ntt(f,0);
for(j=0;j<=k;j++)
f[j]=f[j]*fac[j]%p;
// do something here ...
}
printf("%lld
",f[k]);
return 0;
}
代码2
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
ll p=1004535809;
ll fp(ll a,ll b)
{
ll s=1;
while(b)
{
if(b&1)
s=s*a%p;
a=a*a%p;
b>>=1;
}
return s;
}
struct list
{
int v[500010];
int t[500010];
int h[200010];
int n;
list()
{
n=0;
memset(h,0,sizeof h);
}
void add(int x,int y)
{
n++;
v[n]=y;
t[n]=h[x];
h[x]=n;
}
};
list l;
int vis[200010];
void failed()
{
putchar('0');
exit(0);
}
int s1,s2;
void dfs(int x,int fa,int v)
{
vis[x]=v;
if(v)
s1++;
else
s2++;
int i;
for(i=l.h[x];i;i=l.t[i])
if(l.v[i]!=fa)
{
if(vis[l.v[i]]==-1)
dfs(l.v[i],x,v^1);
else if(vis[l.v[i]]==vis[x])
failed();
}
}
int n,m,k;
ll fac[500010];
ll f[510];
ll f2[510];
ll f1[510];
void dp1()
{
int i;
for(i=k;i>=1;i--)
f1[i]=(f1[i]+f1[i-1])%p;
}
void dp2()
{
int i;
for(i=k;i>=1;i--)
f2[i]=(f2[i]+f2[i-1])%p;
}
ll s[510][510];
int main()
{
// freopen("bulb.in","r",stdin);
// freopen("bulb4.out","w",stdout);
scanf("%d%d%d",&n,&m,&k);
int i,j,x,y;
fac[0]=1;
for(i=1;i<=1000;i++)
fac[i]=fac[i-1]*i%p;
s[0][0]=1;
for(i=1;i<=100;i++)
{
s[i][0]=0;
for(j=1;j<=100;j++)
s[i][j]=(s[i-1][j-1]+s[i-1][j]*j%p)%p;
}
for(i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
l.add(x,y);
l.add(y,x);
}
memset(vis,-1,sizeof vis);
f[0]=1;
for(i=1;i<=n;i++)
if(vis[i]==-1)
{
s1=s2=0;
dfs(i,0,1);
for(j=0;j<=k;j++)
f1[j]=f2[j]=f[j];
for(j=1;j<=s1;j++)
dp1();
for(j=1;j<=s2;j++)
dp2();
for(j=0;j<=k;j++)
f[j]=(f1[j]+f2[j])%p;
}
ll ans=0;
for(i=0;i<=k;i++)
ans=(ans+f[i]*s[k][i]%p*fac[i]%p)%p;
printf("%lld
",ans);
return 0;
}