LC 2836. 在传球游戏中最大化函数值
https://leetcode.cn/problems/maximize-value-of-function-in-a-ball-passing-game/description/
因为要不断地传下去,所以路径上肯定存在环。我最开始的思路就是,针对环进行优化。但是一个问题就是,如果计算每个节点的循环K的距离。
一个方法就是,找到那个循环入口点,然后根据这个循环入口点向前推算:每前进一个,那么就需要减去尾部一个,所以我们需要保存尾部的值。
为了可以比较高效地保存尾部值(同时考虑我们可能已经出环了),那么需要设计一个数据结构,就是我代码中的 State. 有了这个结构然后不断地向前推就好了。
这个代码写正确还真花费了一些时间,主要复杂性都是在这个 State 对象里面。另外就是算法需要找到正确的环,就是 findRoot这个部分。
class Solution:
def getMaxFunctionValue(self, receiver: List[int], k: int) -> int:
n = len(receiver)
parent = [[] for _ in range(n)]
for i in range(n):
x = receiver[i]
parent[x].append(i)
def findRoot():
root = []
visited = set()
for i in range(n):
loop = set()
while i not in visited:
visited.add(i)
loop.add(i)
i = receiver[i]
if i in loop:
root.append(i)
return root
class State:
def __init__(self, loop, k):
self.loop = loop
self.history = []
rep = k // len(self.loop)
rem = k % len(self.loop)
self.value = sum(self.loop) * rep
self.value += sum(self.loop[:rem])
self.loopIdx = (k - 1 + len(self.loop)) % len(self.loop)
self.loopK = k
self.hisIdx = -1
def push(self, x):
self.history.append(x)
saved = (self.loopIdx, self.loopK, self.hisIdx, self.value)
self.value += x
if self.hisIdx != -1:
self.value -= self.history[self.hisIdx]
self.hisIdx += 1
else:
self.value -= self.loop[self.loopIdx]
self.loopIdx -= 1
self.loopK -= 1
if self.loopIdx < 0 and self.loopK > 0:
self.loopIdx += len(self.loop)
if self.loopK == 0:
self.hisIdx = 0
return saved
def pop(self, saved):
self.loopIdx, self.loopK, self.hisIdx, self.value = saved
self.history.pop()
root = findRoot()
ans = [-1] * n
def visitRoot(r):
loop = []
visit = set()
x = r
while x not in visit:
loop.append(x)
visit.add(x)
x = receiver[x]
st = State(loop, k + 1)
def dfs(r, st: State):
ans[r] = st.value
for p in parent[r]:
if ans[p] == -1:
saved = st.push(p)
dfs(p, st)
st.pop(saved)
dfs(r, st)
for r in root:
visitRoot(r)
# print(ans)
return max(ans)
后来看了题解,觉得这个倍增算法很简单很容易理解。
class Solution:
def getMaxFunctionValue(self, receiver: List[int], k: int) -> int:
n = len(receiver)
N = 36
f = [[-1] * N for _ in range(n)]
w = [[-1] * N for _ in range(n)]
for i in range(n):
x = receiver[i]
f[i][0] = x
w[i][0] = x
for j in range(1, N):
for i in range(n):
f[i][j] = f[f[i][j - 1]][j - 1]
w[i][j] = w[i][j - 1] + w[f[i][j - 1]][j - 1]
ans = 0
for i in range(n):
c = i
x = i
for j in range(N):
if k & (1 << j):
y = f[x][j]
c += w[x][j]
x = y
ans = max(ans, c)
return ans