算法详解
很长时间内都没有能够很理解KMP算法的精髓,尤其是很多书上包括《算法导论》没有把next函数(亦或 π函数)讲解的很透彻。
今天去看了matrix67大牛博客中关于kmp部分的讲解,有点儿醍醐灌顶的感觉,当然也只是理解了一点浅层次罢了。
我尝试着用自己的语言说一下自己的理解,顺便锻炼一下自己渣一般的逻辑组织能力。。。。。。
下面开始正题吧~~~
我们知道单模字符串匹配基本就是三种方法:
一、朴素枚举。最坏时间复杂度O(mn)。
二、Rabin-Karp。需要O(m)的预处理。虽然最坏时间复杂度也是O(mn),但出现最坏情况的几率比朴素法小很多,所以这种方法实际应用还是比较广泛的。
三、Knuth-Morris-Pratt。即KMP算法。O(m)的预处理时间,O(n)的匹配时间,非常高效。
先从朴素枚举法说起吧。枚举法就是从字符串的第一位开始,把每一位都作为开头来与模式串逐位匹配一次,如果匹配失败则开始下一位做开头试,直到找到匹配为止。
i = 1 2 3 4 5 6 ……
A = a a a b a a …
B = a a a b a c b
j = 1 2 3 4 5 6 7
我们知道,要优化一个算法就要知道它哪里做了多余的事情。从上面情况来看,在i == j == 6时当前匹配失败,朴素法会把模式串试着与i == 2做开头匹配,但是从我们前面已经得到的信息已经可以得出,i == 2做开头是不可能匹配成功的,所以朴素枚举在这里就做了无用功(也可以说 i 指针做了无用的回溯)。(而且我们要发现在某种坏的情况下它做了相当多的无用功!比如这种情况:000000000000000000000001,模式串00000000001)
那么我们应该想到一个优化了。抽象一下就是,当我们前面已经匹配A[i-j..i-1] == B[1..j],而遇到A[i] != B[j+1]时,我们可以快速(O(1)时间内)找到一个新的j,使得新的A[i-j..i-1] == B[1..j],并且A[i] == B[j+1],这样我们的 i 指针就可以不动,紧接着向下匹配就好了。为了能够快速找到这样的j,我们使用一个辅助的next[]数组(π数组),next[j]记录当B[j+1] != A[i]时,新的j的位置。
这样我们的KMP算法的大致代码就是:
string T,P; bool KMP() { bool flag = false; int n = T.length(); int m = P.length(); int j = -1; for (int i = 0; i < n; i ++) { while(j > -1 && B[j+1] != A[i]) j = next[j]; if (B[j+1] == A[i]) j ++; if (j == m - 1) { flag = true; //匹配成功 break; //j = next[j]; } } return flag; }
前面说过,next[]的值是由前面匹配时已经得到的信息得出来的,那么我们发现next[]的值只与模式串有关。(因为前面已经匹配的地方A、B串是一样的,所以只由一个B串是可以得到信息的),这样我们就可以通过对模式串(B串)进行预处理来求出next[]。
那么怎么样根据模式串求出next[]呢?举个例子来说明:
i = 1 2 3 4 5 6 ……
A = a b a b a b …
B = a b a b a c b
j = 1 2 3 4 5 6 7
当i == 6 ,j == 5时A[i] != B[j+1],那么这时我们应该调整j了,手动调整一下可以发现新的j == 3,即next[5] = 3。
i = 1 2 3 4 5 6 ……
A = a b a b a b …
B = a b a b a c b
j = 1 2 3 4 5 6 7
发现没有,新的j'要满足情况,则需要满足B[1..j‘]与B[j-j'+1..j]匹配。
那么我们可以得出一个朴素的O(m2)的求next[]的方法了。对于每一个j,枚举j',找出符合条件的最大的j'就行了。但是明显这个算法是不够优的(如果m很大怎么办?)
实际上上面一句话已经启示我们到底该怎么求next[]了,“满足B[1..j']与B[j-j'+1..j]匹配”,这就是一个模式串自身匹配的过程啊!
我们看个求next[](这里是π[])的图:
我们看看next[j]可不可以由next[1..j]的信息得出。假设当前要求next[6]。由j == next[5] == 3可得到,B[1..3] == B[3..5],而此时B[4] != B[6],所以此时匹配失败,需要回溯重新匹配了。但我们不想回溯啊!前面说了回溯是做无用功啊!那么我们是不是要找一个新的j'使得B[1..j'] == B[5-j'+1..5]。那么怎么找呢?由B[1..3] == B[3..5]可以知道我们可以这么求:B[1..j'] == B[3-j'+1..3](想想为什么可以这样,可以看看上面那个图),那不就是next[j]了么~~~
写一下代码来加深理解:
void GetNext() { next[0] = -1; int j = -1; int m = p.length(); for (int i = 1; i < m; i ++) { while(j > -1 && p[j+1] != p[i]) j = next[j]; if (p[j+1] == p[i]) j++; next[i] = j; } }
这样我们的KMP算法就描述完了~~~
专题训练:
♦POJ 3461 Oulipo(入门题)
光看样例就一脸纯模板题的样儿。。。。。。
#include <iostream> #include <cstdio> #include <string> #include <cstring> using namespace std; const int N = 1000100; string A, B; int pi[N]; void getp() { memset(pi,0,sizeof(pi)); int m = B.length(); pi[0] = -1; int j = -1; for (int i = 1;i < m; i ++) { while(j > -1 &&B[j+1] != B[i]) j = pi[j]; if (B[j+1] == B[i]) j++; pi[i] = j; } } int kmp() { int res = 0; getp(); int n = A.length(); int m = B.length(); int j = -1; for (int i = 0; i < n; i ++) { while(j > -1 && A[i] !=B[j+1]) j = pi[j]; if (A[i] == B[j+1]) j++; if (j == m - 1) { res ++; j = pi[j]; } } return res; } int main() { int tt; scanf("%d",&tt); while (tt--) { cin>>B; cin>>A; cout<<kmp()<<endl; } return 0; }
♦POJ 1226 Substrings (最长公共子串。KMP + 二分)
数据范围很小,随便搞吧?=。=。其实就是二分答案长度,枚举出该长度的每一个串,然后用KMP验证。总复杂度O(n*len2*log(len))吧,可以接受的。
但是我怎么可以这么弱?各种小错误啊调试了2个小时了吧靠靠靠靠这么个水题。。。。。。
#include <cstdio> #include <iostream> #include <string> #include <cstring> using namespace std; char s[110][110]; char p[110]; int pi[110]; void getpi() { int m = strlen(p); int j = -1; pi[0] = -1; for (int i = 1; i < m; i ++) { while(j > -1 && p[j+1] != p[i]) j = pi[j]; if (p[j+1] == p[i]) j++; pi[i] = j; } } bool kmp(int x) { getpi(); int n = strlen(s[x]); int m = strlen(p); int j = -1; for (int i = 0; i < n; i ++) { while(j > -1 && s[x][i] != p[j+1]) j = pi[j]; if (s[x][i] == p[j+1]) j++; if (j == m - 1) { return true; } } return false; } int BS(int n) { if (n == 1) return strlen(s[0]); int h = 0, t = strlen(s[0]) + 1; while(h <= t) { memset(p,0,sizeof(p)); int fg = 1; int mid = (h + t) >> 1; for (int i = 0; i < strlen(s[0]) - mid + 1; i ++) { if (fg == n) break; else fg = 1; for (int k = 1; k < n; k ++) { for (int j = 0; j < mid; j ++) p[j] = s[0][i + j]; if (kmp(k)) fg++; else { for (int j = 0; j < mid; j ++) p[mid - j - 1] = s[0][i + j]; if (kmp(k)) fg++; else break; } } } if (fg == n) h = mid + 1; else t = mid - 1; } return h - 1; } int main() { int t; scanf("%d",&t); while(t--) { memset(p,0,sizeof(p)); int n; scanf("%d",&n); for (int i = 0; i < n; i ++) scanf("%s",s[i]); printf("%d\n",BS(n)); } return 0; }
♦POJ 2406 Power String (最小周期(重复)子串。加深对next函数的理解)
首先要理解next[]表示的是字符串前缀与后缀的重复程度。
然后记住这个结论:对于一个字符串s,如果len是(len - next[len])的倍数,那么len - next[len]就是s的最小周期子串。
证明一下(解释不太清楚>.<……):如果len是len-next[len]的倍数,假设m = len-next[len] ,那么str[1-m] = str[m-2*m],……,以此类推下去,m肯定是str的最小重复单元的长度。假如len不是len-next[len]的倍数, 如果前缀和后缀重叠,那么最小重复单元肯定str本身了,如果前缀和后缀不重叠,那么str[m-2*m] != str[len-m,len],所以str[1-m] != str[m-2*m] ,最终肯定可以推理出最小重复单元是str本身,因为只要不断递增m证明即可。
还是自己在纸上好好推演一下比较好。
#include <cstdio> #include <cstring> #include <string> using namespace std; const int N = 1000010; int pi[N]; char p[N]; int getpi() { int m = strlen(p); pi[0] = -1; int j = -1; for (int i = 1; i < m; i ++) { while(j > -1 && p[j+1] != p[i]) j = pi[j]; if (p[j+1] == p[i]) j++; pi[i] = j; } int x = m - 1 - pi[m - 1]; if (m % x == 0) return x; else return m; } int main() { while(scanf("%s",p)!=EOF) { if (p[0] == '.') break; int l = strlen(p); printf("%d\n",l / getpi()); } return 0; }
♦HDU 3336 Count the String (KMP+DP)
问题抽象:求所有前缀在字符串中出现的次数。
暴力枚举会达到O(n3)是不行的,枚举前缀然后KMP也会达到O(n2)。当然对于前缀的情况我们应该利用好“前缀数组”------next数组。实际上现在很多题也都不是考KMP而是考next数组的灵活运用。(KMP裸模板题有什么好考的。。。)
我们把问题分成几个子问题来看,用DP解决:f[i]表示以第i位为结尾的字符串匹配数。则sum = ∑f[i] 。
怎么利用next数组呢?我们知道next[i] = j表示串B[1...j] == B[i-j+1...i],那么一部分串(B[i-j+1...i]的后缀串)与前缀的匹配是可以通过j来求出来的,因为相等关系,所以这部分f[i]等价于f[next[i]]。这只是一部分以i结尾的啊,那么以[1...i-j]某处开头、以 i 结尾的串有没有可能呢?答案是不可能的,如果与前缀匹配成功那么next[i]就不是j了(想想是不是~),当然要加上他本身(B[1..i])是整个串前缀的情况,所以得出f[i] = f[next[i]] + 1。然后再算出sum就行了~~~
#include <cstdio> #include <iostream> #include <string> #include <cstring> using namespace std; const int N = 200010; string s; int f[N],pi[N]; void getpi() { int m = s.length(); pi[0] = -1; int j = -1; for (int i = 1; i < m; i ++) { while(j > -1 && s[j+1] != s[i]) j = pi[j]; if (s[j+1] == s[i]) j++; pi[i] = j; } } int ff() { getpi(); int sum = 1; int m = s.length(); f[0] = 1; for (int i = 1; i < m; i ++) { f[i] = f[pi[i]] + 1; sum += f[i]; sum %= 10007; } return sum; } int main() { int t; scanf("%d",&t); while(t--) { int n; scanf("%d",&n); cin>>s; printf("%d\n",ff()%10007); } return 0; }
(未完待续。。。)