1074. Number of Submatrices That Sum to Target

https://leetcode.com/problems/number-of-submatrices-that-sum-to-target/

这题给我最大的启发就是子矩阵的遍历模式:

所以一般来说,想要遍历子矩阵求解某些特性的话,时间复杂度都在O(n^2*m)上面,某些问题可以缩减到O(n*m),但是肯定没有办法再缩减了。


具体到这道题目的话,遵循O(n^2*m)的时间复杂度。在处理每一列的时候,使用hashmap去查找匹配元素的个数。

class Solution:
    def numSubmatrixSumTarget(self, matrix: List[List[int]], target: int) -> int:
        n, m = len(matrix), len(matrix[0])
        from collections import Counter

        if n > m:
            matrix = list(zip(*matrix))
            n, m = m, n

        # O(n^2 * m)
        self.ans = 0
        for i in range(n):
            # print('===== {} ===='.format(i))
            dp = [0] * m
            for j in range(i, n):
                tmp = 0
                c = Counter()
                c[tmp] = 1
                for k in range(m):
                    tmp += matrix[j][k]
                    dp[k] += tmp

                    lookup = dp[k] - target
                    self.ans += c[lookup]
                    c[dp[k]] += 1
                # print(dp)
        return self.ans


还有一个类似的题目 https://leetcode.com/problems/maximum-side-length-of-a-square-with-sum-less-than-or-equal-to-threshold/

如果使用上面算法O(n^2m)去套的话会出现超时。这题目解决办法是,使用二分来估算size。

一旦固定好size之后,回到上面的模式,我们可以不用遍历r1了,那么检查算法就是O(nm). 而size的上线是O(min(n,m)), 所以时间复杂度可以是 O(lgn * nm).

其实这题目使用O(n^2m)在n<=300,m<=300上应该也是可以的,但是作者应该还是想规定使用二分吧。

class Solution:
    def maxSideLength(self, mat: List[List[int]], threshold: int) -> int:
        n = len(mat)
        m = len(mat[0])

        if n > m:
            n, m = m, n
            mat = list(zip(*mat))

        dp = [[0] * m for _ in range(n + 1)]
        for i in range(n):
            for j in range(m):
                dp[i + 1][j] = dp[i][j] + mat[i][j]

        def check_size(sz):
            for i in range(n - sz + 1):
                j = i + sz - 1
                # use dp[j+1] - dp[i]
                tmp = [0] * m
                for k in range(m):
                    tmp[k] = dp[j + 1][k] - dp[i][k]

                for k in range(1, m):
                    tmp[k] += tmp[k - 1]

                for k in range(sz - 1, m):
                    # use [k-sz+1 .. k]
                    v = tmp[k - sz] if k - sz >= 0 else 0
                    if (tmp[k] - v) <= threshold:
                        return True
            return False

        s, e = 1, min(n, m)
        ans = 0
        while s <= e:
            mid = (s + e) // 2
            if check_size(mid):
                ans = mid
                s = mid + 1
            else:
                e = mid - 1
        return ans