目录

使用Rabin-Karp算法替代KMP


Rabin-Karp算法

Rabin-Karp算法是利用滑动窗口的方式来计算哈希值,并通过该哈希值直接进行比较来判断字符串是否匹配,也就是消减了字符串比较的过程来实现 O(n) 的时间复杂度。

废话不多说了,直接上干货。

首先,我问你一个很基础的问题,给你输入一个字符串形式的正整数,如何把它转化成数字的形式?很简单,下面这段代码就可以做到:

string s = "8264";
int number = 0;
for (int i = 0; i < s.size(); i++) {
    // 将字符转化成数字
    number = 10 * number + (s[i] - '0');
    print(number);
}
// 打印输出:
// 8
// 82
// 826
// 8264

可以看到这个算法的核心思路就是不断向最低位(个位)添加数字,同时把前面的数字整体左移一位(乘以 10)。

为什么是乘以 10?因为我们默认探讨的是十进制数。这和我们操作二进制数的时候是一个道理,左移一位就是把二进制数乘以 2,右移一位就是除以 2。

上面这个场景是不断给数字添加最低位,那如果我想删除数字的最高位,怎么做呢?比如说我想把 8264 变成 264,应该如何运算?其实也很简单,让 8264 减去 8000 就得到 264 了。

这个 8000 是怎么来的?是 8 x 10^3 算出来的。8 是最高位的数字,10 是因为我们这里是十进制数,3 是因为 8264 去掉最高位后还剩三位数。

上述内容主要探讨了如何在数字的最低位添加数字以及如何删除数字的最高位,用R表示数字的进制数,用L表示数字的位数,就可以总结出如下公式:

/* 在最低位添加一个数字 */
int number = 8264;
// number 的进制
int R = 10;
// 想在 number 的最低位添加的数字
int appendVal = 3;
// 运算,在最低位添加一位
number = R * number + appendVal;
// 此时 number = 82643

/* 在最高位删除一个数字 */
int number = 8264;
// number 的进制
int R = 10;
// number 最高位的数字
int removeVal = 8;
// 此时 number 的位数
int L = 4;
// 运算,删除最高位数字
number = number - removeVal * R^(L-1);
// 此时 number = 264

如果你能理解这两个公式,那么 Rabin-Karp 算法就没有任何难度,算法就是这样,再高大上的技巧,都是在最简单最基本的原理之上构建的。不过在讲 Rabin-Karp 算法之前,我们先来看一道简单的力扣题目。

高效寻找重复子序列

看下力扣第 187 题「重复的 DNA 序列」,我简单描述下题目:

DNA 序列由四种碱基A, G, C, T组成,现在给你输入一个只包含A, G, C, T四种字符的字符串s代表一个 DNA 序列,请你在s中找出所有重复出现的长度为 10 的子字符串。

比如下面的测试用例:

输入:s = "AAAAACCCCCAAAAACCCCCCAAAAAGGGTTT"
输出:["AAAAACCCCC","CCCCCAAAAA"]
解释:子串 "AAAAACCCCC"  "CCCCCAAAAA" 都重复出现了两次。

输入:s = "AAAAAAAAAAAAA"
输出:["AAAAAAAAAA"]

函数签名如下:

vector<string> findRepeatedDnaSequences(string s);

这道题的拍脑袋解法比较简单粗暴,我直接穷举所有长度为 10 的子串,然后借助哈希集合寻找那些重复的子串就行了,代码如下:

// 暴力解法
vector<string> findRepeatedDnaSequences(string s) {
    int n = s.size();
    // 记录出现过的子串
    unordered_set<string> seen;
    // 记录那些重复出现多次的子串
    // 注意要用哈希集合,防止记录重复的结果
    unordered_set<string> res;

    for (int i = 0; i + 10 <= n; i++) {
        auto subStr = s.substr(i, 10);
        if (seen.count(subStr)){
            // 之前出现过,找到一个重复的
            res.insert(subStr);
        } else {
            // 之前没出现过,加入集合
            seen.insert(subStr);
        }
    }
    return {res.begin(),res.end()};
}

这个算法肯定是没问题的,只是时间复杂度略高。假设s的长度为N,目标子串的长度为L(本题L = 10),for 循环遍历sO(N)个字符,对每个字符都要截取长度为L的子字符串,所以这个算法的时间复杂是O(NL)

遍历整个s肯定是免不了的,问题是我们能不能不要每次都调用substr去截取子字符串

你注意我们这个匹配过程实际上就是维护了一个长度为L = 10的定长窗口在从左向右滑动,是否可以使用滑动窗口中的做法,只维护left, right指针来划定子字符串区间?

其实可以的,直接套用前文给出的滑动窗口算法框架写出伪码思路:

int L = 10;
map<string> seen;

// 滑动窗口代码框架
CharWindow window;
int left = 0, right = 0;
while (right < s.size()) {
    // 扩大窗口,移入字符
    window.addRight(s[right]);
    right++;
    
    // 当子串的长度达到要求
    if (right - left == L) {
        // 把窗口中的字符变成字符串,还是需要 O(L) 的时间
        String windowStr = window.toString();
        if (seen.contains(windowStr)) {
            print("找到一个重复子串: ", windowStr);
        } else {
            seen.add(windowHash);
        }

        // 缩小窗口,移出字符
        window.removeLeft();
        left++;
    }
}

这段伪码直接套用了滑动窗口算法框架,你应该不难理解。但你注意这个解法依然需要将窗口中的字符转化成字符串然后去seen集合判断是否存在重复,你一旦想把字符转化成字符串,就难免需要O(L)的时间来操作。所以这个算法的时间复杂度还是没有降低,依然是O(NL)

所以优化的关键在于,我们能不能不要真的把子字符串生成出来,而是用一些其他形式的唯一标识来表示滑动窗口中的子字符串,并且还能在窗口滑动的过程中快速更新

有办法的,回想一下本文开头我们讨论的那两个公式,现在你应该明白的用意了。

你把AGCT四种字符等价为0123四个数字,那么长度为L = 10的一个碱基序列其实就可以等价为一个十位数,这个数字可以唯一标识一个子串。而且窗口移动的过程,其实就是给这个数字的最低位添加数字,并删除最高位数字的过程,回顾之前的讲解,添加和删除数字的运算就是两个公式,可以在O(1)的时间完成。

然后,我们不要在哈希集合中直接存储子串了,而是存储子串对应的十位数。因为一个十位数可以唯一标识一个子串,所以也可以起到识别重复的作用。

这样,我们就避免了直接生成子串存入集合,而是生成一个十位数来表示子串,而且生成这个十位数的时间花费为O(1),从而降低了匹配算法的时间复杂度。

其实你想下,你把一个字符串对象转化成了一个数字,这是什么?这就是你设计的一个哈希算法,生成的数字就可以认为是字符串的哈希值。在滑动窗口中快速计算窗口中元素的哈希值,叫做滑动哈希技巧

上述优化思路的伪码思路如下:

int L = 10;
// 集合中不要存储字符串了,而是存储字符串对应的哈希值
HashSet<Integer> seen;

// 滑动窗口代码框架
CharWindow window;
int left = 0, right = 0;
while (right < s.size()) {
    // 扩大窗口,移入字符
    window.addRight(s[right]);
    right++;
    
    // 当子串的长度达到要求
    if (right - left == L) {
        // 获取当前窗口内字符串的哈希值,时间 O(1)
        int windowHash = window.hash();
        // 根据哈希值判断是否曾经出现过相同的子串
        if (seen.contains(windowHash)) {
            // 当前窗口中的子串是重复出现的
            print("找到一个重复子串: ", window.toString());
        } else {
            // 当前窗口中的子串之前没有出现过,记下来
            seen.add(windowHash);
        }

        // 缩小窗口,移出字符
        window.removeLeft();
        left++;
    }
}

进一步,我们用一个 10 位数来标识一个长度为 10 的碱基字符序列,这需要 long 类型存储,int 存不下 10 位数。但你注意这个 10 位数中所有的数字只会局限于0,1,2,3,是不是有些浪费?

换句话说,我们需要存储的其实只是一个四进制下的十位数(共包含 4^10 个数字),却用了十进制的十位数(可以包含 10^10 个数字)来保存,显然是有些浪费的。

因为 4^10 = 1048576 < 10^8,所以只要我们在四进制的运算规则下进行运算,十进制的八位数就能存下,这样的话 int 类型就够用了,不需要 long 类型。

具体来说,只要改变我们之前那两个公式的进制R就行了:

/* 在最低位添加一个数字 */
// number 的进制
int R = 4;
// 想在 number 的最低位添加的数字
int appendVal = 0~3 中的任意数字;
// 运算,在最低位添加一位
number = R * number + appendVal;

/* 在最高位删除一个数字 */
// number 的进制
int R = 4;
// number 最高位的数字
int removeVal = 0~3 中的任意数字;
// 此时 number 的位数
int L = ?;
// 运算,删除最高位数字
number = number - removeVal * R^(L-1);

结合数字最高/最低位的处理技巧和滑动窗口代码框架,我们就可以轻松地写出最终的解法代码:

vector<string> findRepeatedDnaSequences(string s) {
    auto getNumber = [](char x)->int{
        switch(x){
            case 'A':
                return 0;
            case 'G':
                return 1;
            case 'C':
                return 2;
            case 'T':
               return 3;
        }
        return -1;
    }
    // 记录重复出现的哈希值
    unordered_set<int> seen;
    // 记录重复出现的字符串结果
    unordered_set<string> res;
    
    // 数字位数
    int L = 10;
    // 进制
    int R = 4;
    // 存储 R^(L - 1) 的结果
    int RL = (int) pow(R, L - 1);
    // 维护滑动窗口中字符串的哈希值
    int windowHash = 0;
    
    // 滑动窗口代码框架,时间 O(N)
    int left = 0, right = 0;
    while (right < nums.length) {
        // 扩大窗口,移入字符,并维护窗口哈希值(在最低位添加数字)
        windowHash = R * windowHash + getNumber(s[right++]);

        // 当子串的长度达到要求
        if (right - left == L) {
            // 根据哈希值判断是否曾经出现过相同的子串
            if (seen.count(windowHash)) {
                // 当前窗口中的子串是重复出现的
                res.insert(s.substr(left, L));
            } else {
                // 当前窗口中的子串之前没有出现过,记下来
                seen.insert(windowHash);
            }
            // 缩小窗口,移出字符,并维护窗口哈希值(删除最高位数字)
            windowHash = windowHash - nums[left] * RL;
            left++;
        }
    }
    // 转化成题目要求的 List 类型
    return {res.begin(),res.end()};
}

滑动窗口算法本身的时间复杂度是O(N),而窗口滑动的过程中的操作耗时都是O(1)(给res添加子串的过程不算),所以整体的时间复杂度是O(N),这样就把算法的复杂度降低到线性级别了。

Rabin-Karp 算法

有了上面由浅入深的铺垫,你理解 Rabin-Karp 算法就非常容易了,因为上面这道题目的本质就是一个字符串匹配的问题。

字符串匹配算法大家都很熟悉,让你在文本串txt中搜索模式串pat的起始索引,暴力字符串匹配算法是这样的:

// 在文本串 txt 中搜索模式串 pat 的起始索引
int search(string txt, string pat) {
    int N = txt.size(), L = pat.size();

    for (int i = 0; i + L <= N; i++) {
        auto subStr = txt.substr(i,  L);
        if (subStr == pat){
            // 在 txt 中找到模式串 pat,返回起始索引
            return i;
        }
    }
    // txt 中不存在模式串 pat
    return -1;
}

你可以发现,这个逻辑和上面讲的那道题的暴力解法非常类似,总的时间复杂度是O(LN),优化的核心也是子串subStr和模式串pat匹配的部分。

那么借鉴上面的思路,我们不要每次都去一个字符一个字符地比较子串和模式串,而是维护一个滑动窗口,运用滑动哈希算法一边滑动一边计算窗口中字符串的哈希值,拿这个哈希值去和模式串的哈希值比较,这样就可以避免截取子串,从而把匹配算法降低为O(N),这就是 Rabin-Karp 指纹字符串查找算法的核心逻辑

那你可能会问,刚才我们处理的题目给你输入的只有AGCT四种字符,所以可以转化成数字,但面对五花八门的字符串,如何把他们转化成数字计算哈希值呢?其实很简单,字符本质上就是编码,而编码其实就是数字。

比方说以 ASCII 码为例,ASCII 码其实就是 0~255 这 256 个数字,分别对应所有英文字符和英文符号。那么一个长度为L的 ASCII 字符串,我们就可以等价理解成一个L位的 256 进制的数字,这个数字就可以唯一标识这个字符串,也就可以作为哈希值。

有了这个想法,我们就可以直接复制粘贴上一道题的大部分代码,写出 Rabin-Karp 算法的主要逻辑:

// 文本串
string txt;
// 模式串
string pat;


// 需要寻找的子串长度为模式串 pat 的长度
int L = pat.size();
// 仅处理 ASCII 码字符串,可以理解为 256 进制的数字
int R = 256;
// 存储 R^(L - 1) 的结果
int RL = (int) pow(R, L - 1);
// 维护滑动窗口中字符串的哈希值
int windowHash = 0;
// 计算模式串的哈希值
long patHash = 0;
for (int i = 0; i < pat.size(); i++) {
    patHash = R * patHash + pat[i];
}

// 滑动窗口代码框架
int left = 0, right = 0;
while (right < txt.length()) {
    // 扩大窗口,移入字符(在最低位添加数字)
    windowHash = R * windowHash + txt[right];
    right++;

    // 当子串的长度达到要求
    if (right - left == L) {
        // 根据哈希值判断窗口中的子串是否匹配模式串 pat
        if (patHash == windowHash) {
            // 找到模式串
            print("找到模式串,起始索引为", left);
            return left;
        }
        
        // 缩小窗口,移出字符(删除最高位数字)
        windowHash = windowHash - txt[left] * RL;
        left++;
    }
}
// 没有找到模式串
return -1;

相对上一道题的解法,这段代码将进制数R改为了 256,同时计算了模式串pat的哈希值patHash用来和窗口中字符串的哈希值windowHash做对比,以便判断是否找到了模式串,其他的代码部分完全相同。

不过呢,这段代码实际运行的时候会有一个严重的问题,那就是整型溢出

你想,上一道题给定的 DNA 序列字符串只包含AGCT四种字符,所以我们可以把 DNA 序列抽象成四进制的数字,即算法中R = 4。相同位数下,四进制包含的数字数量是小于十进制包含的数字数量的。比方说L = 10时,4^10 = 1048576 < 10^8,即 10 位四进制数字用 8 位十进制数字就可以存下了。

但现在输入为 ASCII 码字符串,我们不得不把字符串抽象成 256 进制的数字,即算法中R = 256。而相同位数下,256 进制包含的数字数量显然是远大于十进制包含的数字数量的。比如L = 10时,256^10 = 1.2 x 10^24 < 10 ^25,所以你需要一个 25 位的十进制数才能表示一个 10 位的 256 进制数。

可想而知,如果你真的把字符串对应的 256 进制数字的十进制表示作为该字符串的哈希值,那恐怕L稍微大一点,像RL, windowHash, patHash这些变量就超级大了,long 类型都远远装不下。

所以解决办法是什么呢?如何把一个很大的数字映射到一个较小的范围内呢?答案是求模(余数)

无论一个数字多大,你让它除以Q,余数一定会落在[0, Q-1]的范围内。所以我们可以设置一个Q,用求模的方式让windowHashpatHash保持在[0, Q-1]之间,就可以有效避免整型溢出。

具体来说,对于一个字符串,我们不需要把完整的 256 进制数字存下来,而是对这个巨大的 256 进制数求Q的余数,然后把这个余数作为该字符串的哈希值即可。

好,整型溢出的问题倒是解决了,但新的问题又来了:求模之后的哈希值不能和原始字符串一一对应了,可能出现一对多的情况,即哈希冲突

比方说 10 % 7 等于 3,而 17 % 7 也等于 3,所以如果你得到余数 3,你能确定原始数字就一定是 10 么?不能。

类似的,如果你发现windowHash == patHash,你也不敢完全肯定窗口中的字符串一定就和模式串pat匹配,有可能它俩不匹配,但恰好求模算出来的哈希值一样,这就产生了是「哈希冲突」。

对于 Rabin-Karp 算法来说,当发现windowHash == patHash时,使用暴力匹配算法检查一下窗口中的字符串和pat是否相同就可以避免哈希冲突了。因为希冲突出现的概率比较小,所以偶尔用一下暴力匹配算法是不影响总体的时间复杂度的。

明白了这些问题的解决方案,你就能很自然地写出 Rabin-Karp 算法了:

// Rabin-Karp 指纹字符串查找算法
int rabinKarp(string str, string pattern) {
   const int64_t mod = (1ll << 53) - 1;
    auto    getNumber   = [](char x) { return x - 'a'; };
   int     R           = 26;
   int     L           = (int)pattern.size();
   int64_t RL          = 1;
   int64_t patternHash = 0;
   int64_t windowHash  = 0;
   for (int i = 1; i < L; i++) RL = (RL * R) % mod;
   for (int i = 0; i < L; i++)
      patternHash = ((patternHash * R) % mod + getNumber(pattern[i])) % mod;
   int left = 0, right = 0;

   while (right < str.size())
   {
      windowHash = ((windowHash * R) % mod + getNumber(str[right++])) % mod;
      if (right - left == L)
      {
         if (windowHash == patternHash)
         {
            if (str.substr(left, L) == pattern) return left;
         }
         else
         {
            windowHash =
              (windowHash - (getNumber(str[left++]) * RL) % mod + mod) % mod;
         }
      }
   }
   return -1;
}

上述代码可以在力扣28题验证。

有之前那么多铺垫,算法逻辑应该没啥可说的,就说一下模运算的两个运算法则吧:

X % Q == (X + Q) % Q
(X + Y) % Q == (X % Q + Y % Q) % Q

稍微想一想就能理解这两个运算法则,在代码中但凡涉及到乘法和加法,都可能产生很大的结果,所以一有机会就可以运用上述法则对结果进行求模,以避免造成溢出。

Rabin-Karp 算法的时间复杂度是O(N + L)N为文本串txt的长度,L为模式串pat的长度。当然,每次出现哈希冲突时会使用O(L)的时间进行暴力匹配,但考虑到只要Q设置的合理,哈希冲突的出现概率会很小,所以可以忽略不计。

最后说一下这个大素数Q的选择。

为什么要这个Q尽可能大呢?主要是为了降低哈希冲突的概率

因为代码中你把这个Q作为除数,余数(哈希值)一定落在[0, Q-1]之间,所以Q越大,哈希值的空间就越大,就越不容易出现哈希冲突,整个算法的效率就会高一些。

为什么这个Q要是素数呢?依然是为了降低哈希冲突的概率

举个极端一点的例子,你令Q = 100,那么无论一个数X再大,X % Q的结果必然是X的最后两位。换句话说X前面的那些位你根本没利用,可想而知你这个哈希算法存在某些规律性,不够随机,进而更容易导致哈希冲突,降低算法的效率。