LC 3257. 放三个车的价值之和最大 II

https://leetcode.cn/problems/maximum-value-sum-by-placing-three-rooks-ii/description/

这题最开始我使用的是剪枝,从上到下进行选择:

使用普通的剪枝方法还过来。然后还有一个优化就是每一行里面只考虑最大的3个值,这样的话可以让剪枝更加高效。代码比较长

class Solution:
    def maximumValueSum(self, board: List[List[int]]) -> int:
        n, m = len(board), len(board[0])
        MAX_VAL = [[0] * m for _ in range(n)]
        for i in reversed(range(n)):
            for j in range(m):
                MAX_VAL[i][j] = board[i][j]
                if (i + 1) < n:
                    MAX_VAL[i][j] = max(MAX_VAL[i][j], MAX_VAL[i + 1][j])

        MAX_COL = [[] for _ in range(n)]
        COLS = set()
        for i in range(n):
            js = list(range(m))
            js.sort(key=lambda x: board[i][x], reverse=True)
            MAX_COL[i] = js[:3]
            COLS.update(js[:3])

        COLS = list(COLS)

        import functools
        @functools.cache
        def getmax(st, i, k):
            xs = []
            for j in COLS:
                if st & (1 << j) == 0:
                    xs.append(MAX_VAL[i][j])
            xs.sort()
            return sum(xs[-k:])

        inf = 1 << 63
        ans = -inf

        def dfs(st, i, k, now):
            nonlocal ans
            if i == n: return
            bound = getmax(st, i, k)
            if k == 1:
                ans = max(ans, bound + now)
                return
            # cut branch
            if now + bound <= ans:
                return

            for ii in range(i, n):
                for j in MAX_COL[ii]:
                    if st & (1 << j): continue
                    v = board[ii][j]
                    dfs(st | (1 << j), ii + 1, k - 1, now + v)

        dfs(0, 0, 3, 0)
        return ans

题解里面给的方法更加简单:对于这种3个的题目,可以枚举中间的点,同时更新上下两个点。上下两个点也是只需要维护最大的3个位置就好了。

class Solution:
    def maximumValueSum(self, board: List[List[int]]) -> int:
        n, m = len(board), len(board[0])

        MAX_VAL = [[0] * m for _ in range(n)]
        for i in reversed(range(n)):
            for j in range(m):
                MAX_VAL[i][j] = board[i][j]
                if (i + 1) < n:
                    MAX_VAL[i][j] = max(MAX_VAL[i][j], MAX_VAL[i + 1][j])

        DOWN = MAX_VAL

        MAX_VAL = [[0] * m for _ in range(n)]
        for i in range(n):
            for j in range(m):
                MAX_VAL[i][j] = board[i][j]
                if (i - 1) >= 0:
                    MAX_VAL[i][j] = max(MAX_VAL[i][j], MAX_VAL[i - 1][j])
        UP = MAX_VAL

        ans = -(1 << 63)
        for i in range(1, n - 1):
            up = list(range(m))
            up.sort(key=lambda x: UP[i - 1][x], reverse=True)
            up = up[:3]
            down = list(range(m))
            down.sort(key=lambda x: DOWN[i + 1][x], reverse=True)
            down = down[:3]

            for j in range(m):
                for a in up:
                    for b in down:
                        if a == j or b == j or a == b: continue
                        r = board[i][j] + UP[i - 1][a] + DOWN[i + 1][b]
                        ans = max(ans, r)
        return ans