Skip to content

Commit e335eb2

Browse files
committed
[MCTS] Really really bad threading, working on it
1 parent ba430fd commit e335eb2

File tree

7 files changed

+83
-69
lines changed

7 files changed

+83
-69
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ Ongoing project.
7070

7171
* 0.1947162s / move - 0.003894324s / simulation with 2 threads and 2 batch_size_eval with 50 simulations
7272
* 0.1360865s / move - 0.00272173s / simulation 4 threads 4 batch_size_eval 50 simulations
73-
* 0.1222489s / move - 0.002444978s / simulation 8 threads 8 batch_size eval 50 simulations
73+
* 0.1222489s / move - 0.002444978s / simulation 8 threads 8 batch_size_eval 50 simulations
74+
* 0.1372498 / move - 0.00214452812s / simulations 16 threads 16 batch_size_eval 64 simulations
7475

7576
### 19x19 board
7677

const.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@
1313
## Number of evaluation parallel games
1414
PARALLEL_EVAL = 2
1515
## MCTS parallel
16-
MCTS_PARALLEL = 2
16+
MCTS_PARALLEL = 16
1717

1818

1919
##### GLOBAL
2020

2121
## Size of the Go board
2222
GOBAN_SIZE = 9
2323
## Number of move to end a game
24-
MOVE_LIMIT = GOBAN_SIZE ** 2
24+
MOVE_LIMIT = GOBAN_SIZE ** 2 * 2.5
2525
## Number of last states to keep
2626
HISTORY = 7
2727
## Learning rate
2828
LR = 0.01
2929
## Number of MCTS simulation
30-
MCTS_SIM = 5
30+
MCTS_SIM = 64
3131
## Exploration constant
3232
C_PUCT = 0.2
3333
## L2 Regularization
@@ -41,7 +41,7 @@
4141
## Alpha for Dirichlet noise
4242
ALPHA = 0.03
4343
## Batch size for evaluation during MCTS
44-
BATCH_SIZE_EVAL = 2
44+
BATCH_SIZE_EVAL = 4
4545
## Number of self-play before training
4646
SELF_PLAY_MATCH = 2 * PARALLEL_SELF_PLAY
4747

lib/play.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,20 @@ def self_play(current_time, loaded_version):
8686
queue, results = create_matches(player , \
8787
cores=PARALLEL_SELF_PLAY, match_number=SELF_PLAY_MATCH)
8888
print("[PLAY] Starting to fetch fresh games")
89-
queue.join()
90-
for _ in range(SELF_PLAY_MATCH):
91-
result = results.get()
92-
if result:
93-
collection.insert({
94-
"game": result,
95-
"id": game_id
96-
})
97-
game_id += 1
98-
print("[PLAY] Done fetching")
99-
queue.close()
100-
results.close()
89+
try:
90+
queue.join()
91+
for _ in range(SELF_PLAY_MATCH):
92+
result = results.get()
93+
if result:
94+
collection.insert({
95+
"game": result,
96+
"id": game_id
97+
})
98+
game_id += 1
99+
print("[PLAY] Done fetching")
100+
finally:
101+
queue.close()
102+
results.close()
101103

102104

103105
def play(player, opponent):
@@ -257,9 +259,9 @@ def __call__(self):
257259
while not done:
258260
## Prevent cycling in 2 atari situations
259261
if moves > MOVE_LIMIT:
260-
print("cc")
261262
return pickle.dumps((dataset, self.board.get_winner()))
262263

264+
## Magic ratio for adaptative temperature
263265
if moves > MOVE_LIMIT / 24:
264266
comp = True
265267

@@ -287,6 +289,7 @@ def __call__(self):
287289
print("[EVALUATION] Match %d done in eval" % self.id)
288290
self.opponent.passed = False
289291
return pickle.dumps([reward])
292+
290293
self.player.passed = False
291294
return pickle.dumps((dataset, reward))
292295

lib/train.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def new_agent(result):
171171
print("[EVALUATION] New best player saved !")
172172
else:
173173
nonlocal last_id
174+
## Force a new fetch in case the player didnt improve
174175
last_id = fetch_new_games(collection, dataset, last_id)
175176

176177
## Wait before the circular before is full
@@ -186,12 +187,10 @@ def new_agent(result):
186187
while True:
187188
batch_loss = []
188189
for batch_idx, (state, move, winner) in enumerate(dataloader):
189-
190190
running_loss = []
191-
## Force the network to stop training the current network
192-
## since the new one is better (from the callback)
193-
194191
lr, optimizer = update_lr(lr, optimizer, total_ite)
192+
193+
## Evaluate a copy of the current network asynchronously
195194
if total_ite % TRAIN_STEPS == 0:
196195
pending_player = deepcopy(player)
197196
last_id = fetch_new_games(collection, dataset, last_id)
@@ -206,7 +205,6 @@ def new_agent(result):
206205
pool.apply_async(evaluate, args=(pending_player, best_player), \
207206
callback=new_agent)
208207
except Exception as e:
209-
print(e)
210208
client.close()
211209
pool.terminate()
212210

lib/utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def get_player(current_time, version):
8181
return player, checkpoint
8282

8383

84-
# @profile
8584
def sample_rotation(state, num=8):
85+
""" Apply a certain number of random transformation to the input state """
86+
8687
dh_group = [(None, None), ((np.rot90, 1), None), ((np.rot90, 2), None),
8788
((np.rot90, 3), None), (np.fliplr, None), (np.flipud, None),
8889
(np.flipud, (np.rot90, 1)), (np.fliplr, (np.rot90, 1))]
@@ -109,15 +110,10 @@ def sample_rotation(state, num=8):
109110

110111

111112
def formate_state(state, probas, winner):
113+
""" Repeat the probas and the winner to make every example identical after
114+
the dihedral rotation have been applied """
115+
112116
probas = np.reshape(probas, (1, probas.shape[0]))
113117
probas = np.repeat(probas, 8, axis=0)
114118
winner = np.full((8, 1), winner)
115119
return state, probas, winner
116-
117-
118-
if __name__ == "__main__":
119-
pass
120-
121-
122-
123-

main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def main(folder, version):
3232
train_proc = pool.apply_async(train, args=(current_time, version,))
3333

3434
## Comment one line or the other to get the stack trace
35+
## Must add a loooooong timer otherwise signals are not caught
3536
self_play_proc.get(60000000)
3637
train_proc.get(60000000)
3738

models/mcts.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from lib.utils import _prepare_state, sample_rotation
99

1010

11-
class Node():
12-
11+
class Node:
1312
def __init__(self, parent=None, proba=None, move=None):
1413
"""
1514
p : probability of reaching that node, given by the policy net
@@ -29,19 +28,14 @@ def __init__(self, parent=None, proba=None, move=None):
2928
def update(self, v):
3029
""" Update the node statistics after a playout """
3130

32-
self.w = self.w + float(v)
33-
if self.n > 0:
34-
self.q = self.w / self.n
35-
else:
36-
self.q = 0
31+
self.w = self.w + v
32+
self.q = self.w / self.n if self.n > 0 else 0
3733

3834

3935
def is_leaf(self):
4036
""" Check whether node is a leaf or not """
4137

42-
if self.childrens and len(self.childrens) > 0:
43-
return False
44-
return True
38+
return len(self.childrens) == 0
4539

4640

4741
def expand(self, probas):
@@ -86,55 +80,63 @@ def _opt_select(nodes, c_puct):
8680
return equals[0]
8781

8882

89-
9083
class EvaluatorThread(threading.Thread):
91-
def __init__(self, player, eval_queue, condition_search, condition_eval):
84+
def __init__(self, player, eval_queue, result_queue, condition_search, condition_eval, condition_mix):
9285
threading.Thread.__init__(self)
9386
self.eval_queue = eval_queue
87+
self.result_queue = result_queue
9488
self.player = player
9589
self.condition_search = condition_search
9690
self.condition_eval = condition_eval
91+
self.condition_mix = condition_mix
9792

9893
def run(self):
9994
total_sim = MCTS_SIM
10095
while total_sim > 0:
101-
self.condition_search.acquire()
102-
while (len(self.eval_queue.values()) != BATCH_SIZE_EVAL or \
103-
(len(self.eval_queue.values()) < BATCH_SIZE_EVAL and \
104-
len(self.eval_queue.values()) != total_sim)) or \
105-
(len(self.eval_queue.values()) == BATCH_SIZE_EVAL and \
106-
not all(isinstance(state, np.ndarray) for state in self.eval_queue.values())):
107-
self.condition_search.wait()
108-
109-
self.condition_search.release()
11096
self.condition_eval.acquire()
97+
while (len(self.eval_queue.values()) < BATCH_SIZE_EVAL or \
98+
len(self.result_queue.values()) > 0):
99+
print("notifying in evaluator, current len: %d" % len(self.eval_queue.values()))
100+
self.condition_eval.wait()
101+
self.condition_eval.release()
102+
111103
states = []
112104
for idx, state in self.eval_queue.items():
113105
states.append(sample_rotation(state, num=1))
106+
print('states len: %d' % len(states))
114107
states = _prepare_state(states)
115108
feature_maps = self.player.extractor(states[0])
116109

117110
## Policy and value prediction
118111
probas = self.player.policy_net(feature_maps)
119112
v = self.player.value_net(feature_maps)
120-
for idx in range(BATCH_SIZE_EVAL):
121-
self.eval_queue[idx] = (probas[idx].cpu().data.numpy(), v[idx])
122-
self.condition_eval.notifyAll()
123-
self.condition_eval.release()
113+
keys = list(self.eval_queue.keys())
114+
idx = 0
115+
for key in keys:
116+
self.result_queue[key] = (probas[idx].cpu().data.numpy(), float(v[idx]))
117+
del self.eval_queue[key]
118+
idx += 1
119+
del probas, v, feature_maps
120+
self.condition_mix.acquire()
121+
self.condition_mix.notifyAll()
122+
self.condition_mix.release()
124123
total_sim -= BATCH_SIZE_EVAL
125124

126125

127126

128127
class SearchThread(threading.Thread):
129-
def __init__(self, mcts, game, eval_queue, thread_id, lock, condition_search, condition_eval):
128+
def __init__(self, mcts, game, eval_queue, result_queue, thread_id, \
129+
lock, condition_search, condition_eval, condition_mix):
130130
threading.Thread.__init__(self)
131+
self.result_queue = result_queue
131132
self.eval_queue = eval_queue
132133
self.mcts = mcts
133134
self.game = game
134135
self.lock = lock
135136
self.thread_id = thread_id
136137
self.condition_eval = condition_eval
137138
self.condition_search = condition_search
139+
self.condition_mix = condition_mix
138140

139141

140142
def run(self):
@@ -153,18 +155,25 @@ def run(self):
153155
## Predict the probas
154156
if not done:
155157
self.condition_search.acquire()
156-
self.eval_queue[self.thread_id] = state
157-
self.condition_search.notify()
158-
158+
while len(self.eval_queue.values()) < BATCH_SIZE_EVAL and \
159+
len(self.result_queue.values()) == 0:
160+
print("trying to release in thread id %d" % self.thread_id)
161+
self.condition_search.wait()
162+
print("added move in thread %d" % self.thread_id)
159163
self.condition_search.release()
164+
165+
self.eval_queue[self.thread_id] = state
160166
self.condition_eval.acquire()
161-
self.condition_eval.wait()
162-
163-
res = self.eval_queue[self.thread_id]
164-
probas = np.array(res[0])
165-
v = float(res[1])
167+
self.condition_eval.notify()
166168
self.condition_eval.release()
167169

170+
self.condition_mix.acquire()
171+
self.condition_mix.wait()
172+
self.condition_mix.release()
173+
174+
probas = np.array(self.result_queue[self.thread_id][0], copy=True)
175+
v = float(self.result_queue[self.thread_id][1])
176+
del self.result_queue[self.thread_id]
168177

169178
## Add noise in the root node
170179
if not current_node.parent:
@@ -186,6 +195,7 @@ def run(self):
186195
current_node.update(v)
187196
current_node = current_node.parent
188197
self.lock.release()
198+
print("done in thread id %d" % self.thread_id)
189199

190200

191201
class MCTS:
@@ -228,18 +238,22 @@ def search(self, current_game, player, competitive=False):
228238
threads = []
229239
condition_eval = threading.Condition()
230240
condition_search = threading.Condition()
241+
condition_mix = threading.Condition()
231242
lock = threading.Lock()
232243
eval_queue = {}
233-
evaluator = EvaluatorThread(player, eval_queue, condition_search, condition_eval)
244+
result_queue = {}
245+
evaluator = EvaluatorThread(player, eval_queue, result_queue, condition_search, \
246+
condition_eval, condition_mix)
234247
evaluator.start()
235248
for sim in range(MCTS_SIM // MCTS_PARALLEL):
236249
eval_queue.clear()
237250
for idx in range(MCTS_PARALLEL):
238-
threads.append(SearchThread(self, current_game, eval_queue, idx,
239-
lock, condition_search, condition_eval))
251+
threads.append(SearchThread(self, current_game, eval_queue, result_queue, idx,
252+
lock, condition_search, condition_eval, condition_mix))
240253
threads[-1].start()
241254
for thread in threads:
242255
thread.join()
256+
evaluator.join()
243257

244258
action_scores = np.zeros((current_game.board_size ** 2 + 1,))
245259
for node in self.root.childrens:
@@ -251,6 +265,7 @@ def search(self, current_game, player, competitive=False):
251265
break
252266

253267
self.root = self.root.childrens[idx]
268+
print(final_probas, final_move)
254269
return final_probas, final_move
255270

256271

0 commit comments

Comments
 (0)