题意
不定方程解的计数问题
[sum_{x_1+x_2+cdots + x_m = N, x_i in N} prod x_i ^ {K_i}
]
- (sum K_i le 10^5, m le 10^5)
- (N le 10^7)
(记(0^0 = 1))
思路
首先这个式子的组合意义就是把 (N) 个球分成 (m) 组,然后在每组中有序地选出 (K_i) 个元素(可以重复)。
如果 (K_i = 1),有个明显的组合意义就是选出一个代表元,然后两边各有一些元素,一个未知数变成两个未知数,总和减少了 (1) 。
当然 (K_i > 1) 也同理,可以看成选了 (K_i) 个代表元,但是要考虑代表元重复以及各个代表元间的顺序的问题。
记 (sum K_i = S),枚举代表元的位置数 (i in [m, S]),显然不同组(不在一个未知数内)的代表元相互独立,方案数相乘,可以用若干个 OGF 表示。
把 (m) 个代表元合并成 (k) 个(考虑顺序)的方案数用第二类 Stirling 数表示:
[egin{Bmatrix} m \ k end{Bmatrix} cdot k!
]
第二类 Stirling 数的一行可以用容斥转化成卷积,见 Luogu 模板 第二类斯特林数·行,这里不讲了。
把这些 OGF 乘起来,就得到了 (S) 个代表元合并的方案数,剩下的用组合数就很容易数了。若有 (i) 个位置上有代表元,则方案数为 (m + i) 个非负未知数,和为 (n - i) 的不定方程解数计数。乘法的复杂度上界是 (mathcal O(n log^2 n)),具体这里使用类似石子合并的方法合并。
Code
多项式乘法使用 std::vector
存储,常数巨大,内存巨大
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
#define File(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)
typedef long long ll;
namespace io {
const int SIZE = (1 << 21) + 1;
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;
#define gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
char getc () {return gc();}
inline void flush () {fwrite (obuf, 1, oS - obuf, stdout); oS = obuf;}
inline void putc (char x) {*oS ++ = x; if (oS == oT) flush ();}
template <class I> inline void gi (I &x) {for (f = 1, c = gc(); c < '0' || c > '9'; c = gc()) if (c == '-') f = -1;for (x = 0; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15); x *= f;}
template <class I> inline void print (I x) {if (!x) putc ('0'); if (x < 0) putc ('-'), x = -x;while (x) qu[++ qr] = x % 10 + '0', x /= 10;while (qr) putc (qu[qr --]);}
struct Flusher_ {~Flusher_(){flush();}}io_flusher_;
}
using io :: gi; using io :: putc; using io :: print; using io :: getc;
template<class T> void upmax(T &x, T y){x = x>y ? x : y;}
template<class T> void upmin(T &x, T y){x = x<y ? x : y;}
const int p = 998244353, G = 3;
inline int add(int x, int y){return x+y>=p ? x+y-p : x+y;}
inline int sub(int x, int y){return x-y<0 ? x-y+p : x-y;}
inline int mul(int x, int y){return 1LL * x * y % p;}
inline int power(int x, int y){
int res = 1;
for(; y; y>>=1, x = mul(x, x)) if(y & 1) res = mul(res, x);
return res;
}
inline int inv(int x){return power(x, p - 2);}
const int N = 10000005, M = 100005, Len = 262144;
int fac[N + M * 5], ifac[N + M * 5];
void preC(int n){
fac[0] = 1;
for(int i=1; i<=n; i++) fac[i] = mul(fac[i-1], i);
ifac[n] = inv(fac[n]);
for(int i=n-1; i>=0; i--) ifac[i] = mul(ifac[i+1], i+1);
}
inline int C(int n, int m){return mul(fac[n], mul(ifac[m], ifac[n - m]));}
inline int P(int n, int m){return mul(fac[n], ifac[n - m]);}
inline int equation(int m, int S){
return C(S - 1 + m, m - 1);
}
int K[M];
int m, n;
namespace polynomial{
int w[Len], invw[Len];
struct _polyInit{
_polyInit(){
w[0] = invw[0] = 1;
w[1] = power(G, (p - 1) / Len); invw[1] = inv(w[1]);
for(int i=2; i<Len; i++){
w[i] = mul(w[i-1], w[1]);
invw[i] = mul(invw[i - 1], invw[1]);
}
}
}_init;
int last = -1;
int rev[Len];
void pre(int n){
if(last == n) return ;
last = n;
int lg = -1, nn = n;
while(nn != 1) nn >>= 1, ++lg;
for(int i=1; i<n; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg);
}
void NTT(vector<int> &f, int *w){
int n = f.size();
pre(n);
for(int i=1; i<n; i++)
if(rev[i] < i) swap(f[i], f[rev[i]]);
for(int l=1; l<n; l<<=1){
int step = Len / (l << 1);
for(int j=0; j<n; j+=(l<<1))
for(int k=0, p=0; k<l; k++, p+=step){
int x = f[j + k], y = mul(f[j + l + k], w[p]);
f[j + k] = add(x, y); f[j + k + l] = sub(x, y);
}
}
}
};
typedef vector<int> poly;
void operator*=(poly &a, poly &b){
using namespace polynomial;
int la = a.size(), lb = b.size(), len = 1;
while(len < la + lb) len <<= 1;
a.resize(len); b.resize(len);
NTT(a, w); NTT(b, w);
for(int i=0; i<len; i++) a[i] = mul(a[i], b[i]);
NTT(a, invw);
int invlen = inv(len);
for(int i=0; i<len; i++) a[i] = mul(a[i], invlen);
a.resize(la + lb - 1);
}
poly F[N];
struct LenCmp{
bool operator() (int a, int b) const {return F[a].size() > F[b].size();}
};
priority_queue<int, vector<int>, LenCmp> q;
void getStirling(poly &a, int n){
poly b(n + 1);
a.resize(n + 1);
for(int i=1; i<=n; i++){
a[i] = mul(power(i, n), ifac[i]);
b[i] = ifac[i];
if(i & 1) b[i] = sub(0, b[i]);
}
b[0] = 1;
a *= b;
a.resize(n + 1);
}
int solve(){
int kcnt = 0, ksum = 0;
for(int i=1; i<=m; i++){
kcnt += K[i] >= 1;
ksum += K[i];
if(K[i] == 0) continue;
getStirling(F[i], K[i]);
for(int j=1; j<=K[i]; j++) F[i][j] = mul(F[i][j], fac[j]);
q.push(i);
}
if(q.size() == 0) return 1;
while(q.size() != 1){
int x = q.top(); q.pop();
int y = q.top(); q.pop();
F[x] *= F[y];
q.push(x);
}
vector<int> &f = F[q.top()]; q.pop();
int res = 0;
for(int i=kcnt; i<=ksum; i++)
res = add(res, mul(f[i], equation(m + i, n - i)));
return res;
}
int main(){
gi(m); gi(n);
preC(n + m * 3);
for(int i=1; i<=m; i++)
gi(K[i]);
printf("%d
", solve());
return 0;
}