https://www.luogu.com.cn/problem/P6156
qndjdt
经肉眼观察可得(f(n)=mu(n)^2)
于是大力推柿子
枚举(T=pd),然后((S(n)=sumlimits_{i=1}^nsumlimits_{j=1}^n(i+j)^K))
令(f(n)=sumlimits_{dmid n} mu(d)^2dmu(n/d))
于是
我们只要能快速求出(S),以及(T^Kf(T))的前缀和,就可以数论分块了。
先来看(S)。考虑递推,不难发现
预处理出(Sk(n)=sumlimits_{i=1}^n i^K),那么
(S(n+1)=S(n)+Sk(2n+2)+Sk(2n+1)-2Sk(n+1))
并不能(n log n)算(Sk),所以欧拉筛出(i^K),然后前缀和得到(Sk),再递推出(S),边界是(S(1)=2^K)。
然后是(f),由一坨积性函数卷起来,(f)也是一个积性函数,所以考虑筛法。
在欧拉筛的过程中,会给(i)配上一个质数(p),得到一个新的数(ip),如果(p otmid i),那么直接(f(ip)=f(p)f(i))。
同时对于质数(p),有(f(p)=mu(1)^2mu(p)+mu(p)^2p=p-1)
如果(p^2mid ip)且(p^3 otmid ip),有
(f(p^2)=mu(1)^2mu(p^2)+mu(p)^2pmu(p)+mu(p^2)p^2mu(1)=-p)
那么(f(ip)=f(p^2)f(i/p)=-pf(i/p))。
如果(p^3 mid ip),那么上面算出来(f(p^k))中的(mu),每一项至少有一个为(0)(含有平方因子),所以(f(p^k)=0 quad (k ge 3)),(f(ip)=0)
于是欧拉筛中分类讨论即可。
算完(f),将每个(f(i))乘上(i^K),做前缀和就好了。
(Code:)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e7 + 10, mod = 998244353;
inline int fpow(int x, int y) {
int ans = 1; for (; y; y>>=1, x = 1ll * x * x % mod)
if (y & 1) ans = 1ll * ans * x % mod;
return ans;
}
int ps[777777], pn, Sk[N], S[N], K, f[N];
bool vis[N];
void init() {
int n = 10000000; Sk[1] = 1, f[1] = 1, S[1] = fpow(2, K);
for (int i = 2; i <= n; i++) {
if (!vis[i]) {
ps[pn++] = i, Sk[i] = fpow(i, K);
f[i] = i - 1;
}
for (int j = 0; j < pn && i * ps[j] <= n; j++) {
vis[i * ps[j]] = 1;
Sk[i * ps[j]] = 1ll * Sk[i] * Sk[ps[j]] % mod;
if (i % ps[j] == 0) {
int tmp = i * ps[j];
if (i / ps[j] % ps[j] != 0) f[tmp] = 1ll * f[i/ps[j]] * (mod - ps[j]) % mod;
else f[tmp] = 0;
break;
}
else f[i * ps[j]] = 1ll * f[i] * (ps[j] - 1) % mod;
}
}
for (int i = 1; i <= n/2; i++) f[i] = 1ll * f[i] * Sk[i] % mod;
for (int i = 1; i <= n/2; i++) {
f[i] += f[i-1];
if (f[i] >= mod) f[i] -= mod;
}
for (int i = 1; i <= n; i++) {
Sk[i] += Sk[i-1];
if (Sk[i] >= mod) Sk[i] -= mod;
}
for (int i = 1; i < n/2; i++) {
S[i+1] = (0ll + S[i] + Sk[2*i+2] + Sk[2*i+1] - 2ll * Sk[i+1] % mod + mod) % mod;
}
}
int main() {
int n, ans = 0; ll tK; scanf("%d%lld", &n, &tK);
K = tK % (mod-1); init();
for (int l = 1, r = 0; l <= n; l = r+1) {
r = n / (n/l);
ans += 1ll * S[n/l] * (f[r] - f[l-1] + mod) % mod;
if (ans >= mod) ans -= mod;
}
printf("%d
", ans);
return 0;
}