题目链接
题意
给定一个长度为 (n) 的数列 (a_1,...,a_n) 与 (q) 个询问 (x_1,...,x_q),对于每个 (x_i) 回答有多少对 ((l,r)) 满足( (1leq lleq rleq n)) 且 (gcd(a_l,a_{l+1},...,a_r)=x_i)
思路
对于固定的右端点 (i),将左端点从右 ((i)) 向左 ((1)) 延伸,(gcd) 值是递减的,且变化次数不超过 (logC) ((C)为数列中最大值)
下面讲述两种方法,第一种效率高一些,而第二种也提供了一些新的见解。
法一:滚动数组 —— 更新分段信息
枚举右端点,将由左端点划分出的 (gcd) 值分段。每次用新加进来的 (a_i) 去与刚刚的若干段再取 (gcd) 并更新分段信息,更新的同时统计数目。
保存与更新分段信息 可用滚动数组实现,统计数目 则显然用map(要注意的一点是:需要用map<int, LL>
,因为数目可能会爆(int))。
法二:二分 + ST表 —— 找gcd值变化位置
参考自 hzwer.
如果说上一种做法是极大程度地利用了 上一次的信息,那么这一种做法就是抓住了 gcd值具有单调性 这个特点。
因此,确定分段位置可以直接采用二分查找,而如何快速地获取某一段的 (gcd) 值呢?就靠 (ST) 表大显身手了。
// 学到两点:
// 1. ST表适用的范围不仅局限于区间极值问题
// 2. 系统自带的log是真的慢...
Code
Ver. 1 : 171ms
#include <bits/stdc++.h>
#define maxn 100010
using namespace std;
typedef long long LL;
int a[maxn];
int gcd(int a, int b) { return b ? gcd(b, a % b) : a; }
struct node { int x, p; };
map<int, LL> mp;
vector<node> v[2];
int main() {
int n;
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%d", &a[i]);
for (int i = 0; i < n; ++i) {
bool me = i & 1,
op = !me;
v[me].clear();
v[me].push_back({a[i], i});
int last = a[i];
for (auto nd : v[op]) {
int temp = gcd(nd.x, a[i]);
if (temp == last) v[me][v[me].size()-1].p = nd.p;
else v[me].push_back({temp, nd.p}), last = temp;
}
int now = i;
for (auto nd : v[me]) {
int pre = nd.p;
mp[nd.x] += now - pre + 1;
now = pre - 1;
}
}
int q, x;
scanf("%d", &q);
while (q--) {
scanf("%d", &x);
printf("%I64d
", mp[x]);
}
return 0;
}
Ver. 2 : 296ms
#include <bits/stdc++.h>
#define maxn 100010
using namespace std;
typedef long long LL;
int gcd[maxn][32], a[maxn], n, Log[maxn], bin[32];
map<int, LL> mp;
int Gcd(int a, int b) { return b ? Gcd(b, a%b) : a; }
void rmqInit() {
Log[0] = -1; bin[0] = 1;
for (int i = 1; i < 20; ++i) bin[i] = bin[i-1] << 1;
for (int i = 1; i <= n; ++i) Log[i] = Log[i>>1] + 1, gcd[i][0] = a[i];
for (int j = 1; bin[j] <= n; ++j) {
for (int i = 1; i + bin[j-1] - 1 <= n; ++i) {
gcd[i][j] = Gcd(gcd[i][j-1], gcd[i + bin[j-1]][j-1]);
}
}
}
int query(int l, int r) {
int k = Log[r-l+1];
return Gcd(gcd[l][k], gcd[r-bin[k]+1][k]);
}
int bi(int i, int l, int r, int x) {
while (r-l>1) {
int mid = l+r >> 1, val = query(i, mid);
if (val >= x) l = mid;
else r = mid - 1;
}
return query(i, r) == x ? r : l;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
rmqInit();
for (int i = 1; i <= n; ++i) {
int l = i;
while (true) {
if (l == n+1) break;
int val = query(i, l);
int r = bi(i, l, n, val);
mp[val] += r-l+1;
l = r+1;
}
}
int q, x;
scanf("%d", &q);
while (q--) {
scanf("%d", &x);
printf("%I64d
", mp[x]);
}
return 0;
}