(color{#0066ff}{ 题目描述 })
求
(C_n^m mod{p})
其中 (C) 为组合数。
(color{#0066ff}{输入格式})
一行三个整数 (n,m,p) ,含义由题所述。
(color{#0066ff}{输出格式})
一行一个整数,表示答案
(color{#0066ff}{输入样例})
5 3 3
666 233 123456
(color{#0066ff}{输出样例})
1
61728
(color{#0066ff}{数据范围与提示})
(1≤m≤n≤10^{18}),(2≤p≤1000000) ,不保证 (p) 是质数。
(color{#0066ff}{ 题解 })
简单来说,这个东西就是吧p分解成(prod p^c)
求出组合数mod(p^c)的值然后用CRT合并
这里要用到快速阶乘
比如19的阶乘mod(3^2)
(1*2*3*4*5*6*7*8*9*10*11*12*13*14*15*16*17*18*19)
把所有3的倍数都提出一个3
(3^6*6!*(1*2*4*5*7*8)*(10*11*13*14*16*17)*19)
发现分解成了一个快速幂,一个子问题阶乘,后面的还有循环节
递归处理即可
#include<bits/stdc++.h>
#define LL long long
LL in() {
char ch; LL x = 0, f = 1;
while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
return x * f;
}
LL ksm(LL x, LL y, LL mod) {
LL re = 1LL;
while(y) {
if(y & 1) re = re * x % mod;
x = x * x % mod;
y >>= 1;
}
return re;
}
LL exgcd(LL a, LL b, LL &x, LL &y) {
if(!b) return x = 1, y = 0, a;
LL r = exgcd(b, a % b, x, y);
LL t = x - a / b * y;
return x = y, y = t, r;
}
LL inv(LL x, LL mod) {
LL p, q;
exgcd(x, mod, p, q);
return ((p % mod) + mod) % mod;
}
LL n, m, p;
//乘了p/mod是为了保证其它式子不受影响,再乘当前的逆元保证的事当前式子不受影响
LL CRT(LL b, LL mod) { return b * inv(p / mod, mod) % p * (p / mod) % p; }
LL fac(LL x, LL r, LL rk) {
//0的阶乘为1
if(!x) return 1;
LL ans = 1;
//单个循环的ans
for(LL i = 1; i <= rk; i++) if(i % r) ans = ans * i % rk;
//有循环节,快速幂一下
ans = ksm(ans, x / rk, rk);
//最后剩余的可能不足一个循环节的部分
for(LL i = 1; i <= x % rk; i++) if(i % r) ans = ans * i % rk;
//子问题(快速幂写在了外面,方便)
return ans * fac(x / r, r, rk) % rk;
}
LL C(LL x, LL y, LL r, LL rk) {
LL X = fac(x, r, rk), Y = fac(y, r, rk), XY = fac(x - y, r, rk);
LL ans = 0;
for(LL i = x; i; i /= r) ans += i / r;
for(LL i = y; i; i /= r) ans -= i / r;
for(LL i = x - y; i; i /= r) ans -= i / r;
return X * inv(Y, rk) % rk * inv(XY, rk) % rk * ksm(r, ans, rk) % rk;
}
LL exlucas() {
LL res = p, ans = 0;
for(LL i = 2; i * i <= p; i++) {
if(res % i == 0) {
LL tot = 1;
while(res % i == 0) tot *= i, res /= i;
(ans += CRT(C(n, m, i, tot), tot)) %= p;
}
}
if(res > 1) (ans += CRT(C(n, m, res, res), res)) %= p;
return ans;
}
int main() {
n = in(), m = in(), p = in();
printf("%lld
", exlucas());
return 0;
}