AtCoder Regular Contest 099F - Eating Symbols Hard
题目大意
- 给定一个下标从负无穷到正无穷且初始全为
0
0
0的序列和初始在
0
0
0的指针,
N
N
N个支持指针左/右移或当前位置加/减
1
1
1的操作,求操作序列有多少个子区间,使得执行此区间内操作的结果和执行所有操作后的结果相同(最后指针位置可以不同)。
-
N
≤
250000
N≤250000
N≤250000
题解
- 首先可以发现,序列的下标范围再大,超出了
±
N
±N
±N的部分都是不会被修改的,那么只有在
[
−
N
,
N
]
[-N,N]
[−N,N]内才可能变动,也就是说只用考虑这个范围。
- 求区间个数,想到用分治解决,
- 当前区间
[
l
,
r
]
[l,r]
[l,r],特判
l
=
r
l=r
l=r的情况,其他情况再统计左端点在
[
l
,
m
i
d
]
[l,mid]
[l,mid]且右端点在
(
m
i
d
,
r
]
(mid,r]
(mid,r]的方案数,
- 一个序列用
N
N
N进制数表示,
- 先倒着扫一遍左半区间,算出每个以每个位置作为区间起点时,还需要修改的数是多少才能使最终序列满足条件,把它扔进哈希里,
- 接着顺着扫右半区间,在哈希里查找对应的数,统计答案。
- 要注意指针位置变化的操作,也就是说序列压成的数可能需要整个左/右移若干位。
- 用双哈希更保险,单哈希也可以过。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 250010
#define ll long long
#define mo 998244353
#define mo1 9999973
#define md 9999973
char st[N];
int n, id = 0, ok;
ll f[N * 2], g[N * 2], ff[N * 2], gg[N * 2], a[N * 2], t = 0, tt = 0, ans = 0;
struct {
int i, s; ll S1, S2;
}h[2 * md + 5];
void put(ll S1, ll S2) {
int x = S1 * S2 % (md * 2) + 1;
while(h[x].i == id) {
if(h[x].S1 == S1 && h[x].S2 == S2) {
h[x].s++;
return;
}
x = x % (md * 2) + 1;
}
h[x].i = id, h[x].S1 = S1, h[x].S2 = S2, h[x].s = 1;
}
void find(ll S1, ll S2) {
int x = S1 * S2 % (md * 2) + 1;
while(h[x].i == id) {
if(h[x].S1 == S1 && h[x].S2 == S2) {
ans += h[x].s;
return;
}
x = x % (md * 2) + 1;
}
}
ll ksm(ll x, ll y, ll P) {
if(!y) return 1;
ll l = ksm(x, y / 2, P);
if(y % 2) return l * l % P * x % P;
return l * l % P;
}
ll F(ll x, int p) {
if(p < 0) return x * ff[-p] % mo;
return x * f[p] % mo;
}
ll G(ll x, int p) {
if(p < 0) return x * gg[-p] % mo1;
return x * g[p] % mo1;
}
void solve(int l, int r) {
if(l == r) {
if(ok && a[n] == 0 && st[l] == '<') ans++;
if(ok && a[n] == 0 && st[l] == '>') ans++;
if(ok && a[n] == 1 && st[l] == '+') ans++;
if(ok && a[n] == -1 && st[l] == '-') ans++;
return;
}
int mid = (l + r) / 2;
solve(l, mid), solve(mid + 1, r);
int p = 0; ll ts1 = 0, ts2 = 0;
id++;
for(int i = mid + 1; i <= r; i++) {
if(st[i] == '>') p++;
if(st[i] == '<') p--;
if(st[i] == '+') {
ts1 = (ts1 + f[n - p]) % mo;
ts2 = (ts2 + g[n - p]) % mo1;
}
if(st[i] == '-') {
ts1 = (ts1 - f[n - p] + mo) % mo;
ts2 = (ts2 - g[n - p] + mo1) % mo1;
}
put(ts1, ts2);
}
ts1 = ts2 = p = 0;
for(int i = mid; i >= l; i--) {
if(st[i] == '<') p++;
if(st[i] == '>') p--;
if(st[i] == '+') {
ts1 = (ts1 + f[n - p]) % mo;
ts2 = (ts2 + g[n - p]) % mo1;
}
if(st[i] == '-') {
ts1 = (ts1 - f[n - p] + mo) % mo;
ts2 = (ts2 - g[n - p] + mo1) % mo1;
}
find(F((t - F(ts1, p) + mo) % mo, -p), G((tt - G(ts2, p) + mo1) % mo1, -p));
}
}
int main() {
int i;
scanf("%d
", &n);
scanf("%s", st + 1);
int p = 0;
f[0] = g[0] = 1;
for(i = 1; i <= 2 * n; i++) {
f[i] = f[i - 1] * (2 * n + 1) % mo, g[i] = g[i - 1] * (2 * n + 1) % mo1;
ff[i] = ksm(f[i], mo - 2, mo), gg[i] = ksm(g[i], mo1 - 2, mo1);
}
for(i = 1; i <= n; i++) {
if(st[i] == '+') a[p + n]++;
if(st[i] == '-') a[p + n]--;
if(st[i] == '<') p--;
if(st[i] == '>') p++;
}
t = tt = 0; ok = 1;
for(i = 2 * n; i >= 0; i--) {
t = (t + f[2 * n - i] * a[i] % mo + mo) % mo, tt = (tt + g[2 * n - i] * a[i] % mo1 + mo1) % mo1;
if(i != n && a[i]) ok = 0;
}
solve(1, n);
printf("%lld
", ans);
return 0;
}