题目
点这里看题目。
题目大意:
有(n)个水果,第(i)个水果有甜度值(v_i)。不甜的水果的甜度值是(-1)。现在将它们连成一棵树。水果(x)在树上是“真甜”的,当且仅当:
即存在另一个甜的水果与它有边连接。
求真甜的水果的甜度值之和不超过(maxV)的方案数。
数据范围:(1le nle 40, -1le v_ile 2.5 imes 10^7, 1le maxVle 10^9)
分析
不难想到把这个问题拆分成两个部分:
1.从甜的水果中选出(k)个作为真甜的,求出甜度和不超过(maxV)的方案数(cnt_k),无序。
2.求出树上有(k)个真甜的水果的方案数(tree_k),有序。
答案显然是(sum_{i=0}^ncnt_i imes tree_i)。
Part1 -> meet-in-middle
这是一个经典的可以使用 meet-in-middle 解决的问题,时间复杂度似乎是(O(n2^{frac n 2}))。
大概就是,把甜的水果分成两个集合(A)和(B),存下来(A)的子集和(B)的子集,然后按照大小分类,并且各自排序。
然后枚举(A)的子集大小(x)和(B)的子集大小(y),用指针扫一遍计算就好了。
Part2-> matrix-tree & 容斥
发现(tree_k)本身很难计算,我们考虑枚举子集进行容斥。
设(T_k)为至多有(k)个真甜的水果的方案数,那么有:
因此我们只需要计算出(T_k)就可以容斥得到(tree_k)。
设(S)为甜的水果的总数。我们下面称“甜但不是真甜”的水果为“半甜的”。
考虑如下建图:
假设(1sim k)的水果是真甜的,(k+1sim S)的水果是半甜的,(S+1sim n)的水果是不甜的。
那么,我们允许真甜的水果和真甜或不甜的水果连边;半甜的水果只能和不甜的水果连边;不甜的水果可以和真甜或不甜的水果连边。
可以发现,这样一构建,(k+1sim S)的水果一定是半甜的,(1sim k)的水果可能是真甜的,也就是至多(k)个真甜的水果了。
该部分时间复杂度(O(n^4))。
代码
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
const int mod = 1e9 + 7;
const int MAXN = 55;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
vector<int> SA[MAXN], SB[MAXN];
int C[MAXN][MAXN];
int G[MAXN][MAXN], D[MAXN][MAXN], K[MAXN][MAXN];
int cnt[MAXN], ptr[MAXN], T[MAXN];
int v[MAXN];
int N, MXV, vcnt;
int qkpow( int base, int indx )
{
int ret = 1;
while( indx )
{
if( indx & 1 ) ret = 1ll * ret * base % mod;
base = 1ll * base * base % mod, indx >>= 1;
}
return ret;
}
int inv( const int a ) { return qkpow( a, mod - 2 ); }
int det( int T[][MAXN], const int n )
{
int ans = 1, tmp, inver, indx;
for( int i = 1 ; i <= n ; i ++ )
{
indx = -1;
for( int j = i ; j <= n ; j ++ )
if( T[j][i] )
{ indx = j; break; }
if( indx == -1 ) return 0;
if( indx ^ i ) ans = mod - ans;
std :: swap( T[i], T[indx] );
inver = inv( T[i][i] );
for( int j = i + 1 ; j <= n ; j ++ )
if( T[j][i] )
{
tmp = 1ll * T[j][i] * inver % mod;
for( int k = i ; k <= n ; k ++ )
T[j][k] = ( T[j][k] - 1ll * T[i][k] * tmp % mod + mod ) % mod;
}
ans = 1ll * ans * T[i][i] % mod;
}
return ans;
}
int MatrixTree()
{
for( int i = 1 ; i <= N ; i ++ )
for( int j = 1 ; j <= N ; j ++ )
K[i][j] = 0, D[i][j] = 0;
for( int i = 1 ; i <= N ; i ++ )
for( int j = 1 ; j <= N ; j ++ )
D[i][i] += G[i][j];
for( int i = 1 ; i <= N ; i ++ )
for( int j = 1 ; j <= N ; j ++ )
K[i][j] = ( D[i][j] - G[i][j] + mod ) % mod;
return det( K, N - 1 );
}
void initCnt()
{
int half = vcnt + 1 >> 1;
int siz1 = half, siz2 = vcnt - half;
int bit, val;
for( int S = 0 ; S < 1 << siz1 ; S ++ )
{
bit = val = 0;
for( int i = 1 ; i <= siz1 ; i ++ )
if( ( S >> i - 1 ) & 1 )
bit ++, val += v[i];
if( val <= MXV ) SA[bit].push_back( val );
}
for( int S = 0 ; S < 1 << siz2 ; S ++ )
{
bit = val = 0;
for( int i = 1 ; i <= siz2 ; i ++ )
if( ( S >> i - 1 ) & 1 )
bit ++, val += v[siz1 + i];
if( val <= MXV ) SB[bit].push_back( val );
}
for( int i = 0 ; i <= siz1 ; i ++ ) sort( SA[i].begin(), SA[i].end() );
for( int i = 0 ; i <= siz2 ; i ++ ) sort( SB[i].begin(), SB[i].end() );
for( int i = 0 ; i <= siz1 ; i ++ )
{
for( int j = 0 ; j <= N - i && j <= siz2 ; j ++ )
ptr[j] = SB[j].size() - 1;
for( int t = 0 ; t < SA[i].size() ; t ++ )
for( int j = 0 ; j <= N - i && j <= siz2 ; j ++ )
{
while( ~ ptr[j] && SB[j][ptr[j]] + SA[i][t] > MXV ) ptr[j] --;
cnt[i + j] = ( cnt[i + j] + ptr[j] + 1 ) % mod;
}
}
}
void initT()
{
for( int k = 0 ; k <= vcnt ; k ++ )
{
for( int i = 0 ; i <= N ; i ++ )
for( int j = 0 ; j <= N ; j ++ )
G[i][j] = 0;
for( int i = 1 ; i <= k ; i ++ )
{
for( int j = i + 1 ; j <= k ; j ++ )
G[i][j] ++, G[j][i] ++;
for( int j = vcnt + 1 ; j <= N ; j ++ )
G[i][j] ++, G[j][i] ++;
}
for( int i = k + 1 ; i <= vcnt ; i ++ )
for( int j = vcnt + 1 ; j <= N ; j ++ )
G[i][j] ++, G[j][i] ++;
for( int i = vcnt + 1 ; i <= N ; i ++ )
for( int j = i + 1 ; j <= N ; j ++ )
G[i][j] ++, G[j][i] ++;
T[k] = MatrixTree();
}
for( int i = 0 ; i <= N ; i ++ )
{
C[i][0] = C[i][i] = 1;
for( int j = 1 ; j < i ; j ++ )
C[i][j] = ( C[i - 1][j] + C[i - 1][j - 1] ) % mod;
}
for( int i = 0 ; i <= N ; i ++ )
for( int j = 0 ; j < i ; j ++ )
T[i] = ( T[i] - 1ll * T[j] * C[i][j] % mod + mod ) % mod;
}
class SweetFruits
{
public :
int countTrees( vector<int> sweetness, int mxS )
{
N = sweetness.size(), MXV = mxS;
for( int i = 0 ; i < N ; i ++ ) v[i + 1] = sweetness[i];
sort( v + 1, v + 1 + N );
reverse( v + 1, v + 1 + N );
for( vcnt = 1 ; vcnt <= N && ~ v[vcnt] ; vcnt ++ );
vcnt --; for( int i = vcnt + 1 ; i <= N ; i ++ ) v[i] = 0;
initCnt();
initT();
int ans = 0;
for( int i = 0 ; i <= N ; i ++ ) ans = ( ans + 1ll * T[i] * cnt[i] % mod ) % mod;
return ans;
}
};