MCTS for Nim
我记得好久之前看过一篇文章,大概是说AlphaZero里面的实现原理,里面主要的驱动框架就是MCTS(monte carlo tree search). 然后再整个框架里面有两个network: policy network(策略网络,主要是计算move的先验概率)和value network(价值网络,主要看棋盘).
这两天我让gpt给我重新讲解了一下MCTS的框架,大概是看懂了,并且按照这个框架写了一个nim的程序。可以说效果不是特别好,按照gpt给我的分析是,我的模拟 `simulation` 部分还有待改进。较差的simulation会影响到搜索的效率。
需要解决的问题是:有一堆石子,AB分别取,每个人最多获取1-5个,最后拿走的人获胜,A开始先取。这个问题其实是有封闭解的: 只要是6的倍数,那么就是B获胜,否则就是一定是A获胜。因为不管对方怎么拿,都可以对齐到6的倍数上。
整个程序框架大概需要定义几个类:
- Move. 表示action, 拿走多少个石子
- State. 表示当前状态,需要考虑还剩多少个石子,以及谁准备拿。
- Node. 表示探索空间,其中weights使用UCB1来计算 \(UCB1 = \frac{w_i}{n_i} + c \times \sqrt{\frac{\ln (N+1)}{n_i}}\).
class Move: def __init__(self, taken): self.taken = taken def __eq__(self, other: 'Move'): return self.taken == other.taken def __hash__(self): return hash(self.taken) def __repr__(self): return f'(t={self.taken})' class State: def __init__(self, n, player=0): self.player = player self.n = n def get_moves(self): return [Move(x) for x in range(1, min(self.n, MAX_TAKEN) + 1)] def is_terminal(self): return self.n == 0 def apply_move(self, move: Move): return State(self.n - move.taken, 1 - self.player) def get_result(self, me): assert self.is_terminal() if self.player == me: return -100 else: return 1 def __repr__(self): return f'(p={self.player},n={self.n})' class Node: def __init__(self, state: State, parent: 'Node', move: 'Move'): self.state: State = state self.parent: Node = parent self.move: Move = move self.children: list[Node] = [] self.expand_moves: set[Move] = set() self.visit = 0 self.value = 0 def is_fully_expanded(self) -> bool: return len(self.children) == len(self.state.get_moves()) def weights(self, explore_weight: float) -> List[float]: weights = [ (child.value / (child.visit + 1e-6)) + explore_weight * np.sqrt( np.log(self.visit + 1) / (child.visit + 1e-6)) for child in self.children ] return weights def best_child(self, explore_weight: float) -> 'Node': weights = self.weights(explore_weight) return self.children[np.argmax(weights)] def __repr__(self): return f'node(state={self.state},move={self.move},visit={self.visit},value={self.value})'
然后MCTS程序如下,大致思路是:
- 如果一个节点没有完全扩展的话,也就是还有子节点没有尝试,那么就要先去尝试子节点。
- 如果完全扩展的话,那么选择一个最有潜力的子节点去尝试。
- 按照这个子节点开始进行模拟游戏,这个模拟策略很重要,最后得到游戏结果
- 这个游戏结果就认为是这个子节点的权重,然后反向传播更新路径上的权重。
def mcts(init_state: State, iter_max: int, explore_weight: float, rnd: random.Random): init_node = Node(init_state, None, None) def simulation(state, rnd: random.Random): """改进的 Simulation 阶段""" while not state.is_terminal(): moves = state.get_moves() # move = max(moves, key=lambda m: m.taken) if rnd.random() > 0.8 else rnd.choice(moves) # move = min(moves, key = lambda m: m.taken) move = rnd.choice(moves) # move = max(moves, key=lambda m: m.taken) state = state.apply_move(move) return state def iterate(root: Node): # selection. # 如果当前节点不是terminal并且是fully expanded的话,才进行best child筛选 while root.is_fully_expanded() and not root.state.is_terminal(): child = root.best_child(explore_weight) root = child # expansion. 从当前没有fully expanded的节点去扩展一个节点出来 assert root if not root.state.is_terminal(): assert not root.is_fully_expanded() moves = root.state.get_moves() for move in moves: if move in root.expand_moves: continue new_state = root.state.apply_move(move) new_child = Node(new_state, root, move) root.children.append(new_child) root.expand_moves.add(move) root = new_child break # simulation. state = root.state if not state.is_terminal(): state = simulation(state, rnd) # backprop result = state.get_result(init_state.player) while root: root.visit += 1 root.value += result root = root.parent for _ in range(iter_max): iterate(init_node) best_move = init_node.best_child(0).move if DEBUG_SIM: print(init_node) weights = init_node.weights(0) for c, w in zip(init_node.children, weights): print(' - ', c, w) print(best_move) return best_move