在介绍KMP算法之前先引入一个问题:字符串匹配问题
现在有一个文本串 $S$ 和一个模式串 $P$ ,查找P在S中出现的位置(用第一个字母位置表示)
朴素字符串匹配算法
有以下一个非常简单的算法,枚举模式串 $P$ 的第一个字母在文本串 $S$ 中出现位置,然后依次检查模式串 $P$ 每一个字母与文本串 $S$ 对应位置是否匹配
不难发现,这样的算法时间复杂度是 $O(MN)$ 的
KMP算法
KMP算法就是用来解决单字符串匹配问题的一种算法,其时间复杂度为 $O(M+N)$
KMP算法较朴素字符串匹配算法不同的是引入了失配数组 $Next$ 的概念
我们观察朴素字符串匹配算法,不难发现,每一次失配后,模式串仅仅是向后移动一位,然后重新开始从头匹配,这就使得算法十分低效,那么有什么办法可以直接跳过一部分显然失配的字母呢?看一个KMP匹配的例子
此时模式串 $P$ 在检查最后一个字母时失配
KMP算法不会朴素的将模式串 $P$ 向后移动一位
不难发现,在这次失配过程中,文本串对应匹配成功的字母是 $ABCDAB$ ,我们可以直接跳到下一个 $AB$ 处,直接与模式串前2位对应完成,跳过中间必定失配的字母。
这就是KMP算法的主要思想,那我们先来看一下KMP算法流程
假设现在文本串 $S$ 匹配到 $i$ 位置,模式串 $P$ 匹配到 $j$ 位置
如果 $j == -1$ ,或者当前字符匹配成功(即 $S[i] == P[j]$ ),都令 $i++$ , $j++$ ,继续匹配下一个字符
如果 $j \neq -1$ ,且当前字符匹配失败(即 $S[i] \neq P[j]$ ),则令 $i$ 不变,$j = next[j]$。此举意味着失配时,模式串$P$ 相对于文本串S向右移动了 $j – next [j]$ 位。
现在关键问题转变为如何求出 $next$ 数组,实际上这是由模式串 $P$ 的最长前缀后缀元素决定的
寻找最长前缀后缀
此时,如果根据最长前缀后缀表进行失配跳转,则失配时,模式串向右移动的位数为:
由最长前缀后缀得到 $next$ 数组
但是有一个很不优美的地方,减去的是上一位字符对应的最大长度,那么Next数组不就是最长前缀后缀表元素向右平移一个单位,第一个值令为-1吗
下面介绍具体实现过程
对于模式串 $P$ 的前 $j+1$ 个序列字符:
若 $p[k] == p[j]$ ,则 $next[j+1] = next[j] + 1 = k + 1 $
若 $p[k] \neq p[j]$ ,如果此时 $p[next[k]] == p[j]$ ,则 $next[j+1] = next[k] + 1$ ,否则继续递归前缀索引$k = next[k]$,而后重复此过程。
代码
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<vector>
#define LL long long
using namespace std;
inline int read(){
int x = 0;
int p = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-')
p = -1;
ch = getchar();
}
while('0' <= ch && ch <= '9'){
x = x*10 + ch - '0';
ch = getchar();
}
return x*p;
}
const int MAXN = 1000000 + 10;
int T;
int n,m;
int s1[MAXN];
int s2[MAXN];
int Next[MAXN];
inline void getNext(){
int j = 0, k;
Next[0] = k = -1;
while(j < m){
if(k==-1 || s2[j] == s2[k]){
j++;
k++;
Next[j] = k;
}else
k = Next[k];
}
}
inline int KMP(){
int i=0;
int j=0;
while(i<n && j<m){
if(j==-1 || s1[i] == s2[j]){
i++;
j++;
}else
j = Next[j]; //失配
if(j==m)
return i-m+1; //匹配到模式串最后一位,成功
}
return -1;
}
int main(){
T = read();
while(T--){
memset(Next,0,sizeof(Next));
memset(s1,0,sizeof(s1));
memset(s2,0,sizeof(s2));
n = read();
m = read();
for(int i=0; i<n; i++)
s1[i] = read();
for(int i=0; i<m; i++)
s2[i] = read();
getNext();
printf("%d\n",KMP());
}
return 0;
}