嘟嘟嘟
这题当时没想出来(因为本人实在不太擅长计数),然后又被luogu的第一篇题解吓怕了,就咕了一小段时间再写。
其实这题不是很难。
做法就是基础容斥+NTT。
首先出现(S)次的颜色最多有(N = min { frac{n}{S}, m })种。
我们令(dp[i])表示出现(S)次的颜色至少有(i)种的方案数,那么共有(C_{m} ^ {i})种颜色组合,这些颜色的位置共有(C_{n} ^ {iS})种选取方案,剩下的位置每一位都有(m)中颜色可选,然后再考虑这(C_{n} ^{iS})个位置中每一种颜色的分配方案,就有
[egin{align*}
dp[i]
&= C_{m} ^ {i} * C _{n} ^ {iS} * (C_{iS} ^ {S} * C_{iS - S} ^ {S} * C_{iS - 2S} ^ {S} * ldots * C_{S} ^ {S}) * m ^ {n - iS} \
&= C_{m} ^ {i} * C _{n} ^ {iS} * frac{(iS)!}{(S!) ^ i} * m ^ {n - iS} \
end{align*}
]
然后我们令(ans[i])表示出现(S)次的颜色恰好有(i)种的方案数,根据容斥,就有这么个式子:
[egin{align*}
ans[i] &= sum _ {j = i} ^ {N} (-1) ^ {j - i} C_{j} ^ {i} dp[j] \
ans[i] *i! &= sum _ {j = i} ^ {N} frac{(-1) ^ {j - i}}{(j - i)!} * dp[j] * j!
end{align*}
]
这个东西NTT可做。把dp数组反过来。就像[ZJOI2014]力这道题一样。
需要注意的是这样(ans[i])也是反过来的,所以乘上的是(inv[N - i])和(w[N - i])。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2e7 + 5;
const ll mod = 1004535809;
const ll G = 3;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n, m, S, N, w[maxn];
In ll inc(ll a, ll b) {return a + b >= mod ? a + b - mod : a + b;}
In ll quickpow(ll a, ll b)
{
ll ret = 1;
for(; b; b >>= 1, a = a * a % mod)
if(b & 1) ret = ret * a % mod;
return ret;
}
ll fac[maxn], inv[maxn];
In void init()
{
int Max = max(n, m);
fac[0] = inv[0] = 1;
for(int i = 1; i <= Max; ++i) fac[i] = fac[i - 1] * i % mod;
inv[Max] = quickpow(fac[Max], mod - 2);
for(int i = Max - 1; i; --i) inv[i] = inv[i + 1] * (i + 1) % mod;
}
ll dp[maxn], b[maxn];
int len = 1, lim = 0, rev[maxn];
In void ntt(ll* a, int len, int flg)
{
for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int i = 1; i < len; i <<= 1)
{
ll gn = quickpow(G, (mod - 1) / (i << 1));
for(int j = 0; j < len; j += (i << 1))
{
ll g = 1;
for(int k = 0; k < i; ++k, g = g * gn % mod)
{
ll tp1 = a[k + j] % mod, tp2 = g * a[k + j + i] % mod;
a[k + j] = inc(tp1, tp2), a[k + j + i] = inc(tp1, mod - tp2);
//a[k + j] = (tp1 + tp2) % mod, a[k + j + i] = (tp1 - tp2 + mod) % mod;
}
}
}
if(flg == 1) return;
ll inv = quickpow(len, mod - 2); reverse(a + 1, a + len);
for(int i = 0; i < len; ++i) a[i] = a[i] * inv % mod;
}
int Ans[maxn];
In void bf()
{
for(int i = 0; i <= N; ++i)
for(int j = 0; j <= i; ++j)
Ans[i] = inc(Ans[i], dp[j] * b[i - j] % mod);
ll ans = 0;
for(int i = 0; i <= N; ++i) ans = inc(ans, Ans[i] * inv[N - i] % mod * w[N - i] % mod);
printf("--->%lld
", ans);
}
int main()
{
n = read(), m = read(), S = read();
N = min(m, n / S);
for(int i = 0; i <= m; ++i) w[i] = read();
init();
for(int i = 0; i <= N; ++i)
dp[i] = fac[n] * fac[m] % mod * inv[i] % mod * inv[m - i] % mod * quickpow(inv[S], i) % mod * inv[n - i * S] % mod * quickpow(m - i, n - i * S) % mod * fac[i] % mod;
for(int i = 0; i <= N; ++i) b[i] = (i & 1) ? mod - inv[i] : inv[i];
reverse(dp, dp + N + 1);
//bf();
while(len <= (N << 1)) len <<= 1, ++lim;
for(int i = 0; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lim - 1));
ntt(dp, len, 1), ntt(b, len, 1);
for(int i = 0; i < len; ++i) dp[i] = dp[i] * b[i] % mod;
ntt(dp, len, -1);
ll ans = 0;
for(int i = 0; i <= N; ++i) ans = inc(ans, dp[i] * inv[N - i] % mod * w[N - i] % mod);
write(ans), enter;
return 0;
}