MCTS for Nim

我记得好久之前看过一篇文章,大概是说AlphaZero里面的实现原理,里面主要的驱动框架就是MCTS(monte carlo tree search). 然后再整个框架里面有两个network: policy network(策略网络,主要是计算move的先验概率)和value network(价值网络,主要看棋盘).

这两天我让gpt给我重新讲解了一下MCTS的框架,大概是看懂了,并且按照这个框架写了一个nim的程序。可以说效果不是特别好,按照gpt给我的分析是,我的模拟 `simulation` 部分还有待改进。较差的simulation会影响到搜索的效率。

需要解决的问题是:有一堆石子,AB分别取,每个人最多获取1-5个,最后拿走的人获胜,A开始先取。这个问题其实是有封闭解的: 只要是6的倍数,那么就是B获胜,否则就是一定是A获胜。因为不管对方怎么拿,都可以对齐到6的倍数上。

整个程序框架大概需要定义几个类:

class Move:
    def __init__(self, taken):
        self.taken = taken

    def __eq__(self, other: 'Move'):
        return self.taken == other.taken

    def __hash__(self):
        return hash(self.taken)

    def __repr__(self):
        return f'(t={self.taken})'


class State:
    def __init__(self, n, player=0):
        self.player = player
        self.n = n

    def get_moves(self):
        return [Move(x) for x in range(1, min(self.n, MAX_TAKEN) + 1)]

    def is_terminal(self):
        return self.n == 0

    def apply_move(self, move: Move):
        return State(self.n - move.taken, 1 - self.player)

    def get_result(self, me):
        assert self.is_terminal()
        if self.player == me:
            return -100
        else:
            return 1

    def __repr__(self):
        return f'(p={self.player},n={self.n})'


class Node:
    def __init__(self, state: State, parent: 'Node', move: 'Move'):
        self.state: State = state
        self.parent: Node = parent
        self.move: Move = move
        self.children: list[Node] = []
        self.expand_moves: set[Move] = set()
        self.visit = 0
        self.value = 0

    def is_fully_expanded(self) -> bool:
        return len(self.children) == len(self.state.get_moves())

    def weights(self, explore_weight: float) -> List[float]:
        weights = [
            (child.value / (child.visit + 1e-6)) + explore_weight * np.sqrt(
                np.log(self.visit + 1) / (child.visit + 1e-6))
            for child in self.children
        ]
        return weights

    def best_child(self, explore_weight: float) -> 'Node':
        weights = self.weights(explore_weight)
        return self.children[np.argmax(weights)]

    def __repr__(self):
        return f'node(state={self.state},move={self.move},visit={self.visit},value={self.value})'

然后MCTS程序如下,大致思路是:

def mcts(init_state: State, iter_max: int, explore_weight: float, rnd: random.Random):
    init_node = Node(init_state, None, None)

    def simulation(state, rnd: random.Random):
        """改进的 Simulation 阶段"""
        while not state.is_terminal():
            moves = state.get_moves()
            # move = max(moves, key=lambda m: m.taken) if rnd.random() > 0.8 else rnd.choice(moves)
            # move = min(moves, key = lambda m: m.taken)
            move = rnd.choice(moves)
            # move = max(moves, key=lambda m: m.taken)
            state = state.apply_move(move)
        return state

    def iterate(root: Node):
        # selection.
        # 如果当前节点不是terminal并且是fully expanded的话,才进行best child筛选
        while root.is_fully_expanded() and not root.state.is_terminal():
            child = root.best_child(explore_weight)
            root = child

        # expansion. 从当前没有fully expanded的节点去扩展一个节点出来
        assert root
        if not root.state.is_terminal():
            assert not root.is_fully_expanded()
            moves = root.state.get_moves()
            for move in moves:
                if move in root.expand_moves: continue
                new_state = root.state.apply_move(move)
                new_child = Node(new_state, root, move)
                root.children.append(new_child)
                root.expand_moves.add(move)
                root = new_child
                break

        # simulation.
        state = root.state
        if not state.is_terminal():
            state = simulation(state, rnd)

        # backprop
        result = state.get_result(init_state.player)
        while root:
            root.visit += 1
            root.value += result
            root = root.parent

    for _ in range(iter_max):
        iterate(init_node)

    best_move = init_node.best_child(0).move

    if DEBUG_SIM:
        print(init_node)
        weights = init_node.weights(0)
        for c, w in zip(init_node.children, weights):
            print(' - ', c, w)
        print(best_move)
    return best_move