扩展卢卡斯定理用于求如下式子(其中(p)不一定是质数):
我们将这个问题由总体到局部地分为三个层次解决。
层次一:原问题
首先对(p)进行质因数分解:
显然(p_i^{k_i})是两两互质的,所以如果分别求出(C_n^m mod p_i^{k_i}),就可以构造出若干个形如(C_n^m=a_i mod p_i^{k_i})的方程,然后用中国剩余定理即可求解。
层次二:组合数模质数幂
现在的问题就转化成了求如下式子(其中(p)是质数):
脑补一下组合数公式(C_n^m=frac{n!}{m! imes (n-m)!}),发现由于(m!)和((n-m)!)可能包含质因子(p),所以不能直接求他们对于(p^k)的逆元。此时我们可以将(n!)、(m!)、((n-m)!)中的质因子(p)全部提出来,最后再乘回去即可。即变为下式((k1)为(n!)中质因子(p)的次数,(k2)、(k3)同理):
(frac{m!}{p^{k2}})和(frac{(n-m)!}{p^{k3}})和(p^k)是互质的,可以直接求逆元。
层次三:阶乘除去质因子后模质数幂
现在看看如何计算形如下式的式子。
先考虑如何计算(n! mod p^k)
举个例子:(n=22),(p=3),(k=2)
把这个写出来:
(22!=1 imes 2 imes 3 imes 4 imes 5 imes 6 imes 7 imes 8 imes 9 imes 10 imes 11 imes 12 imes 13 imes 14 imes 15 imes 16 imes 17 imes 18 imes 19 imes 20 imes 21 imes 22)
把其中所有(p)(也就是(3))的倍数提取出来,得到:
(22!=3^7 imes (1 imes 2 imes 3 imes 4 imes 5 imes 6 imes 7) imes(1 imes 2 imes 4 imes 5 imes 7 imes 8 imes 10 imes 11 imes 13 imes 14 imes 16 imes 17 imes 19 imes 20 imes 22 ))
可以看出上式分为三个部分:第一个部分是(3)的幂,次数是小于等于(22)的(3)的倍数的个数,即(lfloorfrac{n}{p} floor)
第二个部分是一个阶乘(7!),即(lfloorfrac{n}{p} floor!),可以递归解决
第三个部分是(n!)中与(p)互质的部分的乘积,这一部分具有如下性质:
(1 imes 2 imes 4 imes 5 imes 7 imes 8equiv10 imes 11 imes 13 imes 14 imes 16 imes 17 mod p^k)
在模(3^2)的意义下(10)和(1)同余,(11)和(2)同余……写成下式就比较显然
((t)是任意正整数)
(prod_{i,(i,p)=1}^{p^k}i)一共循环了(lfloorfrac{n}{p^k} floor)次,暴力求出(prod_{i,(i,p)=1}^{p^k}i)然后用快速幂求它的(lfloorfrac{n}{p^k} floor)次幂。
最后还要乘上(19 imes 20 imes 22)(即(prod_{i,(i,p)=1}^{n mod p^k}i)),显然这一段的长度一定小于(p^k),暴力乘上去即可。
如上三部分的乘积就是(n!)。最终要求的是(frac{n!}{p^{a}} mod p^k),分母全部由上述第一部分和第二部分贡献(第三部分和(p)互质)。而递归计算第二部分的时候已经除去了第二部分中的因子(p),所以最终的答案就是上述第二部分递归返回的结果和第三部分的乘积(与第一部分无关)。
结合代码方便理解:
ll fac(const ll n, const ll p, const ll pk)
{
if (!n)
return 1;
ll ans = 1;
for (int i = 1; i < pk; i++)
if (i % p)
ans = ans * i % pk;
ans = power(ans, n / pk, pk);
for (int i = 1; i <= n % pk; i++)
if (i % p)
ans = ans * i % pk;
return ans * fac(n / p, p, pk) % pk;
}
层次二:组合数模质数幂
回到这个式子
可以很容易地把它转换成代码(注意i要开long long):
ll C(const ll n, const ll m, const ll p, const ll pk)
{
if (n < m)
return 0;
ll f1 = fac(n, p, pk), f2 = fac(m, p, pk), f3 = fac(n - m, p, pk), cnt = 0;
for (ll i = n; i; i /= p)
cnt += i / p;
for (ll i = m; i; i /= p)
cnt -= i / p;
for (ll i = n - m; i; i /= p)
cnt -= i / p;
return f1 * inv(f2, pk) % pk * inv(f3, pk) % pk * power(p, cnt, pk) % pk;
}
层次一:原问题
完整代码(题目:洛谷4720【模板】扩展卢卡斯):
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <climits>
#include <cmath>
using namespace std;
namespace zyt
{
const int N = 1e6;
typedef long long ll;
ll n, m, p;
inline ll power(ll a, ll b, const ll p = LLONG_MAX)
{
ll ans = 1;
while (b)
{
if (b & 1)
ans = ans * a % p;
a = a * a % p;
b >>= 1;
}
return ans;
}
ll fac(const ll n, const ll p, const ll pk)
{
if (!n)
return 1;
ll ans = 1;
for (int i = 1; i < pk; i++)
if (i % p)
ans = ans * i % pk;
ans = power(ans, n / pk, pk);
for (int i = 1; i <= n % pk; i++)
if (i % p)
ans = ans * i % pk;
return ans * fac(n / p, p, pk) % pk;
}
ll exgcd(const ll a, const ll b, ll &x, ll &y)
{
if (!b)
{
x = 1, y = 0;
return a;
}
ll xx, yy, g = exgcd(b, a % b, xx, yy);
x = yy;
y = xx - a / b * yy;
return g;
}
ll inv(const ll a, const ll p)
{
ll x, y;
exgcd(a, p, x, y);
return (x % p + p) % p;
}
ll C(const ll n, const ll m, const ll p, const ll pk)
{
if (n < m)
return 0;
ll f1 = fac(n, p, pk), f2 = fac(m, p, pk), f3 = fac(n - m, p, pk), cnt = 0;
for (ll i = n; i; i /= p)
cnt += i / p;
for (ll i = m; i; i /= p)
cnt -= i / p;
for (ll i = n - m; i; i /= p)
cnt -= i / p;
return f1 * inv(f2, pk) % pk * inv(f3, pk) % pk * power(p, cnt, pk) % pk;
}
ll a[N], c[N];
int cnt;
inline ll CRT()
{
ll M = 1, ans = 0;
for (int i = 0; i < cnt; i++)
M *= c[i];
for (int i = 0; i < cnt; i++)
ans = (ans + a[i] * (M / c[i]) % M * inv(M / c[i], c[i]) % M) % M;
return ans;
}
ll exlucas(const ll n, const ll m, ll p)
{
ll tmp = sqrt(p);
for (int i = 2; p > 1 && i <= tmp; i++)
{
ll tmp = 1;
while (p % i == 0)
p /= i, tmp *= i;
if (tmp > 1)
a[cnt] = C(n, m, i, tmp), c[cnt++] = tmp;
}
if (p > 1)
a[cnt] = C(n, m, p, p), c[cnt++] = p;
return CRT();
}
int work()
{
ios::sync_with_stdio(false);
cin >> n >> m >> p;
cout << exlucas(n, m, p);
return 0;
}
}
int main()
{
return zyt::work();
}