算法详解
很长时间内都没有能够很理解KMP算法的精髓,尤其是很多书上包括《算法导论》没有把next函数(亦或 π函数)讲解的很透彻。
今天去看了
matrix67大牛博客中关于kmp部分的讲解,有点儿醍醐灌顶的感觉,当然也只是理解了一点浅层次罢了。
我尝试着用自己的语言说一下自己的理解,顺便锻炼一下自己渣一般的逻辑组织能力。。。。。。
下面开始正题吧~~~
我们知道单模字符串匹配基本就是三种方法:
一、朴素枚举。最坏时间复杂度O(mn)。
二、Rabin-Karp。需要O(m)的预处理。虽然最坏时间复杂度也是O(mn),但出现最坏情况的几率比朴素法小很多,所以这种方法实际应用还是比较广泛的。
三、
Knuth-Morris-Pratt。即
KMP算法。O(m)的预处理时间,O(n)的匹配时间,非常高效。
先从朴素枚举法说起吧。枚举法就是从字符串的第一位开始,把每一位都作为开头来与模式串逐位匹配一次,如果匹配失败则开始下一位做开头试,直到找到匹配为止。
i = 1 2 3 4
5 6 ……
A = a a a b
a a …
B = a a a b
a c b
j = 1 2 3 4
5 6 7
我们知道,
要优化一个算法就要知道它哪里做了多余的事情。从上面情况来看,在i == j == 6时当前匹配失败,朴素法会把模式串试着与i == 2做开头匹配,但是从我们前面已经得到的信息已经可以得出,i == 2做开头是不可能匹配成功的,所以朴素枚举在这里就做了无用功(也可以说 i 指针做了无用的回溯)。(而且我们要发现在某种坏的情况下它做了相当多的无用功!比如这种情况:000000000000000000000001,模式串00000000001)
那么我们应该想到一个优化了。抽象一下就是,当我们前面已经匹配A[i-j..i-1] == B[1..j],而遇到A[i] != B[j+1]时,我们可以快速(O(1)时间内)找到一个新的j,使得新的A[i-j..i-1] == B[1..j],并且A[i] == B[j+1],这样我们的 i 指针就可以不动,紧接着向下匹配就好了。为了能够快速找到这样的j,我们使用一个辅助的next[]数组(π数组),next[j]记录当B[j+1] != A[i]时,新的j的位置。
这样我们的KMP算法的大致代码就是:
[cpp]
string T,P;
bool KMP()
{
bool flag = false;
int n = T.length();
int m = P.length();
int j = -1;
for (int i = 0; i < n; i ++)
{
while(j > -1 && B[j+1] != A[i]) j = next[j];
if (B[j+1] == A[i]) j ++;
if (j == m - 1)
{
flag = true; //匹配成功
break;
//j = next[j];
}
}
return flag;
}[/cpp]
前面说过,next[]的值是由前面匹配时已经得到的信息得出来的,那么我们发现next[]的值只与模式串有关。(因为前面已经匹配的地方A、B串是一样的,所以只由一个B串是可以得到信息的),这样我们就可以通过对模式串(B串)进行预处理来求出next[]。
那么怎么样根据模式串求出next[]呢?举个例子来说明:
i = 1 2 3 4
5 6 ……
A =
a b a b a b …
B =
a b a b a c b
j = 1 2 3 4 5 6 7
当i == 6 ,j == 5时A[i] != B[j+1],那么这时我们应该调整j了,手动调整一下可以发现新的j == 3,即next[5] = 3。
i = 1 2 3 4 5 6 ……
A = a b
a b a b …
B =
a b a b
a c b
j = 1 2 3 4
5 6 7
发现没有,新的j'要满足情况,则需要满足B[1..j‘]与B[j-j'+1..j]匹配。
那么我们可以得出一个朴素的O(m
2)的求next[]的方法了。对于每一个j,枚举j',找出符合条件的最大的j'就行了。但是明显这个算法是不够优的(如果m很大怎么办?)
实际上上面一句话已经启示我们到底该怎么求next[]了,“满足B[1..j']与B[j-j'+1..j]匹配”,这就是一个模式串自身匹配的过程啊!
我们看个求next[](这里是π[])的图:
我们看看next[j]可不可以由next[1..j]的信息得出。假设当前要求next[6]。由j == next[5] == 3可得到,B[1..3] == B[3..5],而此时B[4] != B[6],所以此时匹配失败,需要回溯重新匹配了。但我们不想回溯啊!前面说了回溯是做无用功啊!那么我们是不是要找一个新的j'使得B[1..j'] == B[5-j'+1..5]。那么怎么找呢?由B[1..3] == B[3..5]可以知道我们可以这么求:B[1..j'] == B[3-j'+1..3](想想为什么可以这样,可以看看上面那个图),那不就是next[j]了么~~~
写一下代码来加深理解:
[cpp]
void GetNext()
{
next[0] = -1;
int j = -1;
int m = p.length();
for (int i = 1; i < m; i ++)
{
while(j > -1 && p[j+1] != p[i]) j = next[j];
if (p[j+1] == p[i]) j++;
next[i] = j;
}
}
[/cpp]
这样我们的KMP算法就描述完了~~~
专题训练:
♦
POJ 3461 Oulipo(入门题)
光看样例就一脸纯模板题的样儿。。。。。。
#include
#include
#include
#include
using namespace std;
const int N = 1000100;
string A, B;
int pi[N];
void getp()
{
memset(pi,0,sizeof(pi));
int m = B.length();
pi[0] = -1;
int j = -1;
for (int i = 1;i < m; i ++)
{
while(j > -1 &&B[j+1] != B[i]) j = pi[j];
if (B[j+1] == B[i]) j++;
pi[i] = j;
}
}
int kmp()
{
int res = 0;
getp();
int n = A.length();
int m = B.length();
int j = -1;
for (int i = 0; i < n; i ++)
{
while(j > -1 && A[i] !=B[j+1]) j = pi[j];
if (A[i] == B[j+1]) j++;
if (j == m - 1)
{
res ++;
j = pi[j];
}
}
return res;
}
int main()
{
int tt;
scanf("%d",&tt);
while (tt--)
{
cin>>B;
cin>>A;
cout<
♦POJ 1226 Substrings (最长公共子串。KMP + 二分)
数据范围很小,随便搞吧?=。=。其实就是二分答案长度,枚举出该长度的每一个串,然后用KMP验证。总复杂度O(n*len2*log(len))吧,可以接受的。
但是我怎么可以这么弱?各种小错误啊调试了2个小时了吧靠靠靠靠这么个水题。。。。。。
#include
#include
#include
#include
using namespace std;
char s[110][110];
char p[110];
int pi[110];
void getpi()
{
int m = strlen(p);
int j = -1;
pi[0] = -1;
for (int i = 1; i < m; i ++)
{
while(j > -1 && p[j+1] != p[i]) j = pi[j];
if (p[j+1] == p[i]) j++;
pi[i] = j;
}
}
bool kmp(int x)
{
getpi();
int n = strlen(s[x]);
int m = strlen(p);
int j = -1;
for (int i = 0; i < n; i ++)
{
while(j > -1 && s[x][i] != p[j+1]) j = pi[j];
if (s[x][i] == p[j+1]) j++;
if (j == m - 1)
{
return true;
}
}
return false;
}
int BS(int n)
{
if (n == 1) return strlen(s[0]);
int h = 0, t = strlen(s[0]) + 1;
while(h <= t)
{
memset(p,0,sizeof(p));
int fg = 1;
int mid = (h + t) >> 1;
for (int i = 0; i < strlen(s[0]) - mid + 1; i ++)
{
if (fg == n) break;
else fg = 1;
for (int k = 1; k < n; k ++)
{
for (int j = 0; j < mid; j ++)
p[j] = s[0][i + j];
if (kmp(k)) fg++;
else
{
for (int j = 0; j < mid; j ++)
p[mid - j - 1] = s[0][i + j];
if (kmp(k)) fg++;
else break;
}
}
}
if (fg == n)
h = mid + 1;
else t = mid - 1;
}
return h - 1;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
memset(p,0,sizeof(p));
int n;
scanf("%d",&n);
for (int i = 0; i < n; i ++)
scanf("%s",s[i]);
printf("%d\n",BS(n));
}
return 0;
}
♦POJ 2406 Power String (最小周期(重复)子串。加深对next函数的理解)
首先要理解next数组表示的是字符串前缀的“对称”程度。
然后记住这个结论:对于一个字符串s,如果len是(len - next[len])的倍数,那么len - next[len]就是s的最小周期子串。
证明一下(解释不太清楚>.<……):如果len是len-next[len]的倍数,假设m = len-next[len] ,那么str[1-m] = str[m-2*m],……,以此类推下去,m肯定是str的最小重复单元的长度。假如len不是len-next[len]的倍数, 如果前缀和后缀重叠,那么最小重复单元肯定str本身了,如果前缀和后缀不重叠,那么str[m-2*m] != str[len-m,len],所以str[1-m] != str[m-2*m] ,最终肯定可以推理出最小重复单元是str本身,因为只要不断递增m证明即可。
还是自己在纸上好好推演一下比较好。
#include
#include
#include
using namespace std;
const int N = 1000010;
int pi[N];
char p[N];
int getpi()
{
int m = strlen(p);
pi[0] = -1;
int j = -1;
for (int i = 1; i < m; i ++)
{
while(j > -1 && p[j+1] != p[i]) j = pi[j];
if (p[j+1] == p[i]) j++;
pi[i] = j;
}
int x = m - 1 - pi[m - 1];
if (m % x == 0)
return x;
else return m;
}
int main()
{
while(scanf("%s",p)!=EOF)
{
if (p[0] == '.') break;
int l = strlen(p);
printf("%d\n",l / getpi());
}
return 0;
}
♦HDU 3336 Count the String (KMP+DP)
问题抽象:求所有前缀在字符串中出现的次数。
暴力枚举会达到O(n3)是不行的,枚举前缀然后KMP也会达到O(n2)。当然对于前缀的情况我们应该利用好“前缀数组”------next数组。实际上现在很多题也都不是考KMP而是考next数组的灵活运用。(KMP裸模板题有什么好考的。。。)
我们把问题分成几个子问题来看,用DP解决:f[i]表示以第i位为结尾的字符串匹配数。则sum = ∑f[i] 。
怎么利用next数组呢?我们知道next[i] = j表示串B[1...j] == B[i-j+1...i],那么一部分串(B[i-j+1...i]的后缀串)与前缀的匹配是可以通过j来求出来的,因为相等关系,所以这部分f[i]等价于f[next[i]]。这只是一部分以i结尾的啊,那么以[1...i-j]某处开头、以 i 结尾的串有没有可能呢?答案是不可能的,如果与前缀匹配成功那么next[i]就不是j了(想想是不是~),当然要加上他本身(B[1..i])是整个串前缀的情况,所以得出f[i] = f[next[i]] + 1。然后再算出sum就行了~~~
#include
#include
#include
#include
using namespace std;
const int N = 200010;
string s;
int f[N],pi[N];
void getpi()
{
int m = s.length();
pi[0] = -1;
int j = -1;
for (int i = 1; i < m; i ++)
{
while(j > -1 && s[j+1] != s[i]) j = pi[j];
if (s[j+1] == s[i]) j++;
pi[i] = j;
}
}
int ff()
{
getpi();
int sum = 1;
int m = s.length();
f[0] = 1;
for (int i = 1; i < m; i ++)
{
f[i] = f[pi[i]] + 1;
sum += f[i];
sum %= 10007;
}
return sum;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int n;
scanf("%d",&n);
cin>>s;
printf("%d\n",ff()%10007);
}
return 0;
}
(未完待续。。。)