容易发现这些 vip 用户并没什么用,所以考虑枚举手持50元与100元的人共有多少个。设手持50元的人 (a) 个,手持100元的人 (a - k) 个,那么一共是 (2*a - k) 个人,最后手上会剩余 (k) 张50元钞票。用卡特兰数计算得到在这种情况下的方案数就是:
((inom{2 * a - k}{a} - inom{2 * a - k}{a + 1}) * inom{n}{2 * a - k})
其中 (l <= k <= r, 1 <= 2 * a - k <= n)。
把组合数分开计算,得到答案其实是两部分的和:
((inom{2 * a - k}{a} * inom{n}{2 * a - k})
与 (-inom{2 * a - k}{a + 1} * inom{n}{ 2 * a - k})
注意到这两个式子均可以化简(inom{n}{m} * inom{m}{r} = inom{n}{r} * inom{n - m}{m - r}),
得到:
(ans = inom{n}{a}inom{n - a}{a - k} - inom{n}{a + 1} * inom{n - a - 1}{a - k - 1})
分开计算两个部分的和,发现上部分是在计算 (sum inom{n}{a}inom{n - a}{a - k} (ain [1, n], k in [l, r]))
而下部分则是在计算 (sum inom{n}{a}inom{n - a}{a - k - 2} (ain [2, n + 1], k in [l, r]))
把中间同样的部分约去即可得到一个 (O(n)) 的式子,每一步需求解一个组合数。
然后问题是怎样求出 (n) 个对合数取模的组合数?常见思路利用扩展卢卡斯已经不可行,但鉴于此题 (n) 比较小,我们可以暴力拆分组合数中的阶乘数。对于与模数互质的部分利用欧拉定理求出逆元,照常处理;不互质的则计算上下约去了多少个质因子后暴力累乘贡献。
#include <bits/stdc++.h> using namespace std; #define maxn 200000 #define CNST 30 #define int long long int n, P, L, R, ans, cnt, fac[maxn], finv[maxn]; int tot, a[maxn], num[maxn][CNST]; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); } while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * k; } void Up(int &x, int y) { x = (x + y) % P; if(x < 0) x += P; } int Qpow(int x, int timer, int P) { int base = 1; for(; timer; timer >>= 1, x = x * x % P) if(timer & 1) base = base * x % P; return base; } void Pre(int n) { int x = P, phi = P; fac[0] = finv[0] = 1; for(int i = 2; i * i <= x; i ++) if(x % i) continue; else { phi = phi / i * (i - 1); a[++ cnt] = i; while(!(x % i)) x /= i; } if(x > 1) phi = phi / x * (x - 1), a[++ cnt] = x; for(int i = 1; i <= n; i ++) { int t = i; for(int j = 1; j <= cnt; j ++) { num[i][j] = num[i - 1][j]; while(!(t % a[j])) t /= a[j], num[i][j] ++; } fac[i] = fac[i - 1] * t % P; finv[i] = Qpow(fac[i], phi - 1, P); } } int Get_C(int n, int m) { int ret = fac[n] * finv[m] % P * finv[n - m] % P, x, y; if(n < 0 || m < 0 || n < m) return 0; for(int i = 1; i <= cnt; i ++) { if(a[i] > n) break; int cnt = num[n][i] - num[m][i] - num[n - m][i]; ret = ret * Qpow(a[i], cnt, P) % P; } return ret; } signed main() { n = read(), P = read(), L = read(), R = read(); Pre(n + 5); for(int i = 0; i <= n; i ++) { if(L > i || R < 0) continue; int l = max(L, 0LL), r = min(i, R), t = Get_C(n, i); Up(ans, P - t * Get_C(n - i + 1, i - r - 1) % P); Up(ans, t * Get_C(n - i + 1, i - l) % P); } printf("%I64d ", ans); return 0; }