题面
给定 (B) 个蓝色的球和 (R) 个红色的球以及一个绿色的球,同颜色的球不可区分。对于一种球的排列方式,记 (l_B) 是绿球左边的蓝球个数,(r_B) 是绿右边的蓝球个数,(l_R) 是球左边的红球个数,(r_R) 是球右边边的红球个数,则该排列的权值是最大的正整数 (x) 满足 (l_B imes x le l_R),(r_B imes xle r_R)
数据范围:(1 le B le 10^6),(1 le R le 10^{18})。
题解
法 1
考虑枚举绿球右边的红球和蓝球个数:
考虑右边那堆东西的组合意义:一条路径的权值是他经过的点中满足 (y = Ax + P) 的点数,要求所有从 ((0, 0)) 出发,到达 ((B, R)) 的路径的权值和。
首先我们先考虑一个前置的问题:
求从 ((0, 0)) 到 ((W, AW + P)) 的路径条数 (设为 (f(W, A, P))) 满足这条路径不穿过 (y = Ax+P)。
考虑一条从 ((0, 0)) 到 ((W, AW + P)) 的路径,如果不穿过 (y = Ax+P),我们就枚举他经过这条直线的第一个位置:
考虑一条从 ((0, 0)) 到 ((W - 1, AW + P + 1)) 的路径,他必然穿过 (y = Ax+P):
我们发现 ((1) - A (2)) 可得 (f(W,A,P) = inom{(A+1)W+P}{W} - A inom{(A+1)W+P}{W-1})
回到现在要解决的问题:一条路径的权值是他经过的点中满足 (y = Ax + P) 的点数,要求所有从 ((0, 0)) 出发,到达 ((W, H)) 的路径的权值和。保证 (AW+P le H) 记为 (g(W, H, A, P))。
枚举路径上的点:(sumlimits_{i = 0}^{W} inom{(A + 1) i+P}{i} inom{W + H - (A+1) i - P}{W - i})。
考虑其组合意义,就是先从 ((0, 0)) 走到 ((i, Ai+P)),再在不经过 (y = Ai+P) 的情况下走到 ((W, H))。
可以把他想象成枚举最后一次 碰到 (y = Ai+P) 的位置,最终要到达 ((W, H + 1)) (因为碰到之后必然会向上走,然后不能 穿过 (y = Ai + P + 1),因此要到达的点是 ((W, H + 1)))。
其实我们算的就是从 ((0, 0)) 到 ((H + 1, W)) ,因此 (g(W, H, A, P) - Ag(W - 1, H+1, A, P) = inom{H+W+1}{W})。结果和 (P) 无关!
因此 (g(W,H,A,P) = sumlimits_{i = 0}^{W} inom{H+W+1}{i} A^{W-i})
接下来就很好做了:要算的是 (sumlimits_{A = 1}^{frac{R}{B}} (R-AB+1) sumlimits_{i = 0}^{B} inom{B+R+1}{i} A^{B-i})。
交换一下求和顺序就是 (sumlimits_{i = 0}^{B} inom{H+R+1}{i} ( (R+1) sumlimits_{A = 1}^{frac{R}{B}} A^{B-i} - B sumlimits_{A = 1}^{frac{R}{B}} A^{B-i+1} ))。可以伯努利数解决。
法 2
前置知识:广义二项级数
从这里开始推:
考虑如何计算后面的东西。
于是变成了和 法1 完全一样的形式了。
代码
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = j, i##E = k; i <= i##E; i++)
#define R(i, j, k) for(int i = j, i##E = k; i >= i##E; i--)
#define ll long long
#define pii pair<int, int>
#define db double
#define x first
#define y second
#define ull unsigned long long
#define sz(a) ((int) (a).size())
#define vi vector<int>
using namespace std;
const int mod = 998244353, G = 3, iG = (mod + 1) / G, N = 2.1e6 + 7, inv2 = (mod + 1) / 2;
#define add(a, b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a, b) (a < b ? a - b + mod : a - b)
inline ull calc(const ull &x) {
return x - (__uint128_t(x) * 9920937979283557439ull >> 93) * 998244353;
}
int qpow(int x, int y = mod - 2) {
int res = 1;
for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
return res;
}
int n, m, fac[N], ifac[N], inv[N];
void init(int x) {
fac[0] = ifac[0] = inv[1] = 1;
L(i, 2, x) inv[i] = (ll) inv[mod % i] * (mod - mod / i) % mod;
L(i, 1, x) fac[i] = (ll) fac[i - 1] * i % mod, ifac[i] = (ll) ifac[i - 1] * inv[i] % mod;
}
int rt[N], Lim;
void Pinit(int x) {
for(Lim = 1; Lim <= x; Lim <<= 1) ;
int sG = qpow(G, (mod - 1) / Lim); rt[0] = 1;
L(i, 1, Lim) rt[i] = (ll) rt[i - 1] * sG % mod;
}
int C(int x, int y) {
return y < 0 || x < y ? 0 : (ll) fac[x] * ifac[y] % mod * ifac[x - y] % mod;
}
int rev[N];
void initrev(int n) {
L(i, 0, n - 1) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) * (n >> 1)));
}
struct poly {
vector<int> a;
int size() { return sz(a); }
int & operator [] (int x) { return a[x]; }
int v(int x) { return x < 0 || x >= sz(a) ? 0 : a[x]; }
void clear() { vector<int> ().swap(a); }
void rs(int x = 0) { a.resize(x); }
poly (int n = 0) { rs(n); }
poly (vector<int> o) { a = o; }
poly (const poly &o) { a = o.a; }
poly Rs(int x = 0) { vi res = a; res.resize(x); return res; }
void ntt(int op, int t = true) {
int n = sz(a);
if(t) initrev(n);
L(i, 0, n - 1) if(rev[i] < i) swap(a[rev[i]], a[i]);
for(int i = 2; i <= n; i <<= 1)
for(int j = 0, l = (i >> 1), ch = Lim / i; j < n; j += i)
for(int k = j, now = 0; k < j + l; k++) {
int pa = a[k], pb = calc((ull) a[k + l] * (op == 1 ? rt[now] : rt[Lim - now]));
a[k] = add(pa, pb), a[k + l] = dec(pa, pb), now += ch;
}
if(op != 1) for(int i = 0, iv = qpow(n); i < n; i++) a[i] = (ll) a[i] * iv % mod;
}
friend poly operator * (poly aa, poly bb) {
if(!sz(aa) || !sz(bb)) return {};
int lim, all = sz(aa) + sz(bb) - 1;
for(lim = 1; lim < all; lim <<= 1);
initrev(lim), aa.rs(lim), bb.rs(lim), aa.ntt(1, false), bb.ntt(1, false);
L(i, 0, lim - 1) aa[i] = (ll) aa[i] * bb[i] % mod;
aa.ntt(-1, false), aa.a.resize(all);
return aa;
}
friend poly operator * (poly aa, int bb) {
poly res(sz(aa));
L(i, 0, sz(aa) - 1) res[i] = (ll) aa[i] * bb % mod;
return res;
}
friend poly operator + (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = add(aa.v(i), bb.v(i));
return poly(res);
}
friend poly operator - (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = dec(aa.v(i), bb.v(i));
return poly(res);
}
poly & operator += (poly o) {
rs(max(sz(a), sz(o)));
L(i, 0, sz(a) - 1) (a[i] += o.v(i)) %= mod;
return (*this);
}
poly & operator -= (poly o) {
rs(max(sz(a), sz(o)));
L(i, 0, sz(a) - 1) (a[i] += mod - o.v(i)) %= mod;
return (*this);
}
poly & operator *= (poly o) {
return (*this) = (*this) * o;
}
poly Inv() {
poly res, f, g;
res.rs(1), res[0] = qpow(a[0]);
for(int m = 1, pn; m < sz(a); m <<= 1) {
pn = m << 1, f = res, g.rs(pn), f.rs(pn), initrev(pn);
for(int i = 0; i < pn; i++) g[i] = (*this).v(i);
f.ntt(1, false), g.ntt(1, false);
for(int i = 0; i < pn; i++) g[i] = (ll) f[i] * g[i] % mod;
g.ntt(-1, false);
for(int i = 0; i < m; i++) g[i] = 0;
g.ntt(1, false);
for(int i = 0; i < pn; i++) g[i] = (ll) f[i] * g[i] % mod;
g.ntt(-1, false), res.rs(pn);
for(int i = m; i < min(pn, sz(a)); i++) res[i] = (mod - g[i]) % mod;
}
return res;
}
poly Integ() {
if(!sz(a)) return poly();
poly res(sz(a) + 1);
L(i, 1, sz(a)) res[i] = (ll) a[i - 1] * inv[i] % mod;
return res;
}
poly Deriv() {
if(!sz(a)) return poly();
poly res(sz(a) - 1);
L(i, 1, sz(a) - 1) res[i - 1] = (ll) a[i] * i % mod;
return res;
}
poly Ln() {
poly g = ((*this).Inv() * (*this).Deriv()).Integ();
return g.rs(sz(a)), g;
}
poly Exp() {
poly res(1), f;
res[0] = 1;
for(int m = 1, pn; m < sz(a); m <<= 1) {
pn = min(m << 1, sz(a)), f.rs(pn), res.rs(pn);
for(int i = 0; i < pn; i++) f[i] = (*this).v(i);
f -= res.Ln(), (f[0] += 1) %= mod, res *= f, res.rs(pn);
}
return res.rs(sz(a)), res;
}
poly pow(int x) {
poly res = (*this).Ln();
L(i, 0, sz(res) - 1) res[i] = (ll) res[i] * x % mod;
res = res.Exp();
return res;
}
poly sqrt() {
poly res(1), f;
res[0] = 1;
for(int m = 1, pn; m < sz(a); m <<= 1) {
pn = min(m << 1, sz(a)), f.rs(pn);
for(int i = 0; i < pn; i++) f[i] = (*this).v(i);
f += res * res, f.rs(pn), res.rs(pn), res = f * res.Inv(), res.rs(pn);
for(int i = 0; i < pn; i++) res[i] = (ll) res[i] * inv2 % mod;
}
return res;
}
void Rev() {
reverse(a.begin(), a.end());
}
} ;
poly Mul(poly aa, poly bb, int all = 0) {
if(!sz(aa) || !sz(bb)) return {};
if(!all) all = sz(aa) + sz(bb) - 1;
int lim; for(lim = 1; lim < all; lim <<= 1);
initrev(lim), aa.rs(lim), bb.rs(lim), aa.ntt(1, 0), bb.ntt(1, 0);
L(i, 0, lim - 1) aa[i] = calc((ull) aa[i] * bb[i]);
aa.ntt(-1, 0), aa.a.resize(all);
return aa;
}
int B, ns, now = 1;
ll R;
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> R >> B, init(B + 2), Pinit(B * 2 + 4);
if(R < B) {
cout << "0
";
return 0;
}
poly a(B + 2), b(B + 2);
L(i, 0, B + 1) a[i] = ifac[i + 1];
a = a.Inv(), now = 1;
L(i, 0, B + 1) now = (R / B + 1) % mod * now % mod, b[i] = (ll) now * ifac[i + 1] % mod;
a *= b;
L(i, 0, B + 1) a[i] = (ll) a[i] * fac[i] % mod;
(a[0] += mod - 1) %= mod;
now = 1;
L(i, 0, B)
(ns += ((R + 1) % mod * a[B - i] % mod + mod - (ll) B * a[B - i + 1]% mod) % mod * now % mod)
%= mod, now = (ll) (B + R + 1 - i) % mod * now % mod * inv[i + 1] % mod;
cout << ns << "
";
return 0;
}