MCTS精简版代码
https://int8.io/monte-carlo-tree-search-beginners-guide/
https://github.com/int8/monte-carlo-tree-search
http://tim.hibal.org/blog/alpha-zero-how-and-why-it-works/
MCTS在选择节点的时候实际上是MAB问题,一种算法是UCB(置信区间上界:upper confidence bound)算法,另外一种简单的算法是汤普森采样算法。
搜索代码 `search.py` ,我在上面增加了部分注释
class MonteCarloTreeSearch(object): def __init__(self, node): """ MonteCarloTreeSearchNode Parameters ---------- node : mctspy.tree.nodes.MonteCarloTreeSearchNode """ self.root = node def best_action(self, simulations_number): """ Parameters ---------- simulations_number : int number of simulations performed to get the best action Returns ------- """ for _ in range(0, simulations_number): # 选择一个没有完全展开的节点 v = self._tree_policy() # 以当前节点以随机的逻辑展开,得到一个reward reward = v.rollout() # 将这个reward反向传播到所有父亲节点上 v.backpropagate(reward) # to select best child go for exploitation only return self.root.best_child(c_param=0.) def _tree_policy(self): """ selects node to run rollout/playout for Returns ------- """ # 找到一个没有完全展开的节点 # 如果该节点完全展开的话,那么选择best_child往下 current_node = self.root while not current_node.is_terminal_node(): if not current_node.is_fully_expanded(): return current_node.expand() else: current_node = current_node.best_child() return current_node
树节点代码 `node.py`, 我在上面也增加了注释. 我理解alpha-zero框架是MCTS,但是在两个函数上做了优化:
- best_child. 每个子节点的扩展概率是不同的,那么在UCB计算节点未来价值的时候,可以将这个概率考虑进去。
- rollout. 是否可以不用完全展开,而在一定深度上直接评估局面。
class MonteCarloTreeSearchNode(ABC): def __init__(self, state, parent=None): """ Parameters ---------- state : mctspy.games.common.TwoPlayersAbstractGameState parent : MonteCarloTreeSearchNode """ self.state = state self.parent = parent # 父亲节点,用于反向传播 self.children = [] # 叶子节点,用于评估展开是选择哪个节点 def is_fully_expanded(self): # 是否完全扩展 return len(self.untried_actions) == 0 def best_child(self, c_param=1.4): # UCB算法选择最优子节点 choices_weights = [ (c.q / c.n) + c_param * np.sqrt((2 * np.log(self.n) / c.n)) for c in self.children ] return self.children[np.argmax(choices_weights)] def rollout_policy(self, possible_moves): # 展开策略:随机选择一个子节点展开 return possible_moves[np.random.randint(len(possible_moves))] class TwoPlayersGameMonteCarloTreeSearchNode(MonteCarloTreeSearchNode): def __init__(self, state, parent=None): super().__init__(state, parent) self._number_of_visits = 0. self._results = defaultdict(int) self._untried_actions = None @property def untried_actions(self): if self._untried_actions is None: self._untried_actions = self.state.get_legal_actions() return self._untried_actions @property def q(self): wins = self._results[self.parent.state.next_to_move] loses = self._results[-1 * self.parent.state.next_to_move] return wins - loses @property def n(self): return self._number_of_visits def expand(self): action = self.untried_actions.pop() next_state = self.state.move(action) child_node = TwoPlayersGameMonteCarloTreeSearchNode( next_state, parent=self ) self.children.append(child_node) return child_node def is_terminal_node(self): return self.state.is_game_over() def rollout(self): # 不断展开直到终局,现实中可以使用评估函数在一定深度上cut-off current_rollout_state = self.state while not current_rollout_state.is_game_over(): possible_moves = current_rollout_state.get_legal_actions() action = self.rollout_policy(possible_moves) current_rollout_state = current_rollout_state.move(action) return current_rollout_state.game_result def backpropagate(self, result): # 反向传播,改节点访问多少次,该节点的胜率如何 self._number_of_visits += 1. self._results[result] += 1. if self.parent: self.parent.backpropagate(result)