LC 2836. 在传球游戏中最大化函数值
https://leetcode.cn/problems/maximize-value-of-function-in-a-ball-passing-game/description/
因为要不断地传下去,所以路径上肯定存在环。我最开始的思路就是,针对环进行优化。但是一个问题就是,如果计算每个节点的循环K的距离。
一个方法就是,找到那个循环入口点,然后根据这个循环入口点向前推算:每前进一个,那么就需要减去尾部一个,所以我们需要保存尾部的值。
为了可以比较高效地保存尾部值(同时考虑我们可能已经出环了),那么需要设计一个数据结构,就是我代码中的 State. 有了这个结构然后不断地向前推就好了。
这个代码写正确还真花费了一些时间,主要复杂性都是在这个 State 对象里面。另外就是算法需要找到正确的环,就是 findRoot这个部分。
class Solution: def getMaxFunctionValue(self, receiver: List[int], k: int) -> int: n = len(receiver) parent = [[] for _ in range(n)] for i in range(n): x = receiver[i] parent[x].append(i) def findRoot(): root = [] visited = set() for i in range(n): loop = set() while i not in visited: visited.add(i) loop.add(i) i = receiver[i] if i in loop: root.append(i) return root class State: def __init__(self, loop, k): self.loop = loop self.history = [] rep = k // len(self.loop) rem = k % len(self.loop) self.value = sum(self.loop) * rep self.value += sum(self.loop[:rem]) self.loopIdx = (k - 1 + len(self.loop)) % len(self.loop) self.loopK = k self.hisIdx = -1 def push(self, x): self.history.append(x) saved = (self.loopIdx, self.loopK, self.hisIdx, self.value) self.value += x if self.hisIdx != -1: self.value -= self.history[self.hisIdx] self.hisIdx += 1 else: self.value -= self.loop[self.loopIdx] self.loopIdx -= 1 self.loopK -= 1 if self.loopIdx < 0 and self.loopK > 0: self.loopIdx += len(self.loop) if self.loopK == 0: self.hisIdx = 0 return saved def pop(self, saved): self.loopIdx, self.loopK, self.hisIdx, self.value = saved self.history.pop() root = findRoot() ans = [-1] * n def visitRoot(r): loop = [] visit = set() x = r while x not in visit: loop.append(x) visit.add(x) x = receiver[x] st = State(loop, k + 1) def dfs(r, st: State): ans[r] = st.value for p in parent[r]: if ans[p] == -1: saved = st.push(p) dfs(p, st) st.pop(saved) dfs(r, st) for r in root: visitRoot(r) # print(ans) return max(ans)
后来看了题解,觉得这个倍增算法很简单很容易理解。
class Solution: def getMaxFunctionValue(self, receiver: List[int], k: int) -> int: n = len(receiver) N = 36 f = [[-1] * N for _ in range(n)] w = [[-1] * N for _ in range(n)] for i in range(n): x = receiver[i] f[i][0] = x w[i][0] = x for j in range(1, N): for i in range(n): f[i][j] = f[f[i][j - 1]][j - 1] w[i][j] = w[i][j - 1] + w[f[i][j - 1]][j - 1] ans = 0 for i in range(n): c = i x = i for j in range(N): if k & (1 << j): y = f[x][j] c += w[x][j] x = y ans = max(ans, c) return ans