题目
点这里看题目。
分析
首先,由于仅仅是 " 存在 " 这一条限制很容易导致计重,且总方案数就是 (10^n) 。我们就可以考虑求补集,也就是不存在的情况。
然后有一个很显然的 DP :
(f(i,S)):序列长度为 (i) ,此时序列的和 (le X + Y + Z) 的极长后缀的状态为 (S) 的方案数。
举一个关于极长后缀的例子:
X=2,Y=5,Z=4
... 1 3 ( 2 5 3 )
括号内的就是状态需要的极长后缀。
那么怎么存下来这个后缀呢?
暴力
写个搜索找一下方案数:
void DFS( const int su )
{
if( su > X + Y + Z ) return ;
++ cnt;
for( int i = 1 ; i <= 10 ; i ++ )
DFS( su + i );
}
得到结果:
input: X = 5, Y = 7, Z = 5
output: cnt = 130624
因而,实际上, (S) 的状态数不会太多。我们可以全部搜索出来之后,建立 (S) 和序号的映射关系。对于每个 (S) ,我们可以直接得到它加了一个字符之后的转移,再进行 DP 。
顺带一提, (S) 也可以直接被压缩成 long long
范围内的数,方便 map
使用。
时间复杂度最大是 (O(130642 imes 10 imes n)) ,有点卡,不过还好。
优雅解法
我们假想出一条长度为 (X+Y+Z) 的空格子列。那么一个数 (x) 就相当于涂掉了连续的 (x) 个格子。我们就只需要判断:第 (X) 个格子是不是某次填涂的末尾?第 (X + Y) 个格子是否是某次填涂的末尾?第 (X+Y+Z) 个格子是否是某次填涂的末尾?
进而可以想到一个状态压缩的方法:对于一个数 (x) ,我们就把 ([1,x)) 的格子留空,把第 (x) 个格子涂上颜色,标记这里是一次填涂的尾巴。这样我们就可以利用位运算快速查出当前状态是否合法。
煮几个栗子:
input: X = 3, Y = 2, Z = 4 => 001010001
原序列 转化为格子 转化后的序列
1 1 1 3 2 1 => 1 1 1 001 01 1 => 111001011
^ x ^ mismatched
1 2 2 3 1 => 1 01 01 001 1 => 101010011
^ ^ ^ matched
这样应该就可以理解了。 DP 的时间复杂度是 (O(10 imes n2^{X+Y+Z})) 。
这个方法理论复杂度要大一些,不过暴力被卡常了
本题有价值的地方在于:
- 补集转化
- 暴力搜索的应用,状态数小的时候甚至可以直接暴力搜出转移的情况(准确来说,直接搜出了 DP 的 DAG)。
- 前缀和与涂格子的转化,使得本题可以直接状态压缩。
代码
暴力
#include <map>
#include <cstdio>
#include <vector>
using namespace std;
typedef long long LL;
const int mod = 1e9 + 7;
const int MAXN = 45, MAXS = 2e6 + 5;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
map<LL, int> state;
int f[MAXN][MAXS];
int trans[MAXS][11];
int seq[MAXN], siz;
int cnt = 0;
int N, X, Y, Z;
bool legal[MAXS];
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
void DFS( const int su, const LL H )
{
if( su > X + Y + Z ) return ;
state[H] = ++ cnt;
for( int i = 1 ; i <= 10 ; i ++ )
DFS( su + i, H * 10 + i );
}
bool chk()
{
int s[MAXN] = {};
for( int i = 1 ; i <= siz ; i ++ ) s[i] = s[i - 1] + seq[i];
bool flag = false;
for( int i = 1 ; i <= siz ; i ++ )
if( s[i] == X )
{ flag = true; break; }
if( ! flag ) return false;
flag = false;
for( int i = 1 ; i <= siz ; i ++ )
if( s[i] == X + Y )
{ flag = true; break; }
if( ! flag ) return false;
return s[siz] == X + Y + Z;
}
int push( const int c )
{
int su = 0, l;
for( l = siz ; l ; su += seq[l --] )
if( c + su + seq[l] > X + Y + Z )
break;
LL val = 0;
for( int j = l + 1 ; j <= siz ; j ++ )
val = val * 10 + seq[j];
if( c <= X + Y + Z ) val = val * 10 + c;
return state[val];
}
void init( const int su )
{
if( su > X + Y + Z ) return ;
legal[++ cnt] = chk();
for( int i = 1 ; i <= 10 ; i ++ ) trans[cnt][i] = push( i );
for( int c = 1 ; c <= 10 ; c ++ )
seq[++ siz] = c, init( su + c ), siz --;
}
int main()
{
read( N ), read( X ), read( Y ), read( Z );
DFS( 0, 0 );
cnt = 0; init( 0 );
f[0][1] = 1;
for( int i = 0 ; i < N ; i ++ )
for( int k = 1 ; k <= cnt ; k ++ )
if( f[i][k] && ! legal[k] )
for( int c = 1 ; c <= 10 ; c ++ )
if( trans[k][c] && ! legal[trans[k][c]] )
add( f[i + 1][trans[k][c]], f[i][k] );
int ans = 1;
for( int i = 1 ; i <= N ; i ++ )
ans = 10ll * ans % mod;
for( int i = 1 ; i <= cnt ; i ++ )
if( ! legal[i] )
add( ans, mod - f[N][i] );
write( ans ), putchar( '
' );
return 0;
}
优雅的解法
#include <cstdio>
const int mod = 1e9 + 7;
const int MAXN = 45, MAXS = ( 1 << 17 ) + 5;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int f[MAXN][MAXS];
int N, X, Y, Z, st, all;
int bit( const int i ) { return 1 << i - 1; }
bool chk( const int S ) { return ( S & st ) == st; }
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int insert( const int S, const int v ) { return ( S << v | bit( v ) ) & all; }
int main()
{
read( N ), read( X ), read( Y ), read( Z );
int up = 1 << X + Y + Z;
all = ( bit( X + Y + Z ) << 1 ) - 1;
st = bit( Z ) | bit( Y + Z ) | bit( X + Y + Z );
f[0][0] = 1;
for( int i = 0 ; i < N ; i ++ )
for( int S = 0 ; S < up ; S ++ )
if( f[i][S] && ! chk( S ) )
for( int k = 1, T ; k <= 10 ; k ++ )
if( ! chk( T = insert( S, k ) ) )
add( f[i + 1][T], f[i][S] );
int ans = 1;
for( int i = 1 ; i <= N ; i ++ )
ans = 10ll * ans % mod;
for( int S = 0 ; S < up ; S ++ )
if( ! chk( S ) )
add( ans, mod - f[N][S] );
write( ans ), putchar( '
' );
return 0;
}