-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathTreeNode.py
70 lines (60 loc) · 3.13 KB
/
TreeNode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf-8 -*-
import numpy as np
'''
Node of MCTS Searching Tree
'''
class TreeNode(object):
def __init__(self, parent, prior_p):
self._parent = parent # parent node
self._children = {} # child nodes,a map from action to TreeNode
self._n_visits = 0 # visit count
self._Q = 0 # Q Value
self._u = 0 # bonus,calculated based on the visit count and prior probability
self._P = prior_p # prior probability,calculated based on the Network
def expand(self, action_priors):
"""Expand tree by creating new children.
action_priors -- output from policy function - a list of tuples of actions
and their prior probability according to the policy function.
"""
for action, prob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(self, prob)
def select(self, c_puct=5.0, epsilon=0.0, alpha=0.3):
"""Select action among children that gives maximum action value, Q plus bonus u(P).
(1-e)pa+e*dirichlet(eta) # add Dirichlet Noise for exploration
Returns:
A tuple of (action, next_node)
"""
return max(self._children.items(), key=lambda act_node: act_node[1]._get_value(c_puct, epsilon, alpha))
def _get_value(self, c_puct, epsilon=0, alpha=0.3):
"""Calculate and return the value for this node: a combination of leaf evaluations, Q, and
this node's prior adjusted for its visit count, u
c_puct -- a number in (0, inf) controlling the relative impact of values, Q, and
prior probability, P, on this node's score.
epsilon -- the fraction of the prior probability, and 1-epsilon is the corresponding dirichlet noise fraction
alpha -- the parameter of dirichlet noise
"""
noise = 0
if epsilon > 0: noise = np.random.dirichlet([alpha])[0] # 添加噪声,目前噪声比例epsilon=0,即,不使用噪声
self._u = c_puct * ((1-epsilon) * self._P + epsilon * noise) * \
np.sqrt(self._parent._n_visits) / (1 + self._n_visits)
return self._Q + self._u
def backup(self, leaf_value):
"""Like a call to update(), but applied recursively for all ancestors.
"""
# If it is not root, this node's parent should be updated first.
if self._parent:
self._parent.backup(-leaf_value)
self._n_visits += 1
# Update Q, a running average of values for all visits.
# This step combine W,Q. Derived formula is as follows (reference AlphaGoZero Method Section):
# W = W_old + leaf_value; Q_old = W_old / (n-1) => W_old = (n-1)*Q_old; Q = W/n
# Q = W/n=(W_old + leaf_value)/n = ((n-1)*Q_old+leaf_value)/n
# = (n*Q_old-Q_old+leaf_value)/n = Q_old + (leaf_value-Q_old)/n
self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits
def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded).
"""
return self._children == {}
def is_root(self):
return self._parent is None