题意
首先一波莫比乌斯反演可得(懒得写latex了):
设(sum(x)=sumlimits_{i=1}^xi)
(ans=sumlimits_{T=1}^{n}sum(frac{n}{T})^2T^2sumlimits_{d|n}dmu(frac{T}{d}))
有(id*mu=varphi)
证明:
由(id=varphi*1,1*mu=epsilon)可得:
(id*mu=varphi*1*mu->id*mu=varphi)
知道这个后:
(sumlimits_{T=1}^{n}sum(frac{n}{T})^2T^2varphi(T))
显然前面可以除法分块,考虑怎么求(T^2varphi(T))
设(f(i)=i^2*varphi(i)),发现这是个积性函数,考虑杜教筛:
设(S(n)=sumlimits_{i=1}^{n}f(i))
先上杜教筛套路式子:
(g(1)*S(n)=sumlimits_{i=1}^n(g*f)(i)-sumlimits_{i=2}^ng(i)*f(frac{n}{i}))
考虑找合适的(g):
((f*g)(n)=sumlimits_{d|n}g(d)*f(frac{n}{d}))
(=sumlimits_{d|N}g(d)*frac{n^2}{d^2}varphi(frac{n}{d}))
由于(sumlimits_{d|n}varphi(d)=n),所以令(g(x)=x^2)
(=sumlimits_{d|n}d^2*frac{n^2}{d^2}varphi(frac{n}{d}))
(=sumlimits_{d|n}n^2varphi(frac{n}{d}))
(=n^2sumlimits_{d|n}varphi(frac{n}{d}))
(=n^3)
因此当(g(x)=x^2)时,((f*g)(x)=x^3)。
于是先杜教筛再除法分块即可。
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=8*1e6+10;
ll n,mod,inv2,inv6,ans;
ll phi[maxn],sum[maxn];
bool vis[maxn];
vector<int>prime;
unordered_map<ll,ll>mp;
inline ll sqr(ll x){return x*x%mod;}
inline ll calc1(ll x){x%=mod;return x*(x+1)%mod*(2*x+1)%mod*inv6%mod;}
inline ll calc2(ll x){x%=mod;return x*(x+1)%mod*inv2%mod;}
inline ll power(ll x,ll k,ll mod)
{
ll res=1;
while(k)
{
if(k&1)res=res*x%mod;
x=x*x%mod;k>>=1;
}
return res;
}
inline void pre_work(int n)
{
vis[1]=1;phi[1]=1;
for(int i=2;i<=n;i++)
{
if(!vis[i])prime.push_back(i),phi[i]=(i-1)%mod;
for(unsigned int j=0;j<prime.size()&&i*prime[j]<=n;j++)
{
vis[i*prime[j]]=1;
if(i%prime[j]==0)
{
phi[i*prime[j]]=1ll*phi[i]*prime[j]%mod;
break;
}
phi[i*prime[j]]=1ll*phi[i]*phi[prime[j]]%mod;
}
}
for(int i=1;i<=n;i++)sum[i]=(sum[i-1]+1ll*phi[i]*i%mod*i%mod)%mod;
}
inline ll getsum(ll x)
{
if(x<=8000000)return sum[x];
if(mp.count(x))return mp[x];
ll res=sqr(calc2(x));
for(ll l=2,r;l<=x;l=r+1)
{
r=x/(x/l);
res-=(calc1(r)-calc1(l-1))%mod*getsum(x/l)%mod;
res%=mod;
}
return mp[x]=(res+mod)%mod;
}
int main()
{
scanf("%lld%lld",&mod,&n);
inv2=power(2,mod-2,mod),inv6=power(6,mod-2,mod);
pre_work(8000000);
for(ll l=1,r;l<=n;l=r+1)
{
r=n/(n/l);
ans=(ans+sqr(calc2(n/l))*(getsum(r)-getsum(l-1))%mod+mod)%mod;
}
printf("%lld",(ans+mod)%mod);
return 0;
}