1595. 连通两组点的最小成本

这题牵涉到两个状态变量的动态规划,size1和size2两边的点集选择。

最开始我把动态规划状态方程写成了 dp[st | (1 << i) | (1 << j)] = dp[st] + cost[i][j]. 这样的结果是,状态空间不大 O(2^(n+m)), 但是时间复杂度却是 O(n*m*2^(n+m)). 面对题目给的数据量铁定是超时的,不管怎么进行局部优化或者是用Java来重写。

其实虽然这题涉及到了两个状态变量,但其实只需要将一个变量设计成为状态,而另外一个变量设计成为顺序,状态类似 dp[n][st] 这样的。具体到这题目上, 状态方程其实可以是 dp[i][st | st0] = dp[i-1][st] + sum(cost(i, j) for j in st0). 对于size1这边的点我们顺序算法,而对于size2这边的点 我们则可以选择的状态。然后在状态更新的时候,可以考虑使用size2里面的那些点来和i进行匹配。

实现下来,空间是O(n*2^m)),时间是O(m*2^(2*m)). 然后需要做一定的预处理。这里面遍历其余点集的代码很有意思:

class Solution:
    def connectTwoGroups(self, cost: List[List[int]]) -> int:
        n, m = len(cost), len(cost[0])
        inf = 1 << 30
        dp = [[inf] * (1 << m) for _ in range(n + 1)]
        dp[0][0] = 0

        C = [[0] * (1 << m) for _ in range(n)]
        for i in range(n):
            for st in range(1 << m):
                c = 0
                for j in range(m):
                    if (st >> j) & 0x1:
                        c += cost[i][j]
                C[i][st] = c
        # print(C)

        for i in range(n):
            for st in range(1 << m):
                val = dp[i][st]
                # 选择至少一个元素
                for j in range(m):
                    st2 = st | (1 << j)
                    dp[i + 1][st2] = min(dp[i + 1][st2], val + cost[i][j])

                # 尝试多个元素去匹配,但是如果已经选择的话就不需要在选择了
                x = R = (1 << m) - 1 - st
                while x:
                    c = C[i][x]
                    st2 = st | x
                    dp[i + 1][st2] = min(dp[i + 1][st2], val + c)
                    x = (x - 1) & R
        ans = dp[n][(1 << m) - 1]
        return ans

在评论区里面还有另外一个解法,我觉得也挺有意思的,而且更加高效。我们不是每次从size2里面选择一个点集来覆盖,而只是选择一个点来覆盖。 这样求解得到最后的结果是,覆盖完成了size1里面所有点的最小代价,但结果可能size2并没有完全覆盖完成。没关系,对于那些没有覆盖完成的点, 我们只选择cost最小的连接就行。空间复杂度是O(n*2^m), 但是时间复杂度缩减到了O(n*m*2^m).

class Solution:
    def connectTwoGroups(self, cost: List[List[int]]) -> int:
        n, m = len(cost), len(cost[0])
        inf = 1 << 30
        dp = [[inf] * (1 << m) for _ in range(n + 1)]
        C = [min(cost[j][i] for j in range(n)) for i in range(m)]
        dp[0][0] = 0

        for i in range(n):
            for st in range(1 << m):
                val = dp[i][st]
                # 选择至少一个元素, 确保i匹配上
                for j in range(m):
                    st2 = st | (1 << j)
                    dp[i + 1][st2] = min(dp[i + 1][st2], val + cost[i][j])

        ans = inf
        for st in range(1 << m):
            c = 0
            for i in range(m):
                if (st >> i) & 0x1: continue
                c += C[i]
            ans = min(ans, dp[n][st] + c)
        return ans