Solution
- 先考虑\(n_1=0\)的情况
- 那么只要考虑形如\(X_i>=A_i\)的限制
- 注意求的是正整数解的个数,即对于\(i>n_2\),\(X_i>=1(A_i=1)\)
- \(\sum_{i=1}^{n}B_i=m\)的非负整数解的个数为\(C(m+n-1,m)\)
- 解释:序列共\(m+n-1\)个位置,选\(n-1\)个位置出来当隔板,把序列分为长度之和为\(m\)的\(n\)段(可能存在长度为\(0\)的段,即隔板相邻的情况)
- 现在为了满足这些限制,令\(B_i=X_i-A_i\),则\(B_i\)的非负整数解的个数就是原题的合法解的个数
- 那么\(m\)要减掉\(\sum_{i=1}^{n}A_i\)
- 考虑\(n_1>0\)的情况,用总方案数\(-\)存在\(X_i>=A_i+1(1<=i<=n_1)\)的情况
- 即考虑容斥:不考虑前\(n_1\)个数的限制的方案数\(-\)前\(n_1\)个数至少有\(1\)个不满足条件的方案数\(+\)前\(n_1\)个数至少有\(2\)个不满足条件的方案数\(-\)……
- 发现\(n,m\)很大,但任意一组数据的\(p\)都可以拆成\(\Pi_{i=1}^{k}pi^{qi}\),且\(p_i<=10007\),那么用扩展\(lucas\)求组合数取模即可
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int o = 2000;
int a[o], b[o], pk, p, c[o], d[o], tst, n1, n2, n, m, ans, h[o], now, f[20][10010];
bool vis[o];
ll tot;
inline int exgcd(int a, int b, int &x, int &y)
{
if (!b)
{
x = 1;
y = 0;
return a;
}
int ret = exgcd(b, a % b, x, y), tmp = x;
x = y;
y = tmp - a / b * y;
return ret;
}
inline int ksm(int x, ll y)
{
int res = 1;
while (y)
{
if (y & 1) res = (ll)res * x % pk;
y >>= 1;
x = (ll)x * x % pk;
}
return res;
}
inline int fac(int n, int p, int k)
{
if (n == 1 || n == 0) return 1;
ll cnt = n / p, bl = n / pk, res = fac(n / p, p, k), i, tmp;
tot += cnt;
tmp = f[now][pk - 1];
tmp = ksm(tmp, bl);
res = res * tmp % pk;
res = res * f[now][n % pk] % pk;
return res;
}
inline int solve(int n, int m, int id)
{
int p = c[id], k = d[id];
pk = a[id];
tot = 0;
int ra = fac(m, p, k); ll ta = tot;
tot = 0;
int rb = fac(n - m, p, k); ll tb = tot;
tot = 0;
int rc = fac(n, p, k); ll tc = tot;
ll t = tc - ta - tb;
if (t < 0) t = (t % k + k) % k;
int ia, ib, xxx;
exgcd(ra, pk, ia, xxx);
exgcd(rb, pk, ib, xxx);
if (ia < 0) ia += pk;
if (ib < 0) ib += pk;
return (ll)rc * ia % pk * ib % pk * ksm(p, t) % pk;
}
inline void init()
{
int i, s = sqrt(p), lp = p, j;
for (i = 2; i <= s; i++)
if (lp % i == 0)
{
int t = 0, r = 1;
while (lp % i == 0)
{
t++;
r *= i;
lp /= i;
}
a[++a[0]] = r;
c[a[0]] = i;
d[a[0]] = t;
}
if (lp != 1)
{
a[++a[0]] = lp;
c[a[0]] = lp;
d[a[0]] = 1;
}
for (i = 1; i <= a[0]; i++)
{
f[i][0] = 1;
for (j = 1; j <= a[i]; j++)
if (j % c[i]) f[i][j] = (ll)f[i][j - 1] * j % a[i];
else f[i][j] = f[i][j - 1];
}
}
inline int cc(ll n, ll m, int p)
{
if (n < m || m < 0) return 0;
int ans = 0, i;
for (i = 1; i <= a[0]; i++)
{
now = i;
b[i] = solve(n, m, i);
}
for (i = 1; i <= a[0]; i++)
{
int mi = p / a[i], g, y, aa = a[i];
exgcd(mi, aa, g, y);
ans = (ans + (ll)mi * g % p * b[i] % p + p) % p;
}
return ans;
}
inline void add(int &x, int y)
{
x += y;
if (x >= p) x -= p;
}
inline void pd()
{
int i, tm = m, cnt = 0;
for (i = 1; i <= n1; i++)
if (vis[i])
{
cnt++;
tm -= h[i] + 1;
}
else tm--;
if (!cnt) return;
if (cnt & 1) add(ans, p - cc(tm + n - 1, tm, p));
else add(ans, cc(tm + n - 1, tm, p));
}
inline void dfs(int k)
{
if (k == n1 + 1)
{
pd();
return;
}
vis[k] = 0;
dfs(k + 1);
vis[k] = 1;
dfs(k + 1);
}
int main()
{
int i;
read(tst); read(p);
init();
while (tst--)
{
read(n);
read(n1);
read(n2);
read(m);
int tmp = n1 + n2;
for (i = 1; i <= tmp; ++i) read(h[i]);
m -= n - n1 - n2;
for (i = n1 + 1; i <= n2 + n1; i++) m -= h[i];
int tm = m - n1;
ans = cc(tm + n - 1, tm, p);
dfs(1);
printf("%d\n", ans);
}
return 0;
}