8
8
from lib .utils import _prepare_state , sample_rotation
9
9
10
10
11
- class Node ():
12
-
11
+ class Node :
13
12
def __init__ (self , parent = None , proba = None , move = None ):
14
13
"""
15
14
p : probability of reaching that node, given by the policy net
@@ -29,19 +28,14 @@ def __init__(self, parent=None, proba=None, move=None):
29
28
def update (self , v ):
30
29
""" Update the node statistics after a playout """
31
30
32
- self .w = self .w + float (v )
33
- if self .n > 0 :
34
- self .q = self .w / self .n
35
- else :
36
- self .q = 0
31
+ self .w = self .w + v
32
+ self .q = self .w / self .n if self .n > 0 else 0
37
33
38
34
39
35
def is_leaf (self ):
40
36
""" Check whether node is a leaf or not """
41
37
42
- if self .childrens and len (self .childrens ) > 0 :
43
- return False
44
- return True
38
+ return len (self .childrens ) == 0
45
39
46
40
47
41
def expand (self , probas ):
@@ -86,55 +80,63 @@ def _opt_select(nodes, c_puct):
86
80
return equals [0 ]
87
81
88
82
89
-
90
83
class EvaluatorThread (threading .Thread ):
91
- def __init__ (self , player , eval_queue , condition_search , condition_eval ):
84
+ def __init__ (self , player , eval_queue , result_queue , condition_search , condition_eval , condition_mix ):
92
85
threading .Thread .__init__ (self )
93
86
self .eval_queue = eval_queue
87
+ self .result_queue = result_queue
94
88
self .player = player
95
89
self .condition_search = condition_search
96
90
self .condition_eval = condition_eval
91
+ self .condition_mix = condition_mix
97
92
98
93
def run (self ):
99
94
total_sim = MCTS_SIM
100
95
while total_sim > 0 :
101
- self .condition_search .acquire ()
102
- while (len (self .eval_queue .values ()) != BATCH_SIZE_EVAL or \
103
- (len (self .eval_queue .values ()) < BATCH_SIZE_EVAL and \
104
- len (self .eval_queue .values ()) != total_sim )) or \
105
- (len (self .eval_queue .values ()) == BATCH_SIZE_EVAL and \
106
- not all (isinstance (state , np .ndarray ) for state in self .eval_queue .values ())):
107
- self .condition_search .wait ()
108
-
109
- self .condition_search .release ()
110
96
self .condition_eval .acquire ()
97
+ while (len (self .eval_queue .values ()) < BATCH_SIZE_EVAL or \
98
+ len (self .result_queue .values ()) > 0 ):
99
+ print ("notifying in evaluator, current len: %d" % len (self .eval_queue .values ()))
100
+ self .condition_eval .wait ()
101
+ self .condition_eval .release ()
102
+
111
103
states = []
112
104
for idx , state in self .eval_queue .items ():
113
105
states .append (sample_rotation (state , num = 1 ))
106
+ print ('states len: %d' % len (states ))
114
107
states = _prepare_state (states )
115
108
feature_maps = self .player .extractor (states [0 ])
116
109
117
110
## Policy and value prediction
118
111
probas = self .player .policy_net (feature_maps )
119
112
v = self .player .value_net (feature_maps )
120
- for idx in range (BATCH_SIZE_EVAL ):
121
- self .eval_queue [idx ] = (probas [idx ].cpu ().data .numpy (), v [idx ])
122
- self .condition_eval .notifyAll ()
123
- self .condition_eval .release ()
113
+ keys = list (self .eval_queue .keys ())
114
+ idx = 0
115
+ for key in keys :
116
+ self .result_queue [key ] = (probas [idx ].cpu ().data .numpy (), float (v [idx ]))
117
+ del self .eval_queue [key ]
118
+ idx += 1
119
+ del probas , v , feature_maps
120
+ self .condition_mix .acquire ()
121
+ self .condition_mix .notifyAll ()
122
+ self .condition_mix .release ()
124
123
total_sim -= BATCH_SIZE_EVAL
125
124
126
125
127
126
128
127
class SearchThread (threading .Thread ):
129
- def __init__ (self , mcts , game , eval_queue , thread_id , lock , condition_search , condition_eval ):
128
+ def __init__ (self , mcts , game , eval_queue , result_queue , thread_id , \
129
+ lock , condition_search , condition_eval , condition_mix ):
130
130
threading .Thread .__init__ (self )
131
+ self .result_queue = result_queue
131
132
self .eval_queue = eval_queue
132
133
self .mcts = mcts
133
134
self .game = game
134
135
self .lock = lock
135
136
self .thread_id = thread_id
136
137
self .condition_eval = condition_eval
137
138
self .condition_search = condition_search
139
+ self .condition_mix = condition_mix
138
140
139
141
140
142
def run (self ):
@@ -153,18 +155,25 @@ def run(self):
153
155
## Predict the probas
154
156
if not done :
155
157
self .condition_search .acquire ()
156
- self .eval_queue [self .thread_id ] = state
157
- self .condition_search .notify ()
158
-
158
+ while len (self .eval_queue .values ()) < BATCH_SIZE_EVAL and \
159
+ len (self .result_queue .values ()) == 0 :
160
+ print ("trying to release in thread id %d" % self .thread_id )
161
+ self .condition_search .wait ()
162
+ print ("added move in thread %d" % self .thread_id )
159
163
self .condition_search .release ()
164
+
165
+ self .eval_queue [self .thread_id ] = state
160
166
self .condition_eval .acquire ()
161
- self .condition_eval .wait ()
162
-
163
- res = self .eval_queue [self .thread_id ]
164
- probas = np .array (res [0 ])
165
- v = float (res [1 ])
167
+ self .condition_eval .notify ()
166
168
self .condition_eval .release ()
167
169
170
+ self .condition_mix .acquire ()
171
+ self .condition_mix .wait ()
172
+ self .condition_mix .release ()
173
+
174
+ probas = np .array (self .result_queue [self .thread_id ][0 ], copy = True )
175
+ v = float (self .result_queue [self .thread_id ][1 ])
176
+ del self .result_queue [self .thread_id ]
168
177
169
178
## Add noise in the root node
170
179
if not current_node .parent :
@@ -186,6 +195,7 @@ def run(self):
186
195
current_node .update (v )
187
196
current_node = current_node .parent
188
197
self .lock .release ()
198
+ print ("done in thread id %d" % self .thread_id )
189
199
190
200
191
201
class MCTS :
@@ -228,18 +238,22 @@ def search(self, current_game, player, competitive=False):
228
238
threads = []
229
239
condition_eval = threading .Condition ()
230
240
condition_search = threading .Condition ()
241
+ condition_mix = threading .Condition ()
231
242
lock = threading .Lock ()
232
243
eval_queue = {}
233
- evaluator = EvaluatorThread (player , eval_queue , condition_search , condition_eval )
244
+ result_queue = {}
245
+ evaluator = EvaluatorThread (player , eval_queue , result_queue , condition_search , \
246
+ condition_eval , condition_mix )
234
247
evaluator .start ()
235
248
for sim in range (MCTS_SIM // MCTS_PARALLEL ):
236
249
eval_queue .clear ()
237
250
for idx in range (MCTS_PARALLEL ):
238
- threads .append (SearchThread (self , current_game , eval_queue , idx ,
239
- lock , condition_search , condition_eval ))
251
+ threads .append (SearchThread (self , current_game , eval_queue , result_queue , idx ,
252
+ lock , condition_search , condition_eval , condition_mix ))
240
253
threads [- 1 ].start ()
241
254
for thread in threads :
242
255
thread .join ()
256
+ evaluator .join ()
243
257
244
258
action_scores = np .zeros ((current_game .board_size ** 2 + 1 ,))
245
259
for node in self .root .childrens :
@@ -251,6 +265,7 @@ def search(self, current_game, player, competitive=False):
251
265
break
252
266
253
267
self .root = self .root .childrens [idx ]
268
+ print (final_probas , final_move )
254
269
return final_probas , final_move
255
270
256
271
0 commit comments