突然发现自己还不会AC自动姬,就去学了一下下。。。
自动姬de用途:
给你n个模式串,1个文本串,求出有多少模式串在文本串中出现过。(详见洛谷P3808)
大体思路:
我先前貌似介绍过一种叫做KMP的神奇东西,还有叫做trie树的逆天玩意儿,嗯,今天要讲的AC自动姬就是两者相结合,诞生的产品啦。
(图片来自百度)
嗯,看到上面的图片,让我们手动过滤掉那些虚线,剩下的就是一颗加入了he,she,his,hers,四个字符串的trie树了。
基础trie树的构建(代码):
void insert(char *s){ int now=0; for (int i=0;s[i];i++){ if (!trie[now][s[i]-'a']) trie[now][s[i]-'a']=++tot; now=trie[now][s[i]-'a']; } end[now]++; }
Fail指针:
那么那些虚线又是什么鬼呢?
其实,它就是自动姬的精髓所在,我们称它为Fail指针(也叫失配指针)。
我觉得失配指针这个名字取的特别形象,因为fail指针的作用就是在文本串与当前串失去匹配后的退路。
什么意思呢?
比如说:文本串在she的h处与模式串失去了匹配(也就是说文本串的下一个字符不是‘e’),也就是说,我们现在被字符串“she”给炒鱿鱼了,那么我们应该怎么做呢,无家可归了吗?
不要慌——既然“she”不要你了,但是我们“his”还是要你的,毕竟你有字符‘h’在,和我们‘his’的前一个字符是相符合的,你可以来我们这里试试看,假如说合适的话,就留下好了。
这样就很容易理解了吧!
我们fail指针的作用就是为了找到这样一条“后路”。
那么假如说我们真的沦落到了没人要的地步,也不用担心,我们“家里蹲”(家,是最温暖的港湾)还是会收留你的,嗯,没错,真的没人要的话,只要把fail指针指向根节点就行了。
我们可以采用BFS的方式来求出这个fail指针。
求出当前节点的fail指针的条件是:我们已经求出了之前所有节点的fail指针。 这里的“之前”指的就是所有深度小于当前节点的点,所以才要用BFS来实现嘛。
为什么会有这个前置条件呢,换句话说,为什么我们知道了这些就能求出当前节点的fail指针呢?
这还得从fail指针的定义出发——失配后的退路,也就是说,我们要找到当前串的部分后缀与其它模式串的前缀完全相同的节点,fail指针便是指向这个节点的。
所以,当前的fail指针肯定会指向深度小于等于当前节点的节点。
具体我们可以怎么做呢,其实很简单。
因为要和当前串的部分后缀完全相同,说明肯定要和前一个字符为止的部分后缀完全相同嘛,所以我们就先要找到fail[s[i-1]]指向的节点,假如说它有s[i]这个儿子的话,那么fail指针就指向它的这个儿子了,如果没有的话,我们就要顺着这个节点的fail指针接着向上找,一直到找到或者到达根节点为止。
我们会很清楚的感觉到:在顺着节点的fail指针不断往上找的过程中,与当前串的后缀完全相同的前缀的部分的长度是在不断减小的。(貌似很拗口,但是如果你有这种感觉的话,就说明你对AC自动姬的学习已经有点感觉了)
Fail指针的构建(代码):
void build(){ queue <int> q; for (int i=0;i<26;i++) if (trie[0][i]) q.push(trie[0][i]); while (!q.empty()){ int now=q.front(); q.pop(); for (int i=0;i<26;i++) if (trie[now][i]) fail[trie[now][i]]=trie[fail[now]][i],q.push(trie[now][i]); else trie[now][i]=trie[fail[now]][i];//注意了,这行代码非常精髓,把原本我们可能要跳多次的环节,直接省略到了一次,因为这句话在执行一个“路径压缩”的操作,把一些没用的(跳了之后发现没有想要儿子的)跳跃全部都省去了 } }
与文本串的匹配:
有了构建Fail指针时的路径压缩之后,查询操作就显的简单多了,因为路径压缩后,我们把一些前缀与当前串的部分后缀相同的节点都连接到了这个节点下方,比如“she”和“her”,以为she的后缀he是与her的前缀he完全相同的,所以路径压缩操作会给当前节点(当前节点是指通过she到达的节点,也就是5号点)新加入一条边‘r’(图中红色的那条),连到沿着her走到达的节点(也就是8号点)。所以我们就可以实现在自动姬上反复横跳啦。
(哔——————————————————————————————————
这样的话,我们就可以无脑在自动姬上跳来跳去,一直顺着文本串在自动姬上走就行了,然后再走的过程中,我们把fail指针能知道的模式串一路遍历过去,统计他们的出现就好。
查询操作(代码):
int query(char *s){ int now=0,res=0; for (int i=0;s[i];i++){ now=trie[now][s[i]-'a']; for (int j=now;j&&~end[j];j=fail[j]) res+=end[j],end[j]=-1; } return res; }
完整代码:
#include <bits/stdc++.h> using namespace std; const int maxn=1000005; char s[maxn]; int trie[maxn][30],fail[maxn],end[maxn],n,tot; void insert(char *s){ int now=0; for (int i=0;s[i];i++){ if (!trie[now][s[i]-'a']) trie[now][s[i]-'a']=++tot; now=trie[now][s[i]-'a']; } end[now]++; } void build(){ queue <int> q; for (int i=0;i<26;i++) if (trie[0][i]) q.push(trie[0][i]); while (!q.empty()){ int now=q.front(); q.pop(); for (int i=0;i<26;i++) if (trie[now][i]) fail[trie[now][i]]=trie[fail[now]][i],q.push(trie[now][i]); else trie[now][i]=trie[fail[now]][i]; } } int query(char *s){ int now=0,res=0; for (int i=0;s[i];i++){ now=trie[now][s[i]-'a']; for (int j=now;j&&~end[j];j=fail[j]) res+=end[j],end[j]=-1; } return res; } int main(){ scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%s",s),insert(s); build(); scanf("%s",s); printf("%d ",query(s)); return 0; }