Description:
https://gmoj.net/senior/#main/show/6639
题解:
考虑(n!)怎么算,经典做法:
设(v=sqrt n),当(n~mod~v eq 0)时就(n--),最后加上这个的贡献就好了。
每一块可以看做(prod {i=1}^v (x+i)),现在求(x=(0..n/v-1)*v)的点值,把多项式分治展开后多点求值即可。
时间复杂度:(O(sqrt n*log^2n))。
这个题可以看做(frac{sum_{i=1}^n prod_{j eq i} j}{n!}),差不多的做。
当然这样就TLE了,这题卡的比较紧。
考虑zzq博客上的倍增点值做法:
假设有多项式(A_n=prod_{i=1}^n (x+i))
我们用一些(n+1)个点值来表示这个多项式,这题就用((0..n)*v)即可。
假设已知:(A_n((0..n)*v))
思考如何求出(A_n((0..2n)*v))
发现可以插值,设要求(xv+c)处的点值,(c=(n+1)v,x in [0,n])
(A(xv+c)=sum_{i=0}^n A(iv) prod_{j eq i} frac{xv+c-jv}{iv-jv})
因为(forall jin [0,n],xv+c-jv eq 0)
所以可以写成:
(=(prod_{j=0}^n (xv+c-jv))*(sum_{i=0}^n frac{(-1)^{n-i}}{i!*(n-i)!}*v^{-n}*frac{1}{xv+c-iv}))
前面可以写成区间积形式,处理前缀积和前缀逆元积即可。
后面的可以NTT。
知道了上面的变换,我们做两次就可以由(A_n((0..n)v))推到(A_{2n}((0..2n)v))。
这样就可以倍增求点值了,不难发现复杂度是(O(n~log~n))的。
这题就需要同时维护两个多项式,有点细节,还要MTT。
Code:
#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i < _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("
")
using namespace std;
ll n, k, mo;
ll ksm(ll x, ll y) {
x %= mo;
ll s = 1;
for(; y; y /= 2, x = x * x % mo)
if(y & 1) s = s * x % mo;
return s;
}
#define db double
#define V vector<ll>
#define si size()
#define re resize
namespace mtt {
const db pi = acos(-1);
struct P {
db x, y;
P(db _x = 0, db _y = 0) { x = _x, y = _y;}
P operator + (P b) { return P(x + b.x, y + b.y);}
P operator - (P b){ return P(x - b.x, y - b.y);}
P operator * (P b) { return P(x * b.x - y * b.y, x * b.y + y * b.x);}
};
const int nm = 1 << 19;
P w[nm]; int r[nm];
P c0[nm], c1[nm], c2[nm], c3[nm];
void build() {
for(int i = 1; i < nm; i *= 2) ff(j, 0, i)
w[i + j] = P(cos(pi * j / i), sin(pi * j / i));
}
void dft(P *a, int n) {
ff(i, 0, n) {
r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
if(i < r[i]) swap(a[i], a[r[i]]);
} P b;
for(int i = 1; i < n; i *= 2) for(int j = 0; j < n; j += 2 * i)
ff(k, 0, i) b = a[i + j + k] * w[i + k], a[i + j + k] = a[j + k] - b, a[j + k] = a[j + k] + b;
}
void rev(P *a, int n) {
reverse(a + 1, a + n);
ff(i, 0, n) a[i].x /= n, a[i].y /= n;
}
P conj(P a) { return P(a.x, -a.y);}
void fft(ll *a, ll *b, int n) {
#define qz(x) ((ll) round(x))
ff(i, 0, n) c0[i] = P(a[i] & 32767, a[i] >> 15), c1[i] = P(b[i] & 32767, b[i] >> 15);
dft(c0, n); dft(c1, n);
ff(i, 0, n) {
P k, d0, d1, d2, d3;
int j = (n - i) & (n - 1);
k = conj(c0[j]);
d0 = (k + c0[i]) * P(0.5, 0);
d1 = (k - c0[i]) * P(0, 0.5);
k = conj(c1[j]);
d2 = (k + c1[i]) * P(0.5, 0);
d3 = (k - c1[i]) * P(0, 0.5);
c2[i] = d0 * d2 + d1 * d3 * P(0, 1);
c3[i] = d0 * d3 + d1 * d2;
}
dft(c2, n); dft(c3, n); rev(c2, n); rev(c3, n);
ff(i, 0, n) {
a[i] = qz(c2[i].x) + (qz(c2[i].y) % mo << 30) + (qz(c3[i].x) % mo << 15);
a[i] %= mo;
}
}
ll a[nm], b[nm];
V operator * (V p, V q) {
int n0 = p.si + q.si - 1, n = 1;
while(n < n0) n *= 2;
ff(i, 0, n) a[i] = b[i] = 0;
ff(i, 0, p.si) a[i] = p[i];
ff(i, 0, q.si) b[i] = q[i];
fft(a, b, n);
p.re(n0);
ff(i, 0, n0) p[i] = a[i];
return p;
}
}
using mtt :: operator *;
int v;
namespace sub1 {
const int N = 1e6 + 5;
ll fac[N], nf[N], f[N], vf[N];
V func(V a, int c) {
int n = a.si - 1;
fac[0] = 1; fo(i, 1, n) fac[i] = fac[i - 1] * i % mo;
nf[n] = ksm(fac[n], mo - 2); fd(i, n, 1) nf[i - 1] = nf[i] * i % mo;
f[0] = -n * v + c;
fo(i, 1, 2 * n) f[i] = f[i - 1] * ((i - n) * v + c) % mo;
vf[2 * n] = ksm(f[2 * n], mo - 2);
fd(i, 2 * n, 1) vf[i - 1] = vf[i] * ((i - n) * v + c) % mo;
ll inv_v = ksm(ksm(v, mo - 2), n);
V p, q; p.re(n + 1); q.re(2 * n + 1);
fo(i, 0, n) {
p[i] = nf[i] * nf[n - i] % mo * ((n - i) % 2 ? -1 : 1) * inv_v % mo * a[i] % mo;
}
fo(i, 0, 2 * n) {
q[i] = ksm((i - n) * v + c, mo - 2);
}
p = p * q;
V w; w.re(n + 1);
fo(i, 0, n) {
w[i] = f[i + n];
if(i > 0) w[i] = w[i] * vf[i - 1] % mo;
w[i] = w[i] * p[i + n] % mo;
}
return w;
}
}
using sub1 :: func;
#define pvv pair<V, V>
#define fs first
#define se second
pvv ch2(pvv a) {
V b = a.fs, c = a.se;
int n = b.si - 1;
V d = func(b, (n + 1) * v);
b.re(2 * n + 1);
fo(i, n + 1, 2 * n) b[i] = d[i - (n + 1)];
d = func(b, n / k);
V e = func(c, (n + 1) * v);
c.re(2 * n + 1);
fo(i, n + 1, 2 * n) c[i] = e[i - (n + 1)];
e = func(c, n / k);
fo(i, 0, 2 * n) c[i] = (c[i] * d[i] + e[i] * b[i]) % mo;
fo(i, 0, 2 * n) b[i] = b[i] * d[i] % mo;
return pvv(b, c);
}
pvv jia1(pvv a) {
int n = a.fs.si - 1;
int m = n / k;
V b = a.fs, c = a.se;
b.re(n + 1 + k);
fo(i, 0, n) b[i] = b[i] * ksm((i * v + m + 1), k) % mo;
fo(j, 1, k) {
b[n + j] = 1;
fo(i, 1, m + 1) b[n + j] = b[n + j] * ksm((n + j) * v + i, k) % mo;
}
c.re(n + 1 + k);
fo(i, 0, n) c[i] = (c[i] * ksm((i * v + m + 1), k) % mo + a.fs[i]) % mo;
fo(j, 1, k) {
ll s1 = 1, s2 = 0;
fo(i, 1, m + 1) {
ll w = ksm((n + j) * v + i, k) % mo;
s2 = (s2 * w + s1) % mo;
s1 = s1 * w % mo;
}
c[n + j] = s2;
}
return pvv(b, c);
}
pvv solve(int n) {
if(n == 1) {
V a; a.re(k + 1);
fo(i, 0, k) a[i] = ksm(i * v + 1, k);
V b; b.re(k + 1);
fo(i, 0, k) b[i] = 1;
return pvv(a, b);
}
pvv a = solve(n / 2);
a = ch2(a);
if(n % 2 == 1) a = jia1(a);
return a;
}
const int N = 1e6 + 5;
ll p[N], q[N];
ll calc() {
pvv a = solve(v);
int m = n / v - 1;
ll xs = 1;
fo(i, 0, m) xs = xs * ksm(a.fs[i], mo - 2) % mo;
p[0] = a.fs[0];
fo(i, 1, m) p[i] = p[i - 1] * a.fs[i] % mo;
q[m + 1] = 1;
fd(i, m, 0) q[i] = q[i + 1] * a.fs[i] % mo;
ll s = 0;
fo(i, 0, m) {
ll xs = 1;
if(i > 0) xs = xs * p[i - 1] % mo;
xs = xs * q[i + 1] % mo;
s = (s + xs * a.se[i]) % mo;
}
s = s * xs % mo;
return s;
}
int main() {
freopen("minusk.in", "r", stdin);
freopen("minusk.out", "w", stdout);
mtt :: build();
scanf("%lld %lld %lld", &n, &k, &mo);
if(n <= 1e6) {
ll ans = 0;
fo(i, 1, n) ans = (ans + ksm(ksm(i, k), mo - 2)) % mo;
pp("%lld
", ans);
return 0;
}
v = max(1, (int) ceil(sqrt((double) n / k)));
ll ans = 0;
while(n > 0 && n % v != 0) {
ans = (ans + ksm(ksm(n, k), mo - 2)) % mo;
n --;
}
ans = (ans + calc() + mo) % mo;
pp("%lld
", ans);
}