CLYZ 学长学姐们留下来的题,感谢 + 膜拜。
Description
题目背景
本题中合法括号串的定义如下:
()
是合法括号串;- 如果
A
是合法括号串,则(A)
是合法括号串。 - 如果
A
,B
是合法括号串,则AB
是合法括号串。
本题中子串与不同的子串的定义如下:
- 字符串 \(S\) 的子串是 \(S\) 中连续的任意个字符组成的字符串。\(S\) 的子串可用起始位置 \(l\) 与终止位置 \(r\) 来表示,记为 \(S (l, r)\)(\(1 \le l \le r \le |S|\),\(|S|\) 表示 \(S\) 的长度)。
- \(S\) 的两个子串视作不同当且仅当它们本质不同。
题目描述
一个大小为 \(n\) 的树包含 \(n\) 个结点和 \(n - 1\) 条边,每条边连接两个结点,且任意两个结点间有且仅有一条简单路径互相可达。
小 Q 是一个充满好奇心的小朋友,有一天他在上学的路上碰见了一个大小为 \(n\) 的树,树上结点从 \(1 \sim n\) 编号,\(1\) 号结点为树的根。除 \(1\) 号结点外,每个结点有一个父亲结点,\(u\)(\(2 \le u \le n\))号结点的父亲为 \(f_u\)(\(1 \le f_u < u\))号结点。
小 Q 发现这个树的每个结点上恰有一个括号,可能是 (
或 )
。小 Q 定义 \(s_i\) 为:将根结点到 \(i\) 号结点的简单路径上的括号,按结点经过顺序依次排列组成的字符串。
显然 \(s_i\) 是个括号串,但不一定是合法括号串,因此现在小 Q 想对所有的 \(i\)(\(1 \le i \le n\))求出,\(s_i\) 中有多少个互不相同的子串是合法括号串。
这个问题难倒了小 Q,他只好向你求助。设 \(s_i\) 共有 \(k_i\) 个不同子串是合法括号串,你只需要告诉小 Q 所有 \(i \times k_i\) 的异或和,即:
其中 \(\text{xor}\) 是位异或运算。
数据范围:\(1 \leq n \leq 5 \times 10^5\)。
时空限制:\(5000 \ \mathrm{ms} / 512 \ \mathrm{MiB}\)。
Solution
采用增量法。对树进行 DFS,每次计算以 \(u\) 为最低点时的答案。
由于要求本质不同,对于一个串 \(s[u..t_u]\),如果其在 \(s[\mathrm{fa}_u..1]\) 中作为子串出现,则这样的串是不能计入答案的。
考虑求一个深度最小的 \(t_u\),相当于是在 \(\mathrm{fa}_u\) 到 \(1\) 的路径上选一个点 \(v\),使得 \(\mathrm{LCP}(s[u..1], s[v..1])\) 最大。
在树上做一遍 SA,这相当于是在 \(\mathrm{fa}_u\) 到 \(1\) 的路径上,找 \(\mathrm{rk}_u\) 的前驱后继。用 std::set
维护即可。
考虑求在 \(\mathrm{fa}_{t_u}\) 到 \(1\) 的路径上,有多少个点 \(v\) 使得 \(s[u..v]\) 是一个合法括号串。
将左右括号转换为 \(\pm 1\) 做一个前缀和。用倍增找出可行范围后,在 std::vector
上二分维护即可。
时间复杂度 \(\mathcal{O}(n \log n)\)。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <set>
#include <vector>
typedef long long s64;
using std::set;
typedef set<int>::iterator iter;
using std::vector;
template <class T>
inline void read(T &x) {
static char s;
while (s = getchar(), s < '0' || s > '9');
x = s - '0';
while (s = getchar(), s >= '0' && s <= '9') x = x * 10 + s - '0';
}
void tense(int &x, const int &y) {
if (x < y) x = y;
}
void relax(int &x, const int &y) {
if (x > y) x = y;
}
const int N = 500100;
int n;
char s[N];
int anc[19][N];
int tot, head[N], ver[N], Next[N];
void add_edge(int u, int v) {
ver[++ tot] = v; Next[tot] = head[u]; head[u] = tot;
}
int pre[N];
int dep[N];
int g[19][N];
void dfs1(int u) {
pre[u] = pre[anc[0][u]] + (s[u] == '(' ? -1 : 1);
dep[u] = dep[anc[0][u]] + 1;
for (int i = 1; i <= 18; i ++) anc[i][u] = anc[i - 1][anc[i - 1][u]];
g[0][u] = pre[anc[0][u]];
for (int i = 1; i <= 18; i ++) tense(g[i][u] = g[i - 1][u], g[i - 1][anc[i - 1][u]]);
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
dfs1(v);
}
}
namespace SA {
int m = 2;
int sa[N], rk[N], height[N];
int cnt[N], id[N], px[N];
int anc_rk[19][N];
int get_LCP(int x, int y) {
int cur = 0;
for (int i = 18; i >= 0; i --) {
if (std::min(dep[x], dep[y]) < (1 << i)) continue;
if (anc_rk[i][x] ^ anc_rk[i][y]) continue;
x = anc[i][x], y = anc[i][y], cur ^= (1 << i);
}
return cur;
}
bool same(int x, int y, int k) {
int p = anc[k][x] ? anc_rk[k][anc[k][x]] : -1;
int q = anc[k][y] ? anc_rk[k][anc[k][y]] : -1;
return anc_rk[k][x] == anc_rk[k][y] && p == q;
}
void build() {
for (int i = 1; i <= n; i ++) rk[i] = (s[i] == '(' ? 1 : 2);
for (int i = 1; i <= n; i ++) cnt[rk[i]] ++;
for (int i = 1; i <= m; i ++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i --) sa[cnt[rk[i]] --] = i;
for (int k = 0, p = 0; (1 << k) < n; k ++, m = p) {
for (int i = 0; i <= m; i ++) cnt[i] = 0;
for (int i = 1; i <= n; i ++) cnt[px[i] = rk[anc[k][i]]] ++;
for (int i = 1; i <= m; i ++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i --) id[cnt[px[i]] --] = i;
for (int i = 0; i <= m; i ++) cnt[i] = 0;
for (int i = 1; i <= n; i ++) cnt[px[i] = rk[id[i]]] ++;
for (int i = 1; i <= m; i ++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i --) sa[cnt[px[i]] --] = id[i];
for (int i = 1; i <= n; i ++) anc_rk[k][i] = rk[i];
p = 0;
for (int i = 1; i <= n; i ++) rk[sa[i]] = same(sa[i - 1], sa[i], k) ? p : ++ p;
}
for (int i = 1; i <= n; i ++) rk[sa[i]] = i;
for (int i = 2; i <= n; i ++) height[i] = get_LCP(sa[i - 1], sa[i]);
}
}
namespace ST {
int logx[N];
int f[19][N];
void build() {
logx[0] = -1;
for (int i = 1; i <= n; i ++) logx[i] = logx[i >> 1] + 1;
for (int i = 1; i <= n; i ++) f[0][i] = SA::height[i];
for (int j = 1; j <= 18; j ++)
for (int i = 1; i <= n - (1 << j) + 1; i ++)
relax(f[j][i] = f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
}
int query(int l, int r) {
int k = logx[r - l + 1];
return std::min(f[k][l], f[k][r - (1 << k) + 1]);
}
}
set<int> G;
int find_h(int x) {
iter w = G.lower_bound(x);
int cur = 0;
if (w != G.end())
tense(cur, ST::query(x + 1, *w));
if (w != G.begin())
tense(cur, ST::query(*(-- w) + 1, x));
return cur;
}
int find_m(int u) {
int x = u;
for (int i = 18; i >= 0; i --)
if (dep[x] >= (1 << i) && pre[u] >= g[i][x]) x = anc[i][x];
return x;
}
vector<int> pos[N * 2];
s64 ans[N];
void dfs2(int u) {
int l = dep[find_m(u)] + 1, r = dep[u] - find_h(SA::rk[u]);
pos[n + pre[anc[0][u]]].push_back(dep[u]);
G.insert(SA::rk[u]);
ans[u] = ans[anc[0][u]];
if (s[u] == ')') {
if (l <= r) {
vector<int> &V = pos[n + pre[u]];
ans[u] += upper_bound(V.begin(), V.end(), r) - lower_bound(V.begin(), V.end(), l);
}
}
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
dfs2(v);
}
pos[n + pre[anc[0][u]]].pop_back();
G.erase(SA::rk[u]);
}
int main() {
read(n);
scanf("%s", s + 1);
for (int i = 2; i <= n; i ++)
read(anc[0][i]), add_edge(anc[0][i], i);
dfs1(1);
SA::build();
ST::build();
dfs2(1);
s64 res = 0;
for (int i = 1; i <= n; i ++) res ^= (ans[i] * i);
printf("%lld\n", res);
return 0;
}