对于已经学了3,4遍kmp的我,竟然感觉AC自动机不太难。
kmp是解决单模式串匹配的问题,就是只能判断一个字符串是否为另一个字符串的子串。AC自动机是解决多模式串匹配的问题,能判断多个字符串是否同是为一个字符串的子串。
先从最暴力的想法入手:首先肯定要把所有模式串建成一棵trie树,然后枚举查询字符串的起始点 i,看从 i 开始的子串是否能出现模式串。这就像kmp的暴力算法:寻找最大的 j,使模式串b[1 ~ j] = b[i - j + 1 ~ j]。然后我们减少无用的寻找,尽量继承上一次寻找的答案来降低复杂度,就是线性的kmp。
AC自动机就是在trie树上进行匹配,所以可以理解为trie树上的kmp。当我们匹配到某一个节点u时,它代表的是模式串中的一些串的前缀b[1 ~ j],同也要满足是查询字符串的一部分a[i - j + 1 ~ i]。这个 j 也可以理解为trie树上这个节点的深度。当失配的时候,我们要找次长的前缀b[1 ~ k](也就是深度比 j 小的点),看他能否等于查询串的a[i - k + 1 ~ i],这个k不仅仅是还是一个模式串b的前缀,可能也在trie树上的其他链上的,但一定是最大的k,满足b[1 ~ k] = a[i - k + 1 ~ i]。
接着考虑具体怎么实现:kmp的状态转移是从i - 1转移过来的,所以在trie树上就是从其父亲节点转移,因此要保证对于u,深度比u小的点的失配指针都已经构造好了。所以可以从根节点开始bfs,对于当前队首的节点u,构造u的所有儿子节点的失配指针:如果失配,就沿着父亲的失配指针往上跳,直到找到一个节点的下一个状态和当前节点一样。
但正是因为构造每一个节点的fail时要挑好多次,所以可能会超时,解决的办法就是构造虚拟节点:对于u的一个不存在的儿子ch[u][i],让ch[u][i] = ch[fail[u]][i],也就是直接指向一个最长的前缀b[1 ~ j],使节点u的下一个字符如果是 i 的话,保证b[1 ~ j]= b[i - j + 1 ~ i]。这样如果v = ch[now][i]存在的话,fail[v] = ch[fail[now]][i]。
有人会问,如果ch[fail[u]][i]也不存在呢?这是不可能的,因为如果ch[fail[u]][i]不存在的话,那么他存的值是ch[ch[fail[u][i]][i]。
查询其实因题而异,比如这个板子。只用统计出现次数,那么跑ac自动机的时候如果这个模板串已经找出来了,就改了他的val,避免重复统计。
具体实现就是像kmp一样,沿着失配边往后跳,沿路统计所有节点的val,并改成-1,证明已经走过了。那么终止条件要么是跳到头了,要么是val[j] = -1。因为如果val[j] = -1,那么fail[j], fail[fail[j]]一定都是-1.
放一个这道题的代码
1 #include<cstdio> 2 #include<iostream> 3 #include<algorithm> 4 #include<cmath> 5 #include<cstring> 6 #include<cstdlib> 7 #include<cctype> 8 #include<stack> 9 #include<queue> 10 #include<vector> 11 using namespace std; 12 #define enter puts("") 13 #define space putchar(' ') 14 #define Mem(a, x) memset(a, x, sizeof(a)) 15 #define rg register 16 typedef long long ll; 17 typedef double db; 18 const int INF = 0x3f3f3f3f; 19 const db eps = 1e-8; 20 const int maxn = 1e6 + 5; 21 inline ll read() 22 { 23 ll ans = 0; 24 char ch = getchar(), las = ' '; 25 while(!isdigit(ch)) las = ch, ch = getchar(); 26 while(isdigit(ch)) ans = ans * 10 + ch - '0', ch = getchar(); 27 if(las == '-') ans = -ans; 28 return ans; 29 } 30 inline void write(ll x) 31 { 32 if(x < 0) putchar('-'), x = -x; 33 if(x >= 10) write(x / 10); 34 putchar(x % 10 + '0'); 35 } 36 37 int n; 38 char s[maxn]; 39 40 int ch[maxn][30], cnt = 0, val[maxn], f[maxn]; 41 42 int getnum(char c) 43 { 44 return c - 'a'; 45 } 46 void insert(char *s) 47 { 48 int m = strlen(s); 49 int now = 0; 50 for(int i = 0; i < m; ++i) 51 { 52 int c = getnum(s[i]); 53 if(!ch[now][c]) ch[now][c] = ++cnt; 54 now = ch[now][c]; 55 } 56 val[now]++; 57 } 58 void build() 59 { 60 queue<int> q; 61 for(int i = 0; i < 26; ++i) 62 if(ch[0][i]) q.push(ch[0][i]); 63 while(!q.empty()) 64 { 65 int now = q.front(); q.pop(); 66 for(int i = 0; i < 26; ++i) 67 { 68 if(ch[now][i]) f[ch[now][i]] = ch[f[now]][i], q.push(ch[now][i]); 69 else ch[now][i] = ch[f[now]][i]; 70 } 71 } 72 } 73 int query(char *s) 74 { 75 int len = strlen(s), now = 0, ans = 0; 76 for(int i = 0; i < len; ++i) 77 { 78 int c = getnum(s[i]); 79 now = ch[now][c]; 80 for(int j = now; j && val[j] != -1; j = f[j]) ans += val[j], val[j] = -1; 81 } 82 return ans; 83 } 84 85 int main() 86 { 87 n = read(); 88 for(int i = 1; i <= n; ++i) 89 { 90 scanf("%s", s); 91 insert(s); 92 } 93 build(); 94 scanf("%s", s); 95 write(query(s)); enter; 96 return 0; 97 }