MCTS for Nim
我记得好久之前看过一篇文章,大概是说AlphaZero里面的实现原理,里面主要的驱动框架就是MCTS(monte carlo tree search). 然后再整个框架里面有两个network: policy network(策略网络,主要是计算move的先验概率)和value network(价值网络,主要看棋盘).
这两天我让gpt给我重新讲解了一下MCTS的框架,大概是看懂了,并且按照这个框架写了一个nim的程序。可以说效果不是特别好,按照gpt给我的分析是,我的模拟 `simulation` 部分还有待改进。较差的simulation会影响到搜索的效率。
需要解决的问题是:有一堆石子,AB分别取,每个人最多获取1-5个,最后拿走的人获胜,A开始先取。这个问题其实是有封闭解的: 只要是6的倍数,那么就是B获胜,否则就是一定是A获胜。因为不管对方怎么拿,都可以对齐到6的倍数上。
整个程序框架大概需要定义几个类:
- Move. 表示action, 拿走多少个石子
- State. 表示当前状态,需要考虑还剩多少个石子,以及谁准备拿。
- Node. 表示探索空间,其中weights使用UCB1来计算 \(UCB1 = \frac{w_i}{n_i} + c \times \sqrt{\frac{\ln (N+1)}{n_i}}\).
class Move:
def __init__(self, taken):
self.taken = taken
def __eq__(self, other: 'Move'):
return self.taken == other.taken
def __hash__(self):
return hash(self.taken)
def __repr__(self):
return f'(t={self.taken})'
class State:
def __init__(self, n, player=0):
self.player = player
self.n = n
def get_moves(self):
return [Move(x) for x in range(1, min(self.n, MAX_TAKEN) + 1)]
def is_terminal(self):
return self.n == 0
def apply_move(self, move: Move):
return State(self.n - move.taken, 1 - self.player)
def get_result(self, me):
assert self.is_terminal()
if self.player == me:
return -100
else:
return 1
def __repr__(self):
return f'(p={self.player},n={self.n})'
class Node:
def __init__(self, state: State, parent: 'Node', move: 'Move'):
self.state: State = state
self.parent: Node = parent
self.move: Move = move
self.children: list[Node] = []
self.expand_moves: set[Move] = set()
self.visit = 0
self.value = 0
def is_fully_expanded(self) -> bool:
return len(self.children) == len(self.state.get_moves())
def weights(self, explore_weight: float) -> List[float]:
weights = [
(child.value / (child.visit + 1e-6)) + explore_weight * np.sqrt(
np.log(self.visit + 1) / (child.visit + 1e-6))
for child in self.children
]
return weights
def best_child(self, explore_weight: float) -> 'Node':
weights = self.weights(explore_weight)
return self.children[np.argmax(weights)]
def __repr__(self):
return f'node(state={self.state},move={self.move},visit={self.visit},value={self.value})'
然后MCTS程序如下,大致思路是:
- 如果一个节点没有完全扩展的话,也就是还有子节点没有尝试,那么就要先去尝试子节点。
- 如果完全扩展的话,那么选择一个最有潜力的子节点去尝试。
- 按照这个子节点开始进行模拟游戏,这个模拟策略很重要,最后得到游戏结果
- 这个游戏结果就认为是这个子节点的权重,然后反向传播更新路径上的权重。
def mcts(init_state: State, iter_max: int, explore_weight: float, rnd: random.Random):
init_node = Node(init_state, None, None)
def simulation(state, rnd: random.Random):
"""改进的 Simulation 阶段"""
while not state.is_terminal():
moves = state.get_moves()
# move = max(moves, key=lambda m: m.taken) if rnd.random() > 0.8 else rnd.choice(moves)
# move = min(moves, key = lambda m: m.taken)
move = rnd.choice(moves)
# move = max(moves, key=lambda m: m.taken)
state = state.apply_move(move)
return state
def iterate(root: Node):
# selection.
# 如果当前节点不是terminal并且是fully expanded的话,才进行best child筛选
while root.is_fully_expanded() and not root.state.is_terminal():
child = root.best_child(explore_weight)
root = child
# expansion. 从当前没有fully expanded的节点去扩展一个节点出来
assert root
if not root.state.is_terminal():
assert not root.is_fully_expanded()
moves = root.state.get_moves()
for move in moves:
if move in root.expand_moves: continue
new_state = root.state.apply_move(move)
new_child = Node(new_state, root, move)
root.children.append(new_child)
root.expand_moves.add(move)
root = new_child
break
# simulation.
state = root.state
if not state.is_terminal():
state = simulation(state, rnd)
# backprop
result = state.get_result(init_state.player)
while root:
root.visit += 1
root.value += result
root = root.parent
for _ in range(iter_max):
iterate(init_node)
best_move = init_node.best_child(0).move
if DEBUG_SIM:
print(init_node)
weights = init_node.weights(0)
for c, w in zip(init_node.children, weights):
print(' - ', c, w)
print(best_move)
return best_move