LC 8020. 字符串转换

https://leetcode.cn/problems/string-transformation/description/

这题最开始我拆解问题思路是没有问题的:找出循环切分点,并且动态规划计算总数。但是两个子问题的求解方式都有点问题。

题解写的非常好 https://leetcode.cn/problems/string-transformation/solutions/2435348/kmp-ju-zhen-kuai-su-mi-you-hua-dp-by-end-vypf/

找出循环切分点我想使用hash算法来减少字符串匹配。但是hash算法只能快速过滤掉不一致的情况,如果hash value相同的话还需要做一次比较。最坏情况下类似 "aaa…aaa" 和 "aaaa..aaa" 这样的话就有大量的字符串匹配。 正确的方法是使用KMP来加速匹配(这个好像也是我遇到过的少有的题目,必须使用KMP算法来匹配字符串的)

动态规划计算这个想的就更偏了。我最开始尝试的算法是,假设起始点是i, 每次挪动距离可以是[1,n], 需要挪动k次,那么最终点落在0这个位置上有多少种方法。但是显然这种方法计算量巨大,因为这个转移状态太大了。

其实这个状态有点想的太细了,太细了才会导致计算量巨大。如果可以把整个状态做更大的抽象(或者是统一的话),那么这个状态就特别小。

很早之前看过课本上的kmp算法,感觉有点不太好理解,因为里面构建状态是类似backoff的状态,而不是上次match的状态。我觉得题解里面给出的模板挺好的,可以按照这个方式去理解一下。

def mat_mul(a, b, MOD):
    R, K, C = len(a), len(a[0]), len(b[0])
    res = [[0] * C for _ in range(R)]
    for k in range(K):
        for i in range(R):
            for j in range(C):
                res[i][j] += (a[i][k] * b[k][j]) % MOD
                res[i][j] %= MOD
    return res


def FindCutByKMP(s, t):
    n = len(s)
    def compute_max_match(pattern):
        match = [0] * len(pattern)
        c = 0
        for i in range(1, len(pattern)):
            v = pattern[i]
            while c and pattern[c] != v:
                c = match[c - 1]
            if pattern[c] == v:
                c += 1
            match[i] = c
        return match

    def kmp_search(text, pattern):
        match = compute_max_match(pattern)
        match_count = c = 0
        for i, v in enumerate(text):
            v = text[i]
            while c and pattern[c] != v:
                c = match[c - 1]
            if pattern[c] == v:
                c += 1
            if c == len(pattern):
                match_count += 1
                c = match[c - 1]
        return match_count

    cuts = kmp_search(s + s[:-1], t)
    return cuts


def ComputeMM(c, k, s, t):
    # f[i][0] after i operations, s == t
    # f[i][1] after i operations, s!= t

    # f[i][0] = f[i-1][0] * (c-1) + f[i-1][1] * c
    # f[i][1] = f[i-1][0] * (n-c) * f[i-1][1] * (n-1-c)

    # f[0][0] = 1 if s == t
    MOD = 10 ** 9 + 7
    n = len(s)
    base = [[c - 1, c], [n - c, n - 1 - c]]
    eq = 1 if (s == t) else 0
    T = [[eq], [1 - eq]]
    while k:
        if k & 0x1:
            T = mat_mul(base, T, MOD)
        base = mat_mul(base, base, MOD)
        k = k >> 1
    return T[0][0]


class Solution:
    def numberOfWays(self, s: str, t: str, k: int) -> int:
        cuts = FindCutByKMP(s, t)
        if cuts == 0: return 0
        return ComputeMM(cuts, k, s, t)