parent 树与 SAM 构造算法概述
设母串为 S, 标号为 0...n-1。S[l,r] 表示一个子串 S[l]...S[r]。
考虑一类集合 A, 其包含了所有满足长度是某个常数,且结尾位置集合相同的子串(子串 S[l...r] 的结尾位置是 r)。设 len(A) 是其中子串的长度。
考虑一个这种类别的集合 E, 它只包含了长度为 0 的子串 —— 空串, 这里将空串的出现位置定义为 [1,0]...[n-1,n-2],[n,n-1]。
考虑对于任意这种集合, 记为 X, 有一个在前面加字符的操作, 即把所有 S[l,r] ∈ X 变成 S[l-1,r](如果可以变的话), 那么这些新的串各自组成几个新的这类集合。
从 E 开始, 不断进行在前面加字符的操作, 直到不可以加为止, 显然在这个过程中 S 的所有子串都被遍历到了。
查看集合的变化轨迹, 可以发现以老集合为其生成的新集合的父亲,是一颗树的结构, 这东西就是常说的 parent 树。(没缩链的)
把往前加字符的操作换成往后加字符的操作,再把链缩一下, 就得到了 S 的后缀树。
流传比较广的 SAM 构造算法, 是 “blumer 的后缀自动机构造算法”, 即是一种通过往前加字符的增量法构造缩链后缀 Trie 和节点之间往前加字符的转移的算法。
所以从前往后用这个算法增量, 得到的是反串的后缀 Trie 及正串的往后加字符的转移(这同时也是反串的往前加字符的转移)。
blumer 后缀自动机构造算法
考虑一颗串 S 的后缀 Trie。
对一个串 s 定义 node(s) 为 s 对应的 Trie 中的节点,即 Trie(s) 转移到的节点,若没有,定义为 null。
对一个串 s 和一个字符 x 定义一种在前面加字符的操作, trans(node(s),x) = node(xs)。
考虑在 S 前面加一个字符 c, 要得到 cS 的后缀 Trie 及 trans 函数。对于加入节点的部分,只需要对 S 的所有前缀对应的节点 v 处理 trans(v, c) 即可。对于 trans 函数的部分, 对于所有新加入的点 v(原来没有的点), 显然对于所有的 x 都有 trans(v,x) = null。
注意到若 trans(node(sy),x) ≠ null, 那么有 trans(node(s),x) ≠ null, 所以需要更新 trans 函数的必然是 S 到某个 S 的前缀 p 对应的所有点。
以下是广为流传的 SAM 实现,加了适当的注释。
struct sam_node {
int t[26], f, len;
} a[...];
// 对于后缀树中的节点, 没有保留在后缀树中的转移边, 只保留了指向后缀树
// 中祖先的指针 f
// f 是往前加字符的转移
// len 是节点的最大深度(因为节点都是缩链的)
int tn=1, las=1; // 1 是根
void extnd(int c)
{
int p = las, np = ++tn; las = np;
a[np].len = a[p].len + 1;
for (; p && !a[p].t[c]; p = a[p].f) a[p].t[c] = np;
if (!p) a[np].f = 1; // 直接从根处分叉
else {
int v = a[p].t[c];
if (a[v].len == a[p].len + 1) a[np].f = v;// 直接覆盖了 v 这条链
else {// 在 v 的某处分叉了
int nv = ++tn; a[nv] = a[v];
a[nv].len = a[p].len + 1;
for (; p && a[p].t[c] == v; p = a[p].f) a[p].t[c] = nv;
a[np].f = a[v].f = nv;
}
}
}
洛谷模板题
考虑这个子串所在的节点, 由于一个节点里结束集合都相等, 所以这个子串必然是结束集合里最长的子串。
所以递推出所有节点的结束集合大小, 再枚举节点就行了。
#include<bits/stdc++.h>
using namespace std;
const int N = 1e6;
char S[N + 23];
int sa[N*2+23], c[N*2+3], siz[N*2+3];
struct sam_node {
int t[26], f, len;
} a[N*2 + 23];
int tot=1, las=1;
void extnd(char c) {
int np = ++tot, p = las; las = np;
a[np].len = a[p].len + 1;
for (; p && !a[p].t[c]; p = a[p].f) a[p].t[c] = np;
if(!p) a[np].f = 1;
else {
int v = a[p].t[c];
if(a[v].len == a[p].len + 1) a[np].f = v;
else {
int nv = ++tot; a[nv] = a[v];
a[nv].len = a[p].len + 1;
for(; p && a[p].t[c] == v; p = a[p].f) a[p].t[c] = nv;
a[np].f = a[v].f = nv;
}
}
siz[np] = 1;
}
int main() {
long long ans = 0ll;
scanf("%s", S);
int n = strlen(S);
for(int i = 0; i < n; ++i) extnd(S[i]-'a');
for(int i = 1; i <= tot; ++i) ++c[a[i].len];
for(int i = 1; i <= n; ++i) c[i] += c[i - 1];
for(int i = tot; i >= 1; --i) sa[c[a[i].len]--] = i;
for(int i = tot; i >= 1; --i) {
int p = sa[i];
siz[a[p].f] += siz[p];
if (siz[p] > 1) ans = max(ans, 1ll * a[p].len * siz[p]);
}
cout << ans;
return 0;
}