题目链接:G. Mercenaries
题目大意:
- 有(n)名佣兵,你可以雇佣第(i)名当且仅当你所雇佣的佣兵总人数在([l_i,r_i])中。
- 有(m)组限制,每组限制给出(a_i,b_i)表示(a_i,b_i)两名佣兵不可同时雇佣。
- 求雇佣至少(1)名佣兵的方案数。
题解:官方题解太神了没看懂(其实就是菜qwq)。
下面讲一下我口胡的思路。
如果(m=0),那么我们可以枚举最后选多少个佣兵,然后一个简单的组合数解决问题。
那么我们发现(m)非常小,这提示我们可以把所有和限制搭上边的佣兵单独拎出来算,现在我们需要知道的就是有限制的佣兵中钦定有些佣兵不可以选,另一些选出若干个的方案数,这很明显应当用状压 DP 解决,然而一个尴尬的事情发生了,(mleq 20)意味着最多会有(40)个有限制的佣兵,(2^{40})妥妥地爆掉了。
似乎陷入了绝境……
换一种思维方式,考虑到如果一个佣兵有限制,但是能够限制他的佣兵在我们枚举的选取的总人数中一个都不符合要求,那么这个佣兵可以暂时地归类到普通佣兵里,这启示我们可以不压缩佣兵,而选择压缩佣兵之间的限制,这个显然是可以做的。
那么,这一题就做完了,最后形式化地描述一下算法的思路流程,先预处理出(2^m)个限制的选取集合中选取佣兵的个数为(0)~(m)次的方案数,然后枚举我们选多少个佣兵,将所有符合条件且有限制的佣兵统计出来,并且统计出它们之间的连边,如果一个佣兵没有佣兵与他相连,就将其加到普通佣兵的个数中去,时间复杂度(O(2^mcdot m^2+nm))。
代码:(写得貌似很丑)
#include <cstdio>
void read(int &a){
a=0;
char c=getchar();
while(c<'0'||c>'9'){
c=getchar();
}
while(c>='0'&&c<='9'){
a=(a<<1)+(a<<3)+(c^48);
c=getchar();
}
}
int quick_power(int a,int b,int Mod){
int ans=1;
while(b){
if(b&1){
ans=1ll*ans*a%Mod;
}
b>>=1;
a=1ll*a*a%Mod;
}
return ans;
}
const int Maxn=300000;
const int Maxm=20;
const int Mod=998244353;
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
frac[0]=1;
for(int i=1;i<=Maxn;i++){
frac[i]=1ll*frac[i-1]*i%Mod;
}
inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
for(int i=Maxn-1;i>=0;i--){
inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
}
}
int C(int n,int m){
if(n<m){
return 0;
}
if(m<0){
return 0;
}
return 1ll*frac[n]*inv_f[m]%Mod*inv_f[n-m]%Mod;
}
int l[Maxn+5],r[Maxn+5];
bool vis[Maxn+5];
int n,m;
int a[Maxm+5],b[Maxm+5];
int f[1<<Maxm|5][Maxm+5];
int lis[Maxm<<1|5],len;
bool in[Maxn+5],out[Maxn+5];
int deg[Maxn+5];
int head[Maxn+5],arrive[Maxn+5],nxt[Maxn+5],val[Maxn+5],tot;
int sum[Maxn+5];
void add_edge(int from,int to,int w){
arrive[++tot]=to;
nxt[tot]=head[from];
val[tot]=w;
head[from]=tot;
}
bool build_graph(int mask){
len=0;
for(int i=1;i<=m;i++){
if((mask>>(i-1))&1){
if(!in[a[i]]){
in[a[i]]=1;
lis[++len]=a[i];
}
if(!in[b[i]]){
in[b[i]]=1;
lis[++len]=b[i];
}
}
}
tot=0;
for(int i=1;i<=len;i++){
head[lis[i]]=0;
deg[lis[i]]=0;
}
for(int i=1;i<=m;i++){
if(in[a[i]]&&in[b[i]]&&(mask>>(i-1)&1)==0){
return 0;
}
}
for(int i=1;i<=m;i++){
if((mask>>(i-1))&1){
deg[a[i]]++;
deg[b[i]]++;
add_edge(a[i],b[i],i);
add_edge(b[i],a[i],i);
}
}
return 1;
}
void connect_dfs(int u,int &mask){
for(int i=head[u];i;i=nxt[i]){
if(((1<<(val[i]-1))&mask)>0){
continue;
}
mask|=(1<<(val[i]-1));
int v=arrive[i];
connect_dfs(v,mask);
}
}
int calc(int mask){
for(int i=1;i<=m;i++){
if((mask>>(i-1))&1){
out[a[i]]=out[b[i]]=1;
}
}
int ans=0;
for(int i=1;i<=len;i++){
if(in[lis[i]]&&!out[lis[i]]){
ans++;
}
}
for(int i=1;i<=m;i++){
if((mask>>(i-1))&1){
out[a[i]]=out[b[i]]=0;
}
}
return ans;
}
bool check(int x,int len){
return l[x]<=len&&len<=r[x];
}
int count(int x){
int ans=0;
while(x){
ans++;
x-=(x&(-x));
}
return ans;
}
int main(){
init();
read(n),read(m);
for(int i=1;i<=n;i++){
read(l[i]),read(r[i]);
}
for(int i=1;i<=m;i++){
read(a[i]),read(b[i]);
vis[a[i]]=vis[b[i]]=1;
}
for(int mask=0;mask<(1<<m);mask++){
f[mask][0]=1;
if(mask==0){
continue;
}
if(!build_graph(mask)){
for(int j=1;j<=len;j++){
in[lis[j]]=0;
}
f[mask][0]=0;
continue;
}
int nmask=0;
connect_dfs(lis[1],nmask);
if(nmask!=mask){
for(int j=0;j<=m;j++){
if(f[nmask][j]==0){
continue;
}
for(int k=0;k+j<=m;k++){
if(j==0&&k==0){
continue;
}
f[mask][j+k]=(f[mask][j+k]+1ll*f[nmask][j]*f[mask^nmask][k])%Mod;
}
}
for(int j=1;j<=len;j++){
in[lis[j]]=0;
}
continue;
}
nmask=mask;
int u=lis[1];
int num=0;
for(int j=head[u];j;j=nxt[j]){
nmask&=(~(1<<(val[j]-1)));
}
num=calc(nmask)-1;
for(int j=1;j<=m;j++){
for(int k=0;k<=j;k++){
f[mask][j]=(f[mask][j]+1ll*f[nmask][k]*C(num,j-k))%Mod;
}
}
num=0;
for(int j=head[u];j;j=nxt[j]){
int v=arrive[j];
for(int k=head[v];k;k=nxt[k]){
nmask&=(~(1<<(val[k]-1)));
if(arrive[k]==u){
continue;
}
}
}
num=calc(nmask)-1-deg[lis[1]];
for(int j=1;j<=m;j++){
for(int k=0;k<j;k++){
f[mask][j]=(f[mask][j]+1ll*f[nmask][k]*C(num,j-1-k))%Mod;
}
}
for(int j=1;j<=len;j++){
in[lis[j]]=0;
}
}
for(int i=1;i<=n;i++){
if(!vis[i]){
sum[l[i]]++;
sum[r[i]+1]--;
}
}
for(int i=1;i<=n;i++){
sum[i]+=sum[i-1];
}
int ans=0;
for(int i=1;i<=n;i++){
int mask=0;
for(int j=1;j<=m;j++){
if(check(a[j],i)&&check(b[j],i)){
in[a[j]]=in[b[j]]=1;
mask|=(1<<(j-1));
}
}
int num=0;
for(int j=1;j<=m;j++){
if(check(a[j],i)&&!in[a[j]]){
num++;
in[a[j]]=1;
}
if(check(b[j],i)&&!in[b[j]]){
num++;
in[b[j]]=1;
}
}
for(int j=1;j<=m;j++){
in[a[j]]=in[b[j]]=0;
}
int tmp=sum[i];
tmp+=num;
for(int j=0;j<=m;j++){
ans=(ans+1ll*C(tmp,i-j)*f[mask][j])%Mod;
}
}
printf("%d
",ans);
return 0;
}