AC 自动机(Aho-Corasick Automaton),也是一种字符串匹配算法。主要用于解决多个模式串匹配主串的问题,它的本质是用 Trie + KMP 算法。
原理
与 KMP 算法类似,主要步骤:
-
将所有模式串构建成一棵 Trie 树;
-
在 Trie 上构造所有节点的前缀指针;
-
利用前缀指针对主串进行匹配。
模式串与主串的匹配
与 KMP 算法完全相同。
匹配步骤:
-
如果当前字符匹配(
ch[j][S[i + 1] - 'a'] != 0
),则继续匹配下一个字符(i++, j = ch[j][S[j + 1] - 'a']
); -
如果当前字符失配(
ch[j][S[i + 1] - 'a'] == 0
),则重新对齐(j = nxt[j]
)直到匹配或者找到了 0。
构造节点的前缀指针
构建方式与 KMP 算法类似。
在 KMP 中,如果在模式串的一个位置失配,则需要回到模式串的前面一个位置继续匹配。从位置 $i$ 处失配后回到 $j$ 位置 ,记作 $fail(i) = j$ 。
考虑 $fail(i) = j$ 的条件:串的前 j个字符组成的前缀,是前 个字符组成前缀的后缀。理论依据是,这样可以保证每一时刻已匹配的字符尽量多,避免遗漏。
现在将问题转化为,在一棵 Trie 上,求一个节点 $j$,使得从根到 $j$ 的路径组成的串是从根到 $i$ 的路径组成串的后缀。
如图(图片来自 Menci 的《AC 自动机学习笔记》):
设 $i$ 父节点为 $i'$, 的入边上的字母为 $c$。
一个显然的结论是,如果 $fail(i')$ 有字母 $c$ 的出边,则该出边指向的点即为 $fail(i)$。例如,上图中 $fail(7) = 1, fail(8) = 2$。
如果 $fail(i')$ 没有字母 $c$ 的出边,则沿着失配函数继续向上找,找到 $fail(fail(i'))$ …… 直到找到根为止,如果找不到一个符合条件的节点,则 $fail(i)$ 为根。 例如,上图中 $fail(3) = 0$。
BFS 优化
if (!ch[u][i]) ch[u][i] = ch[nxt[u]][i];
如果当前节点 u
不存在 i
的转移边,则创建对应儿子,并让它的指向更短前缀,即指向该节点前缀指针。
原来判断节点的转移边 i
是否存在,存在就直接赋值,否则缩短前缀继续判断直到匹配(while (v > 0 && !ch[v][i]) v = nxt[v]
)。根据 BFS 遍历特性,浅层的节点已经构建好了前缀指针,如果让节点不存在的转移边 i
直接指向的更短前缀。这样就无需判断,直接赋值即可。(如果节点前缀指针也没有 i
的转移边怎么办。其实前缀指针的 i
的转移边已经在浅层遍历时指向更短前缀,达到了上面 while
语句的效果。)
通过 BFS 构建步骤:
-
初始化所有与根节点相连的转移边(
ch[0][c] = 1
); -
由浅到深遍历每个节点,每到达一个节点遍历它的所有转移边
i
; -
如果节点
u
存在i
的转移边(ch[u][i] != 0
),则该节点转移边进队,且它的儿子ch[u][i]
前缀指针指向该节点前缀指针(nxt[ch[u][i]] = ch[nxt[i]][i]
); -
如果节点
u
不存在i
的转移边(ch[u][i] == 0
),则让节点不存在的转移边i
指向该节点前缀指针(ch[u][i] = ch[nxt[u]][i]
)。
模板
给定 $n$ 个模式串 $s_i$ 和一个文本串 $t$,求有多少个不同的模式串在文本串里出现过。
两个模式串不同当且仅当他们编号不同。
const int MAXN = 1000005;
int book[MAXN];
int ch[MAXN][30];
int nxt[MAXN], tot;
int que[MAXN], l, r;
char p[MAXN], s[MAXN];
void init() {
tot = 0, l = 1, r = 1;
memset(ch, 0, sizeof(ch));
memset(nxt, 0, sizeof(nxt));
memset(book, 0, sizeof(book));
memset(que, 0, sizeof(que));
}
void insert(char *s) {
int u = 0;
int len = strlen(s);
for (int i = 0; i < len; i++) {
int c = s[i] - 'a';
if (!ch[u][c]) ch[u][c] = ++tot;
u = ch[u][c];
}
book[u]++;
}
void build(){
for (int i = 0; i < 26; i++){
if (ch[0][i]) {
nxt[ch[0][i]] = 0;
que[r++] = ch[0][i];
}
}
while (l < r) {
int u = que[l++];
for (int i = 0; i < 26; i++) {
if(!ch[u][i]) ch[u][i] = ch[nxt[u]][i];
else {
que[r++] = ch[u][i];
nxt[ch[u][i]] = ch[nxt[u]][i];
}
}
}
}
int query(char *s) {
int res = 0;
int len = strlen(s), u = 0;
for(int i = 0; i < len; i++){
u = ch[u][s[i] - 'a'];
for(int k = u; k && ~book[k]; k = nxt[k]) {
res += book[k];
book[k] = -1;
}
}
return res;
}
在 query
函数中:
for(int k = u; k && ~book[k]; k = nxt[k]) {
res += book[k];
book[k] = -1;
}
当 k != 0
且 book[k]
未被标记(book[k] != -1
)时执行。
原理是按照 Trie 的方式去匹配。当匹配到一个模式串时,累加它被标记的次数,并缩短后缀,继续判断是否在匹配串集中,直到根节点。每次访问直接把 book[k]
累加到 res
即可,为了避免重复访问,访问过后标记为 -1
。