11
11
12
12
13
13
@jit
14
- def _opt_select (nodes , c_puct ):
15
- """ Optimized version of the selection """
14
+ def _opt_select (nodes , c_puct = C_PUCT ):
15
+ """ Optimized version of the selection based of the PUCT formula """
16
16
17
17
total_count = 0
18
18
for i in range (nodes .shape [0 ]):
@@ -38,16 +38,6 @@ def dirichlet_noise(probas):
38
38
return new_probas
39
39
40
40
41
- def _select (nodes , c_puct = C_PUCT ):
42
- """
43
- Select the move that maximises the mean value of the next state +
44
- the result of the PUCT function
45
- """
46
-
47
- return nodes [_opt_select (np .array ([[node .q , node .n , node .p ] \
48
- for node in nodes ]), c_puct )]
49
-
50
-
51
41
class Node :
52
42
53
43
def __init__ (self , parent = None , proba = None , move = None ):
@@ -105,19 +95,23 @@ def run(self):
105
95
106
96
## Wait for the eval_queue to be filled by new positions to evaluate
107
97
self .condition_search .acquire ()
108
- while len (self .eval_queue ) < BATCH_SIZE_EVAL :
98
+ while len (self .eval_queue ) < BATCH_SIZE_EVAL and \
99
+ len (self .eval_queue ) != MCTS_PARALLEL - BATCH_SIZE_EVAL or \
100
+ len (self .eval_queue ) == 0 :
109
101
self .condition_search .wait ()
110
102
self .condition_search .release ()
111
103
112
104
self .condition_eval .acquire ()
105
+ keys = list (self .eval_queue .keys ())
106
+
113
107
## Predict the feature_maps, policy and value
114
- states = torch .tensor (np .array (list (self .eval_queue .values ()))[0 :BATCH_SIZE_EVAL ],
115
- dtype = torch .float , device = DEVICE )
108
+ states = torch .tensor (np .array (list (self .eval_queue .values ()))[0 :len ( keys ) ],
109
+ dtype = torch .float , device = DEVICE )
116
110
v , probas = self .player .predict (states )
117
111
118
112
## Replace the state with the result in the eval_queue
119
113
## and notify all the threads that the result are available
120
- for idx , i in zip (list ( self . eval_queue . keys ()) , range (BATCH_SIZE_EVAL )):
114
+ for idx , i in zip (keys , range (len ( keys ) )):
121
115
del self .eval_queue [idx ]
122
116
self .result_queue [idx ] = (probas [i ].cpu ().data .numpy (), v [i ])
123
117
@@ -150,7 +144,10 @@ def run(self):
150
144
151
145
## Traverse the tree until leaf
152
146
while not current_node .is_leaf () and not done :
153
- current_node = _select (current_node .childrens )
147
+ ## Select the action that maximizes the PUCT algorithm
148
+ current_node = current_node .childrens [_opt_select ( \
149
+ np .array ([[node .q , node .n , node .p ] \
150
+ for node in current_node .childrens ]))]
154
151
155
152
## Virtual loss since multithreading
156
153
self .lock .acquire ()
@@ -161,7 +158,8 @@ def run(self):
161
158
162
159
if not done :
163
160
164
- ## Add current leaf state to the evaluation queue
161
+ ## Add current leaf state with random dihedral transformation
162
+ ## to the evaluation queue
165
163
self .condition_search .acquire ()
166
164
self .eval_queue [self .thread_id ] = sample_rotation (state , num = 1 )
167
165
self .condition_search .notify ()
@@ -243,7 +241,6 @@ def search(self, current_game, player, competitive=False):
243
241
Search the best moves through the game tree with
244
242
the policy and value network to update node statistics
245
243
"""
246
- threads = []
247
244
248
245
## Locking for thread synchronization
249
246
condition_eval = threading .Condition ()
@@ -256,6 +253,7 @@ def search(self, current_game, player, competitive=False):
256
253
evaluator = EvaluatorThread (player , eval_queue , result_queue , condition_search , condition_eval )
257
254
evaluator .start ()
258
255
256
+ threads = []
259
257
## Do exactly the required number of simulation per thread
260
258
for sim in range (MCTS_SIM // MCTS_PARALLEL ):
261
259
for idx in range (MCTS_PARALLEL ):
0 commit comments