题目
点这里看题目
分析
关于组合数,这里有一个基本等式:
[inom{n}{k} imes k=inom{n-1}{k-1} imes n
]
尝试推广一下:
[egin{aligned}
k^2 imesinom nk&=nk imes inom{n-1}{k-1}\
&=n imes inom{n-1}{k-1}+n(n-1) imes inom{n-2}{k-2}\
k^3 imesinom nk&=kleft(n imes inom{n-1}{k-1}+n(n-1) imes inom{n-2}{k-2}
ight)\
&=n imes inom{n-1}{k-1}+3n(n-1) imes inom{n-2}{k-2}+n(n-1)(n-2) imes inom{n-3}{k-3}\
& vdots\
k^m imesinom nk&=sum_{i=1}^m coe(m,i) imes n^{underline i} imes inom{n-i}{k-i}
end{aligned}
]
公式里面的(coe(m,i)),表示当(k)的次数为(m)、(n)的下降幂为(i)次的时候的系数。
发现存在递推式:
[coe(m,i)=coe(m-1,i-1)+i imes coe(m-1,i)
]
其中(coe(m-1,i-1))是从((k-i+1) imes n^{underline{i-1}} imes inom{n-i+1}{k-i+1})贡献的;(i imes coe(m-1,i))是(k imes n^{underline i} imes inom{n-i}{k-i})结合后留下来的。
然后发现(coe(m,i))实际上就是第二类斯特林数。改写式子得到:
[k^m imesinom nk=sum_{i=0}^m {m race i} imes n^{underline i} imes inom{n-i}{k-i}
]
为了保证美观以及式子的整齐,利用({nrace 0}=0(n>0))这个性质,我们把下标放到了(0)。
考察原题目:
[egin{aligned}
&sum_{k=0}^nf(k) imes x^k imes inom n k\
=&sum_{i=0}^m a_isum_{k=0}^nx^k imes k^i imes inom n k\
=&sum_{i=0}^m a_isum_{k=0}^nx^k imes sum_{j=1}^i {irace j} imes n^{underline j} imes inom{n-j}{k-j}\
=&sum_{i=0}^m a_isum_{j=0}^i{irace j} imes n^{underline j} imesleft(sum_{k=j}^nx^k imes inom{n-j}{k-j}
ight)\
=&sum_{i=0}^m a_isum_{j=0}^i{irace j} imes n^{underline j} imes x^j imes left(sum_{k=0}^{n-j}x^k imes inom{n-j}{k}
ight)\
=&sum_{i=0}^m a_isum_{j=0}^i{irace j} imes n^{underline j} imes x^j imes (1+x)^{n-j}
end{aligned}
]
现在就可以(O(m^2log_2 n))地计算了。如果追求更优的复杂度,把((1+x)^{n-j})预处理一下就好了。
代码
#include <cstdio>
const int MAXM = 1005;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int S[MAXM][MAXM];
int A[MAXM], xpw[MAXM];
int N, X, mod, M;
int qkpow( int base, int indx )
{
int ret = 1;
while( indx )
{
if( indx & 1 ) ret = 1ll * ret * base % mod;
base = 1ll * base * base % mod, indx >>= 1;
}
return ret;
}
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int main()
{
read( N ), read( X ), read( mod ), read( M );
for( int i = 0 ; i <= M ; i ++ ) read( A[i] );
S[0][0] = 1;
for( int i = 1 ; i <= M ; i ++ )
for( int j = 1 ; j <= M ; j ++ )
S[i][j] = ( S[i - 1][j - 1] + 1ll * S[i - 1][j] * j % mod ) % mod;
for( int i = 0 ; i <= M ; i ++ ) xpw[i] = qkpow( ( X + 1 ) % mod, N - i );
int ans = 0, tmp, down, pw;
for( int i = 0 ; i <= M ; i ++ )
{
tmp = 0, pw = down = 1;
for( int j = 0 ; j <= i ; j ++ )
{
add( tmp, 1ll * S[i][j] * down % mod * pw % mod * xpw[j] % mod );
down = 1ll * down * ( N - j ) % mod, pw = 1ll * pw * X % mod;
}
add( ans, 1ll * A[i] * tmp % mod );
}
write( ans ), putchar( '
' );
return 0;
}