Description
求∑∑((n mod i)*(m mod j))其中1<=i<=n,1<=j<=m,i≠j。
Input
第一行两个数n,m。
Output
一个整数表示答案mod 19940417的值
Sample Input
3 4
Sample Output
1
样例说明
答案为(3 mod 1)*(4 mod 2)+(3 mod 1) * (4 mod 3)+(3 mod 1) * (4 mod 4) + (3 mod 2) * (4 mod 1) + (3 mod 2) * (4 mod 3) + (3 mod 2) * (4 mod 4) + (3 mod 3) * (4 mod 1) + (3 mod 3) * (4 mod 2) + (3 mod 3) * (4 mod 4) = 1
数据规模和约定
对于100%的数据n,m<=10^9。
Solution
题目就是求
[∑_{i=1}^n∑_{j=1}^m[i≠j](nspace modspace i)(mspace modspace j)
]
先讨论不考虑i≠j的限制条件的情况
[large
egin{align*}
&sum_{i=1}^nsum_{j=1}^m(nspace modspace i)(mspace modspace j)\
&=sumsum{(n-frac{n}{i}*i)(m-frac{m}{j}*j)}\
&=sum_{i=1}^{n}sum_{j=1}^{m}{nm-frac{n}{i}*i*m-n*frac{m}{j}*j+i*j*frac{n}{i}*frac{m}{j}}\
&=n^2m^2-nm^2sum_{i=1}^{n}{frac{n}{i}*i}-n^2*msum_{j=1}^m{frac{m}{j}*j}+nmsum_{i=1}^{n}{i*frac{n}{i}*}sum_{j=1}^{m}{j*frac{m}{j}}
end{align*}
]
这是一种方法
然而还有更简便的方法
[large
sum{nspace modspace i}*sum{mspace modspace j}
]
直接用余数之和那题的方法求这个就好(不知道余数之和那题怎么写的戳这里)
就不用上面一大堆码起来也麻烦的式子了
对于i==j的情况
[large
egin{align*}
&sum_{i=1}^{k=min(n,m)}{(n-frac{n}{i}*i)(m-frac{m}{i}*i)}[i==j]\
&=sum_{i=1}^{k}{nm-m*frac{n}{i}*i-n*frac{m}{i}*i+i^2*frac{n}{i}*frac{m}{i}}\
&=knm-kmsum_{i=1}^{k}{frac{n}{i}*i}-knsum_{i=1}^{k}{frac{m}{i}*i}+ksum_{i=1}^{k}{i^2}sum_{i=1}^{k}{frac{n}{i}}sum_{i=1}^{k}{frac{m}{i}}
end{align*}
]
利用数论分块(O(sqrt{n}))求出上面两式,将两式相减即可
P.S:(sum_{i=1}^n{i^2}=frac{n*(n+1)*(2n+1)}{6})
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define N 2010
#define mod 19940417
const ll m6 = 3323403;
ll n, m;
ll ans = 0;
ll sum(ll l, ll r) {
return (r - l + 1) * (l + r) / 2 % mod;
}
ll calc(ll k) {
ll ans = k * k % mod;
for(int l = 1, r; l <= k; l = r + 1) {
r = k / (k / l);
ans = ((ans - sum(l, r) * (k / l) % mod) % mod + mod) % mod;
}
return ans;
}
ll cal(ll x) {
return x * (x + 1) % mod * (2 * x + 1) % mod * m6 % mod;
}
ll sum2(ll l, ll r) {
return (cal(r) - cal(l - 1) + mod) % mod;
}
int main() {
scanf("%lld%lld", &n, &m);
if(n > m) swap(n, m);
ans = calc(n) * calc(m) % mod;
ans = ((ans - n * n % mod * m % mod) % mod + mod) % mod;
for(int l = 1, r; l <= n; l = r + 1) {
r = min(n / (n / l), m / (m / l));
ans = (ans + sum(l, r) * ((n/l)*m % mod + (m/l)*n % mod) % mod % mod);
ans = (ans - sum2(l, r) * (n/l) % mod * (m/l) % mod + mod) % mod;
}
printf("%lld
", (ans % mod + mod) % mod);
return 0;
}