MCTS精简版代码
https://int8.io/monte-carlo-tree-search-beginners-guide/
https://github.com/int8/monte-carlo-tree-search
http://tim.hibal.org/blog/alpha-zero-how-and-why-it-works/
MCTS在选择节点的时候实际上是MAB问题,一种算法是UCB(置信区间上界:upper confidence bound)算法,另外一种简单的算法是汤普森采样算法。
搜索代码 `search.py` ,我在上面增加了部分注释
class MonteCarloTreeSearch(object):
def __init__(self, node):
"""
MonteCarloTreeSearchNode
Parameters
----------
node : mctspy.tree.nodes.MonteCarloTreeSearchNode
"""
self.root = node
def best_action(self, simulations_number):
"""
Parameters
----------
simulations_number : int
number of simulations performed to get the best action
Returns
-------
"""
for _ in range(0, simulations_number):
# 选择一个没有完全展开的节点
v = self._tree_policy()
# 以当前节点以随机的逻辑展开,得到一个reward
reward = v.rollout()
# 将这个reward反向传播到所有父亲节点上
v.backpropagate(reward)
# to select best child go for exploitation only
return self.root.best_child(c_param=0.)
def _tree_policy(self):
"""
selects node to run rollout/playout for
Returns
-------
"""
# 找到一个没有完全展开的节点
# 如果该节点完全展开的话,那么选择best_child往下
current_node = self.root
while not current_node.is_terminal_node():
if not current_node.is_fully_expanded():
return current_node.expand()
else:
current_node = current_node.best_child()
return current_node
树节点代码 `node.py`, 我在上面也增加了注释. 我理解alpha-zero框架是MCTS,但是在两个函数上做了优化:
- best_child. 每个子节点的扩展概率是不同的,那么在UCB计算节点未来价值的时候,可以将这个概率考虑进去。
- rollout. 是否可以不用完全展开,而在一定深度上直接评估局面。
class MonteCarloTreeSearchNode(ABC):
def __init__(self, state, parent=None):
"""
Parameters
----------
state : mctspy.games.common.TwoPlayersAbstractGameState
parent : MonteCarloTreeSearchNode
"""
self.state = state
self.parent = parent # 父亲节点,用于反向传播
self.children = [] # 叶子节点,用于评估展开是选择哪个节点
def is_fully_expanded(self): # 是否完全扩展
return len(self.untried_actions) == 0
def best_child(self, c_param=1.4): # UCB算法选择最优子节点
choices_weights = [
(c.q / c.n) + c_param * np.sqrt((2 * np.log(self.n) / c.n))
for c in self.children
]
return self.children[np.argmax(choices_weights)]
def rollout_policy(self, possible_moves): # 展开策略:随机选择一个子节点展开
return possible_moves[np.random.randint(len(possible_moves))]
class TwoPlayersGameMonteCarloTreeSearchNode(MonteCarloTreeSearchNode):
def __init__(self, state, parent=None):
super().__init__(state, parent)
self._number_of_visits = 0.
self._results = defaultdict(int)
self._untried_actions = None
@property
def untried_actions(self):
if self._untried_actions is None:
self._untried_actions = self.state.get_legal_actions()
return self._untried_actions
@property
def q(self):
wins = self._results[self.parent.state.next_to_move]
loses = self._results[-1 * self.parent.state.next_to_move]
return wins - loses
@property
def n(self):
return self._number_of_visits
def expand(self):
action = self.untried_actions.pop()
next_state = self.state.move(action)
child_node = TwoPlayersGameMonteCarloTreeSearchNode(
next_state, parent=self
)
self.children.append(child_node)
return child_node
def is_terminal_node(self):
return self.state.is_game_over()
def rollout(self): # 不断展开直到终局,现实中可以使用评估函数在一定深度上cut-off
current_rollout_state = self.state
while not current_rollout_state.is_game_over():
possible_moves = current_rollout_state.get_legal_actions()
action = self.rollout_policy(possible_moves)
current_rollout_state = current_rollout_state.move(action)
return current_rollout_state.game_result
def backpropagate(self, result): # 反向传播,改节点访问多少次,该节点的胜率如何
self._number_of_visits += 1.
self._results[result] += 1.
if self.parent:
self.parent.backpropagate(result)