Skip to content

Commit f882e0a

Browse files
committed
[ALL] Cleaning code + changed multithreading to accept non multiples between BATCH_SIZE_EVAL and MCTS_PARALLEL
1 parent b051110 commit f882e0a

File tree

5 files changed

+30
-27
lines changed

5 files changed

+30
-27
lines changed

const.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
## Number of evaluation parallel games
1313
PARALLEL_EVAL = 3
1414
## MCTS parallel
15-
MCTS_PARALLEL = 12
15+
MCTS_PARALLEL = 4
1616

1717

1818
##### GLOBAL
@@ -28,7 +28,7 @@
2828
## Learning rate
2929
LR = 0.01
3030
## Number of MCTS simulation
31-
MCTS_SIM = 128
31+
MCTS_SIM = 64
3232
## Exploration constant
3333
C_PUCT = 0.2
3434
## L2 Regularization

lib/game.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def __call__(self):
106106
moves = 0
107107
comp = False
108108

109-
# if self.id % 10 == 0:
110-
print("Starting game number %d" % self.id)
109+
if self.id % 10 == 0:
110+
print("Starting game number %d" % self.id)
111111

112112
while not done:
113113

lib/go.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ def _act(self, action, history):
110110

111111

112112
def test_move(self, action):
113-
""" Test if a specific valid action should be played,
114-
depending on the current score """
113+
"""
114+
Test if a specific valid action should be played,
115+
depending on the current score. This is used to stop
116+
the agent from passing if it makes him loose
117+
"""
115118

116119
board_clone = self.board.clone()
117120
current_score = board_clone.fast_score + self.komi
@@ -167,7 +170,6 @@ def step(self, action):
167170

168171
# Reward: if nonterminal, then the reward is -1
169172
if not self.board.is_terminal:
170-
self.done = False
171173
return _format_state(self.history, self.player_color, self.board_size), \
172174
-1, False
173175

models/mcts.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
@jit
14-
def _opt_select(nodes, c_puct):
15-
""" Optimized version of the selection """
14+
def _opt_select(nodes, c_puct=C_PUCT):
15+
""" Optimized version of the selection based of the PUCT formula """
1616

1717
total_count = 0
1818
for i in range(nodes.shape[0]):
@@ -38,16 +38,6 @@ def dirichlet_noise(probas):
3838
return new_probas
3939

4040

41-
def _select(nodes, c_puct=C_PUCT):
42-
"""
43-
Select the move that maximises the mean value of the next state +
44-
the result of the PUCT function
45-
"""
46-
47-
return nodes[_opt_select(np.array([[node.q, node.n, node.p] \
48-
for node in nodes]), c_puct)]
49-
50-
5141
class Node:
5242

5343
def __init__(self, parent=None, proba=None, move=None):
@@ -105,19 +95,23 @@ def run(self):
10595

10696
## Wait for the eval_queue to be filled by new positions to evaluate
10797
self.condition_search.acquire()
108-
while len(self.eval_queue) < BATCH_SIZE_EVAL:
98+
while len(self.eval_queue) < BATCH_SIZE_EVAL and \
99+
len(self.eval_queue) != MCTS_PARALLEL - BATCH_SIZE_EVAL or \
100+
len(self.eval_queue) == 0:
109101
self.condition_search.wait()
110102
self.condition_search.release()
111103

112104
self.condition_eval.acquire()
105+
keys = list(self.eval_queue.keys())
106+
113107
## Predict the feature_maps, policy and value
114-
states = torch.tensor(np.array(list(self.eval_queue.values()))[0:BATCH_SIZE_EVAL],
115-
dtype=torch.float, device=DEVICE)
108+
states = torch.tensor(np.array(list(self.eval_queue.values()))[0:len(keys)],
109+
dtype=torch.float, device=DEVICE)
116110
v, probas = self.player.predict(states)
117111

118112
## Replace the state with the result in the eval_queue
119113
## and notify all the threads that the result are available
120-
for idx, i in zip(list(self.eval_queue.keys()), range(BATCH_SIZE_EVAL)):
114+
for idx, i in zip(keys, range(len(keys))):
121115
del self.eval_queue[idx]
122116
self.result_queue[idx] = (probas[i].cpu().data.numpy(), v[i])
123117

@@ -150,7 +144,10 @@ def run(self):
150144

151145
## Traverse the tree until leaf
152146
while not current_node.is_leaf() and not done:
153-
current_node = _select(current_node.childrens)
147+
## Select the action that maximizes the PUCT algorithm
148+
current_node = current_node.childrens[_opt_select( \
149+
np.array([[node.q, node.n, node.p] \
150+
for node in current_node.childrens]))]
154151

155152
## Virtual loss since multithreading
156153
self.lock.acquire()
@@ -161,7 +158,8 @@ def run(self):
161158

162159
if not done:
163160

164-
## Add current leaf state to the evaluation queue
161+
## Add current leaf state with random dihedral transformation
162+
## to the evaluation queue
165163
self.condition_search.acquire()
166164
self.eval_queue[self.thread_id] = sample_rotation(state, num=1)
167165
self.condition_search.notify()
@@ -243,7 +241,6 @@ def search(self, current_game, player, competitive=False):
243241
Search the best moves through the game tree with
244242
the policy and value network to update node statistics
245243
"""
246-
threads = []
247244

248245
## Locking for thread synchronization
249246
condition_eval = threading.Condition()
@@ -256,6 +253,7 @@ def search(self, current_game, player, competitive=False):
256253
evaluator = EvaluatorThread(player, eval_queue, result_queue, condition_search, condition_eval)
257254
evaluator.start()
258255

256+
threads = []
259257
## Do exactly the required number of simulation per thread
260258
for sim in range(MCTS_SIM // MCTS_PARALLEL):
261259
for idx in range(MCTS_PARALLEL):

stats.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,20 @@ def stats_report():
6767
old_values = do_sims(player, old_values, mcts_parallel=6, mcts_sim=64, batch_size_eval=2)
6868
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=64, batch_size_eval=2)
6969
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=64, batch_size_eval=2)
70+
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=64, batch_size_eval=4)
71+
old_values = do_sims(player, old_values, mcts_parallel=6, mcts_sim=64, batch_size_eval=2)
7072
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=64, batch_size_eval=4)
7173
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=64, batch_size_eval=4)
7274
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=64, batch_size_eval=6)
7375

7476

7577
## 128 simulations
76-
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=128, batch_size_eval=2)
78+
old_values = do_sims(player, old_values, mcts_parallel=2, mcts_sim=128, batch_size_eval=2)
7779
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=128, batch_size_eval=2)
7880
old_values = do_sims(player, old_values, mcts_parallel=6, mcts_sim=128, batch_size_eval=2)
7981
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=128, batch_size_eval=4)
8082
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=128, batch_size_eval=2)
83+
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=128, batch_size_eval=4)
8184
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=128, batch_size_eval=4)
8285
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=128, batch_size_eval=4)
8386
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=128, batch_size_eval=6)

0 commit comments

Comments
 (0)