LC 2836. 在传球游戏中最大化函数值

https://leetcode.cn/problems/maximize-value-of-function-in-a-ball-passing-game/description/

因为要不断地传下去,所以路径上肯定存在环。我最开始的思路就是,针对环进行优化。但是一个问题就是,如果计算每个节点的循环K的距离。

一个方法就是,找到那个循环入口点,然后根据这个循环入口点向前推算:每前进一个,那么就需要减去尾部一个,所以我们需要保存尾部的值。

为了可以比较高效地保存尾部值(同时考虑我们可能已经出环了),那么需要设计一个数据结构,就是我代码中的 State. 有了这个结构然后不断地向前推就好了。

这个代码写正确还真花费了一些时间,主要复杂性都是在这个 State 对象里面。另外就是算法需要找到正确的环,就是 findRoot这个部分。

class Solution:
    def getMaxFunctionValue(self, receiver: List[int], k: int) -> int:
        n = len(receiver)
        parent = [[] for _ in range(n)]
        for i in range(n):
            x = receiver[i]
            parent[x].append(i)

        def findRoot():
            root = []
            visited = set()
            for i in range(n):
                loop = set()
                while i not in visited:
                    visited.add(i)
                    loop.add(i)
                    i = receiver[i]
                if i in loop:
                    root.append(i)
            return root

        class State:
            def __init__(self, loop, k):
                self.loop = loop
                self.history = []

                rep = k // len(self.loop)
                rem = k % len(self.loop)
                self.value = sum(self.loop) * rep
                self.value += sum(self.loop[:rem])
                self.loopIdx = (k - 1 + len(self.loop)) % len(self.loop)
                self.loopK = k
                self.hisIdx = -1

            def push(self, x):
                self.history.append(x)
                saved = (self.loopIdx, self.loopK, self.hisIdx, self.value)

                self.value += x
                if self.hisIdx != -1:
                    self.value -= self.history[self.hisIdx]
                    self.hisIdx += 1
                else:
                    self.value -= self.loop[self.loopIdx]
                    self.loopIdx -= 1
                    self.loopK -= 1
                    if self.loopIdx < 0 and self.loopK > 0:
                        self.loopIdx += len(self.loop)
                    if self.loopK == 0:
                        self.hisIdx = 0
                return saved

            def pop(self, saved):
                self.loopIdx, self.loopK, self.hisIdx, self.value = saved
                self.history.pop()

        root = findRoot()
        ans = [-1] * n

        def visitRoot(r):
            loop = []
            visit = set()
            x = r
            while x not in visit:
                loop.append(x)
                visit.add(x)
                x = receiver[x]
            st = State(loop, k + 1)

            def dfs(r, st: State):
                ans[r] = st.value
                for p in parent[r]:
                    if ans[p] == -1:
                        saved = st.push(p)
                        dfs(p, st)
                        st.pop(saved)

            dfs(r, st)

        for r in root:
            visitRoot(r)
        # print(ans)
        return max(ans)

后来看了题解,觉得这个倍增算法很简单很容易理解。

class Solution:
    def getMaxFunctionValue(self, receiver: List[int], k: int) -> int:
        n = len(receiver)
        N = 36

        f = [[-1] * N for _ in range(n)]
        w = [[-1] * N for _ in range(n)]
        for i in range(n):
            x = receiver[i]
            f[i][0] = x
            w[i][0] = x

        for j in range(1, N):
            for i in range(n):
                f[i][j] = f[f[i][j - 1]][j - 1]
                w[i][j] = w[i][j - 1] + w[f[i][j - 1]][j - 1]

        ans = 0
        for i in range(n):
            c = i
            x = i
            for j in range(N):
                if k & (1 << j):
                    y = f[x][j]
                    c += w[x][j]
                    x = y
            ans = max(ans, c)
        return ans