答案是所有的方案数-三点共线的方案数
(ans=总方案数-横着三点共线或竖着三点贡献的方案数-斜着三点共线的方案数)
只有斜着的要考虑,设(f(n,m))为(N*M)网格的斜着三点共线的方案数
枚举每个曼哈顿距离为((i,j),i,j>0)的点对
那么在点对连线中坐标为整数的点(不包括两个端点)的个数为(gcd(i,j)-1)
简单证明一下..
从点对的一点到另一点的曼哈顿距离是((i,j))
在该直线上的下一个整数点的位置距离起点显然是((frac{i}{gcd(i,j)},frac{j}{gcd(i,j)}))
那么该直线能分成(i/(frac{i}{gcd(i,j)})=gcd(i,j))段,即有(gcd(i+j)+1)的点数,减去端点数就是(gcd(i,j)-1)
那么(f(n,m)=2sumlimits^n_{i=1}sumlimits^m_{j=1}(n-i+1)(m-i+1)(gcd(i,j)-1))
(=2sumlimits^n_{i=1}sumlimits^m_{j=1}(n-i+1)(m-i+1)gcd(i,j)-2sumlimits^n_{i=1}sumlimits^m_{j=1}(n-i+1)(m-i+1))
计算(sumlimits^n_{i=1}sumlimits^m_{j=1}(n-i+1)(m-i+1)gcd(i,j))
枚举(gcd)的值,不妨令(n<m)
(=sumlimits^n_{d=1}d*sumlimits^{n/d}_{i=1}sumlimits^{m/d}_{j=1}(n-i*d+1)(m-j*d+1)[gcd(i,j)=1])
把后面的部分莫比乌斯反演下
(=sumlimits^n_{d=1}d*sumlimits^{n/d}_{i=1}sumlimits^{m/d}_{j=1}(n-i*d+1)(m-j*d+1)sumlimits_{t|gcd(i,j)}mu(t))
再把t提前
(=sumlimits^n_{d=1}sumlimits_{t=1}^{n/d}mu(t)*dsumlimits^{n/(dt)}_{i=1}sumlimits^{m/(dt)}_{j=1}(n-j*d*t+1)(m-i*d*t+1))
令(T=d*t)
(sumlimits_{T=1}^{n}dsumlimits_{d|T}mu(frac{T}{d})sumlimits^{n/T}_{i=1}sumlimits^{m/T}_{j=1}(n-i*T+1)(m-j*T+1))
可以想到(sum_{d|n}frac{mu(d)}{d}=frac{varphi(n)}{n})
那么(sumlimits_{T=1}^{n}sumlimits_{d|T}d*mu(frac{T}{d}))
(=sumlimits_{T=1}^{n}sumlimits_{d|T}frac{Tmu(d)}{d})
(sumlimits^n_{T=1}varphi(T))
得到(sumlimits^n_{T=1}varphi(T)sumlimits^{n/T}_{i=1}sumlimits^{m/T}_{j=1}(n-i*T+1)(m-j*T+1))
后面那个部分就是个等差数列求和后相乘,预处理欧拉函数后对于每个T就能O(1)求解
最后复杂度O(n),这题套路和能量采集基本相同,都是将莫比乌斯函数与欧拉函数相结合
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/rope>
// #include <ext/pb_ds/priority_queue.hpp>
using namespace __gnu_pbds;
using namespace std;
// freopen("k.in", "r", stdin);
// freopen("k.out", "w", stdout);
// clock_t c1 = clock();
// std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
mt19937 rnd(time(NULL));
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
#define sqr(x) (x) * (x)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef pair<char, char> PCC;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
const double pi = acos(-1.0);
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 2e6 + 7;
const ll MAXM = 4e5 + 7;
const int MOD = 1e9 + 7;
const double eps = 1e-5;
int prime[MAXN], phi[MAXN], cnt;
bool flag[MAXN];
void Euler(int n)
{
cnt = 0;
memset(flag, 0, sizeof(flag));
phi[0] = 0;
phi[1] = 1;
for (int i = 2; i <= n; i++)
{
if (!flag[i])
{
prime[cnt++] = i;
phi[i] = i - 1;
}
for (int j = 0; j < cnt && 1LL * i * prime[j] <= n; j++)
{
int k = i * prime[j];
flag[k] = 1;
if (i % prime[j] == 0)
{
phi[k] = phi[i] * prime[j];
break;
}
else
phi[k] = phi[i] * (prime[j] - 1);
}
}
}
ll cal(int x) { return 1LL * x * (x - 1) * (x - 2) / 6; }
ll sumA[MAXN], sumB[MAXN];
ll tmp[MAXN];
ll cal2(int a1, int d, int n) { return 1LL * n * a1 + 1LL * n * (n - 1) / 2 * d; }
int main()
{
int n, m;
scanf("%d%d", &n, &m);
if (n > m)
swap(n, m);
Euler(n);
ll ans = cal((n + 1) * (m + 1)) - 1LL * (m + 1) * cal(n + 1) - (n + 1) * cal(m + 1) + 2 * cal2(n, -1, n) * cal2(m, -1, m);
for (int T = 1; T <= n; T++)
ans -= 2 * 1LL * phi[T] * cal2(n - T + 1, -T, n / T) * cal2(m - T + 1, -T, m / T);
printf("%lld
", ans);
return 0;
}