MCTS精简版代码

https://int8.io/monte-carlo-tree-search-beginners-guide/

https://github.com/int8/monte-carlo-tree-search

http://tim.hibal.org/blog/alpha-zero-how-and-why-it-works/

MCTS在选择节点的时候实际上是MAB问题,一种算法是UCB(置信区间上界:upper confidence bound)算法,另外一种简单的算法是汤普森采样算法。

搜索代码 `search.py`,我在上面增加了部分注释

class MonteCarloTreeSearch(object):

    def __init__(self, node):
        """
        MonteCarloTreeSearchNode
        Parameters
        ----------
        node : mctspy.tree.nodes.MonteCarloTreeSearchNode
        """
        self.root = node

    def best_action(self, simulations_number):
        """

        Parameters
        ----------
        simulations_number : int
            number of simulations performed to get the best action

        Returns
        -------

        """
        for _ in range(0, simulations_number):
            # 选择一个没有完全展开的节点
            v = self._tree_policy()
            # 以当前节点以随机的逻辑展开,得到一个reward
            reward = v.rollout()
            # 将这个reward反向传播到所有父亲节点上
            v.backpropagate(reward)
            # to select best child go for exploitation only
        return self.root.best_child(c_param=0.)

    def _tree_policy(self):
        """
        selects node to run rollout/playout for

        Returns
        -------

        """
        # 找到一个没有完全展开的节点
        # 如果该节点完全展开的话,那么选择best_child往下
        current_node = self.root
        while not current_node.is_terminal_node():
            if not current_node.is_fully_expanded():
                return current_node.expand()
            else:
                current_node = current_node.best_child()
        return current_node

树节点代码 `node.py`, 我在上面也增加了注释. 我理解alpha-zero框架是MCTS,但是在两个函数上做了优化:

class MonteCarloTreeSearchNode(ABC):

    def __init__(self, state, parent=None):
        """
        Parameters
        ----------
        state : mctspy.games.common.TwoPlayersAbstractGameState
        parent : MonteCarloTreeSearchNode
        """
        self.state = state
        self.parent = parent # 父亲节点,用于反向传播
        self.children = [] # 叶子节点,用于评估展开是选择哪个节点

    def is_fully_expanded(self): # 是否完全扩展
        return len(self.untried_actions) == 0

    def best_child(self, c_param=1.4): # UCB算法选择最优子节点
        choices_weights = [
            (c.q / c.n) + c_param * np.sqrt((2 * np.log(self.n) / c.n))
            for c in self.children
        ]
        return self.children[np.argmax(choices_weights)]

    def rollout_policy(self, possible_moves): # 展开策略:随机选择一个子节点展开
        return possible_moves[np.random.randint(len(possible_moves))]


class TwoPlayersGameMonteCarloTreeSearchNode(MonteCarloTreeSearchNode):

    def __init__(self, state, parent=None):
        super().__init__(state, parent)
        self._number_of_visits = 0.
        self._results = defaultdict(int)
        self._untried_actions = None

    @property
    def untried_actions(self):
        if self._untried_actions is None:
            self._untried_actions = self.state.get_legal_actions()
        return self._untried_actions

    @property
    def q(self):
        wins = self._results[self.parent.state.next_to_move]
        loses = self._results[-1 * self.parent.state.next_to_move]
        return wins - loses

    @property
    def n(self):
        return self._number_of_visits

    def expand(self):
        action = self.untried_actions.pop()
        next_state = self.state.move(action)
        child_node = TwoPlayersGameMonteCarloTreeSearchNode(
            next_state, parent=self
        )
        self.children.append(child_node)
        return child_node

    def is_terminal_node(self):
        return self.state.is_game_over()

    def rollout(self): # 不断展开直到终局,现实中可以使用评估函数在一定深度上cut-off
        current_rollout_state = self.state
        while not current_rollout_state.is_game_over():
            possible_moves = current_rollout_state.get_legal_actions()
            action = self.rollout_policy(possible_moves)
            current_rollout_state = current_rollout_state.move(action)
        return current_rollout_state.game_result

    def backpropagate(self, result): # 反向传播,改节点访问多少次,该节点的胜率如何
        self._number_of_visits += 1.
        self._results[result] += 1.
        if self.parent:
            self.parent.backpropagate(result)