【BZOJ4803】逆欧拉函数
题面
题解
题目是给定你(varphi(n))要求前(k)小的(n)。
设(n=prod_{i=1}^k{p_i}^{c_i})
则(varphi(n)=prod_{i=1}^k{p_i}^{c_i-1}(p_i-1))
然后我们猜一下这个(n)不是很多,事实上(n)不超过(50w)个。
考虑暴力(dfs)出所有的(n):
首先筛出(sqrt{varphi(n)})内的素数
对于当前(dfs)的值(phi)
看(phi)中的约数有没有(筛出的素数-1)
若有,假设该素数为(p)
去除(phi)中的所有(p),之后再将(dfs)的(n)累乘上(p)
在每一次递归开头用(miller)_(Rabin)判断(phi+1)是否为素数,如果是,则直接加进答案就行了
想一想,为什么?
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <ctime>
using namespace std;
typedef long long ll;
const int MAX_N = 1e7 + 5;
const int T = 10;
bool is_prime[MAX_N];
int prime[MAX_N], num, K;
ll N = 1e7, ans[MAX_N], cnt_ans;
void sieve() {
for (int i = 1; i <= N; i++) is_prime[i] = 1;
is_prime[1] = 0;
for (int i = 2; i <= N; i++) {
if (is_prime[i]) prime[++num] = i;
for (int j = 1; prime[j] * i <= N && j <= num; j++) {
is_prime[i * prime[j]] = 0;
if (!(i % prime[j])) break;
}
}
}
ll fmul(ll x, ll y, ll Mod) {
ll res = 0;
while (y) {
if (y & 1ll) res = (res + x) % Mod;
y >>= 1ll;
x = (x + x) % Mod;
}
return res;
}
ll fpow(ll x, ll y, ll Mod) {
ll res = 1;
while (y) {
if (y & 1ll) res = fmul(res, x, Mod);
y >>= 1ll;
x = fmul(x, x, Mod);
}
return res;
}
bool Test(ll a, ll n) {
ll r = 0, t = n - 1, m;
while ((t & 1ll) == 0) ++r, t >>= 1ll;
m = (n - 1) / (1ll << r);
for (int i = 0; i < r; i++) if (fpow(a, (1ll << i) * m, n) == n - 1) return 1;
if (fpow(a, m, n) == 1) return 1;
return 0;
}
bool Miller_Rabin(ll n) {
if (n == 2ll) return 1;
if (n < 2ll || ((n & 1ll) == 0)) return 0;
for (int i = 1; i <= T; i++) {
ll a = rand() % (n - 2) + 2;
if (fpow(a, n - 1, n) != 1) return 0;
if (!Test(a, n)) return 0;
}
return 1;
}
void solve(ll phi, ll n, int lst) {
if (phi + 1 > prime[num] && Miller_Rabin(phi + 1))
ans[++cnt_ans] = n * (phi + 1);
for (int i = lst; i; i--) {
if (!(phi % (prime[i] - 1))) {
ll t1 = phi / (prime[i] - 1), t2 = n, t3 = 1ll;
while (!(t1 % t3)) {
t2 *= prime[i];
solve(t1 / t3, t2, i - 1);
t3 *= prime[i];
}
}
}
if (phi == 1ll) ans[++cnt_ans] = n;
}
int main () {
srand(time(NULL));
sieve();
cin >> N >> K;
solve(N, 1ll, num);
sort(&ans[1], &ans[cnt_ans + 1]);
for (int i = 1; i < K; i++) printf("%lld ", ans[i]);
printf("%lld
", ans[K]);
return 0;
}