【AtCoder - agc040_c】Neither AB nor BA 想法+组合数学
题意
求长度为n(n为偶数)的满足以下条件的字符串数量。
-
字符串中只含有ABC三种字母,且字符串的长度为偶数。
-
字符串可以按照以下规则删除成空串
- 相邻两个字母只要不是AB或者BA都可以删除
题解
因为容易发现奇数位上的A与偶数位上的B不能消去,奇数位上的B与偶数位上的A也是同理
因此做一个巧妙的转化,把所有奇数位上的A换成B,所有奇数位上的B换成A
例如,ABCBBABBAC转化BBCBAAABBC
这样转换之后就会发现,原来的规则是AB和BA不能消除,现在变成了AA不能消除,BB不能消除
这样就容易发现,想要消除一个A必须跟一个B或者一个C一起消除,B也是同理。
那么在这样的情况下:
-
如果A的数量超过字符串长度的一半,那么势必无法全部消除
-
如果B的数量超过字符串长度的一半,那么势必也无法全部消除
那么要求能全部删除的数量,只需要用总数 - 不能完全删除的数量即可得答案
不能完全删除的数量可以这样计算,根据上面的条件可知,只要让A或者B的数量大于字符串长度的一半即可
一个小Tip:可能有同学会觉得,A的数量大于一半就不能消除不是经过转化之后得出的结论吗?那我们要算不能完全删除的数量不是应该转化回去计算吗?(对!没错这个“有同学”就是我)
其实细品一下会发现,就按照转化之后的结论计算完全没问题,因为要构造不合法的情况,你就按照我们的结论,放上大于n/2个A,然后再把奇数位上的A和B按照之前转化过来的规则转化回去,就是本题真正要求的不合法的情况。正因为我们转化前后两种情况是等价的,所以构造出来的不合法的情况,经过转化之后也是等价的。
又因为转化回去构造的话考虑起来蛮复杂的,什么奇A偶B,奇B偶A,所以不如就在转化之后这里计算不合法的数量方便,所以干脆就直接按照转化之后的结论构造即可得出答案(因为我们清楚这两者是等价的啦)
那么只需要处理出A不合法的情况数,再 * 2就是A和B总共的不合法情况数量。
也就是枚举A的数量t从n/2+1到n,给A从n个位置中挑出t个位置
剩下的位置,每个位置都有两种选择B或者C,那么A不合法的情况数就是(C(n,t)*2^{n-t})
长度为n的字符串,每个位置有ABC三种选择,因此字符串总数为(3^n)
长度为n的合法的字符串数量 = 字符串总数 - 不合法字符串数量 = (3^n-2*C(n,t)*2^{n-t})
/****************************
* Author : W.A.R *
* Date : 2020-10-08-20:19 *
****************************/
/*
https://vjudge.net/problem/AtCoder-agc040_c
*/
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<iostream>
#include<algorithm>
#include<queue>
#include<map>
#include<stack>
#include<string>
#include<set>
#define IOS ios::sync_with_stdio(false)
#define show(x) std:: cerr << #x << " = " << x << std::endl;
#define mem(a,x) memset(a,x,sizeof(a))
#define Rint register int
using namespace std;
typedef long long ll;
const int maxn=1e7+10;
const int maxm=2e6+10;
const ll mod=998244353;
ll fac[maxn],invFac[maxn],mi[maxn],n,sum;
ll qpow(ll a,ll n){
a%=mod;
ll ans=1;
while(n){
if(n%2)ans=ans*a%mod;
n/=2;
a=a*a%mod;
}
return ans;
}
ll C(ll n,ll m){
if(m>n||m<0)return 0;
return fac[n]*invFac[n-m]%mod*invFac[m]%mod;
}
void Init(){
fac[0]=1;
mi[0]=1;
for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;
invFac[n]=qpow(fac[n],mod-2);
for(int i=n-1;i>=0;i--)invFac[i]=invFac[i+1]*(i+1)%mod;
for(int i=1;i<=n;i++)mi[i]=mi[i-1]*2%mod;
}
int main(){
scanf("%lld",&n);Init();
for(int i=n/2+1;i<=n;i++)sum=(sum+((C(n,i)*mi[n-i])%mod))%mod;
sum=(qpow(3,n)-sum*2%mod+mod)%mod;
printf("%lld
",sum);
return 0;
}
感受
就好巧妙的转化,是我不配了呜呜呜呜呜