Skip to content

Commit 3860bbe

Browse files
committed
[ALL] Migrated to pytorch 0.5, minor fixes and reverted to old threading on MCTS
1 parent e335eb2 commit 3860bbe

File tree

10 files changed

+60
-83
lines changed

10 files changed

+60
-83
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ Ongoing project.
6464

6565
# Statistics
6666

67-
## For a 10 layers deep Resnet evaluated on 5 games
67+
## For a 10 layers deep Resnet evaluated on 50 games 64 simulations
6868

6969
### 9x9 board
7070

71-
* 0.1947162s / move - 0.003894324s / simulation with 2 threads and 2 batch_size_eval with 50 simulations
72-
* 0.1360865s / move - 0.00272173s / simulation 4 threads 4 batch_size_eval 50 simulations
73-
* 0.1222489s / move - 0.002444978s / simulation 8 threads 8 batch_size_eval 50 simulations
74-
* 0.1372498 / move - 0.00214452812s / simulations 16 threads 16 batch_size_eval 64 simulations
71+
* 0.2377991s / move - 0.00371561093s / simulation 2 threads 2 batch_size_eval
72+
* 0.1624937s / move - 0.00253896406s / simulation 4 threads 4 batch_size_eval
73+
* 0.1465123s / move - 0.00228925468s / simulation 8 threads 8 batch_size_eval
74+
* 0.1401098s / move - 0.00218921563s / simulation 16 threads 16 batch_size_eval
7575

7676
### 19x19 board
7777

const.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
## CUDA variable from Torch
77
CUDA = torch.cuda.is_available()
88
## Dtype of the tensors depending on CUDA
9-
DTYPE_FLOAT = torch.cuda.FloatTensor if CUDA else torch.FloatTensor
10-
DTYPE_LONG = torch.cuda.LongTensor if CUDA else torch.LongTensor
9+
DEVICE = torch.device("cuda") if CUDA else torch.device("cpu")
1110
## Number of self-play parallel games
12-
PARALLEL_SELF_PLAY = 2
11+
PARALLEL_SELF_PLAY = 3
1312
## Number of evaluation parallel games
1413
PARALLEL_EVAL = 2
1514
## MCTS parallel
16-
MCTS_PARALLEL = 16
15+
MCTS_PARALLEL = 4
1716

1817

1918
##### GLOBAL
@@ -61,13 +60,13 @@
6160
## Number of residual blocks
6261
BLOCKS = 10
6362
## Number of training step before evaluating
64-
TRAIN_STEPS = 7 * BATCH_SIZE
63+
TRAIN_STEPS = 6 * BATCH_SIZE
6564
## Optimizer
6665
ADAM = False
6766
## Learning rate annealing factor
6867
LR_DECAY = 0.1
6968
## Learning rate annnealing interval
70-
LR_DECAY_ITE = 50 * TRAIN_STEPS
69+
LR_DECAY_ITE = 100 * TRAIN_STEPS
7170
## Print the loss
7271
LOSS_TICK = BATCH_SIZE // 4
7372
## Refresh the dataset
@@ -78,6 +77,6 @@
7877

7978
## Number of matches against its old version to evaluate
8079
## the newly trained network
81-
EVAL_MATCHS = 21
80+
EVAL_MATCHS = 20
8281
## Threshold to keep the new neural net
8382
EVAL_THRESH = 0.55

lib/evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def evaluate(player, new_player):
1111
black_wins = 0
1212
white_wins = 0
1313
for result in results:
14-
if result[0] == 0:
14+
if result[0] == 1:
1515
white_wins += 1
1616
else:
1717
black_wins += 1
@@ -20,4 +20,4 @@ def evaluate(player, new_player):
2020
% (black_wins, white_wins))
2121
if black_wins >= EVAL_THRESH * len(results):
2222
return True
23-
return False
23+
return False

lib/gtp.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,8 @@ def __init__(self, game, komi=7.5, board_size=19, version="0.2", name="AlphaGo")
129129
def send(self, message):
130130
message_id, command, arguments = parse_message(message)
131131
if command in self.known_commands:
132-
# try:
133132
return format_success(
134133
message_id, getattr(self, "cmd_" + command)(arguments))
135-
# except ValueError as exception:
136-
# return format_error(message_id, exception.args[0])
137134
else:
138135
return format_error(message_id, "unknown command")
139136

lib/play.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def run(self):
153153
self.game_queue.task_done()
154154
self.result_queue.put(answer)
155155
except Exception as e:
156-
print("xd")
156+
print("Game has thrown an error")
157157

158158

159159

@@ -195,8 +195,6 @@ def _get_move(self, board, probas):
195195
""" Select a move without MCTS """
196196

197197
player_move = None
198-
valid_move = False
199-
can_pass = False
200198
legal_moves = board.get_legal_moves()
201199

202200
while player_move not in legal_moves and len(legal_moves) > 0:
@@ -208,7 +206,6 @@ def _get_move(self, board, probas):
208206

209207
return player_move
210208

211-
# @profile
212209
def _play(self, state, player, other_pass, competitive=False):
213210
""" Choose a move depending on MCTS or not """
214211

@@ -286,7 +283,7 @@ def __call__(self):
286283

287284
## Pickle the result because multiprocessing
288285
if self.opponent:
289-
print("[EVALUATION] Match %d done in eval" % self.id)
286+
print("[EVALUATION] Match %d done in eval, winner %s" % (self.id, "black" if reward == 0 else "white"))
290287
self.opponent.passed = False
291288
return pickle.dumps([reward])
292289

lib/train.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self):
2929
def forward(self, winner, self_play_winner, probas, self_play_probas):
3030
value_error = F.mse_loss(winner, self_play_winner)
3131
policy_error = F.kl_div(probas, self_play_probas)
32+
# policy_error = torch.sum(-probas * torch.log(self_play_probas))
3233
return value_error + policy_error
3334

3435

@@ -104,9 +105,9 @@ def collate_fn(example):
104105
state.extend(ex[0])
105106
probas.extend(ex[1])
106107
winner.extend(ex[2])
107-
state = torch.tensor(state).type(DTYPE_FLOAT)
108-
probas = torch.tensor(probas).type(DTYPE_FLOAT)
109-
winner = torch.tensor(winner).type(DTYPE_FLOAT)
108+
state = torch.tensor(state, dtype=torch.float, device=DEVICE)
109+
probas = torch.tensor(probas, dtype=torch.float, device=DEVICE)
110+
winner = torch.tensor(winner, dtype=torch.float, device=DEVICE)
110111
return state, probas, winner
111112

112113

@@ -151,7 +152,7 @@ def train(current_time, loaded_version):
151152
total_ite = checkpoint['total_ite']
152153
lr = checkpoint['lr']
153154
version = checkpoint['version']
154-
last_id = collection.find().count() - 120
155+
last_id = collection.find().count() - (MOVES // MOVE_LIMIT) * 2
155156
else:
156157
player = Player()
157158
optimizer = create_optimizer(player, lr)
@@ -209,9 +210,9 @@ def new_agent(result):
209210
pool.terminate()
210211

211212
example = {
212-
'state': Variable(state).type(DTYPE_FLOAT),
213-
'winner': Variable(winner).type(DTYPE_FLOAT),
214-
'move' : Variable(move).type(DTYPE_FLOAT)
213+
'state': state,
214+
'winner': winner,
215+
'move' : move
215216
}
216217
loss = train_epoch(player, optimizer, example, criterion)
217218
running_loss.append(loss)

lib/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _prepare_state(state):
1515
"""
1616

1717
x = torch.from_numpy(np.array([state]))
18-
x = Variable(x).type(DTYPE_FLOAT)
18+
x = torch.tensor(x, dtype=torch.float, device=DEVICE)
1919
return x
2020

2121

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def main(folder, version):
3333

3434
## Comment one line or the other to get the stack trace
3535
## Must add a loooooong timer otherwise signals are not caught
36-
self_play_proc.get(60000000)
36+
# self_play_proc.get(60000000)
3737
train_proc.get(60000000)
3838

3939
except KeyboardInterrupt:

models/mcts.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -81,62 +81,53 @@ def _opt_select(nodes, c_puct):
8181

8282

8383
class EvaluatorThread(threading.Thread):
84-
def __init__(self, player, eval_queue, result_queue, condition_search, condition_eval, condition_mix):
84+
def __init__(self, player, eval_queue, condition_search, condition_eval):
8585
threading.Thread.__init__(self)
8686
self.eval_queue = eval_queue
87-
self.result_queue = result_queue
8887
self.player = player
8988
self.condition_search = condition_search
9089
self.condition_eval = condition_eval
91-
self.condition_mix = condition_mix
9290

9391
def run(self):
9492
total_sim = MCTS_SIM
9593
while total_sim > 0:
96-
self.condition_eval.acquire()
97-
while (len(self.eval_queue.values()) < BATCH_SIZE_EVAL or \
98-
len(self.result_queue.values()) > 0):
99-
print("notifying in evaluator, current len: %d" % len(self.eval_queue.values()))
100-
self.condition_eval.wait()
101-
self.condition_eval.release()
94+
self.condition_search.acquire()
95+
while (len(self.eval_queue.values()) != BATCH_SIZE_EVAL or \
96+
(len(self.eval_queue.values()) < BATCH_SIZE_EVAL and \
97+
len(self.eval_queue.values()) != total_sim)) or \
98+
(len(self.eval_queue.values()) == BATCH_SIZE_EVAL and \
99+
not all(isinstance(state, np.ndarray) for state in self.eval_queue.values())):
100+
self.condition_search.wait()
102101

102+
self.condition_search.release()
103+
self.condition_eval.acquire()
103104
states = []
104105
for idx, state in self.eval_queue.items():
105106
states.append(sample_rotation(state, num=1))
106-
print('states len: %d' % len(states))
107107
states = _prepare_state(states)
108108
feature_maps = self.player.extractor(states[0])
109109

110110
## Policy and value prediction
111111
probas = self.player.policy_net(feature_maps)
112112
v = self.player.value_net(feature_maps)
113-
keys = list(self.eval_queue.keys())
114-
idx = 0
115-
for key in keys:
116-
self.result_queue[key] = (probas[idx].cpu().data.numpy(), float(v[idx]))
117-
del self.eval_queue[key]
118-
idx += 1
119-
del probas, v, feature_maps
120-
self.condition_mix.acquire()
121-
self.condition_mix.notifyAll()
122-
self.condition_mix.release()
113+
for idx in range(BATCH_SIZE_EVAL):
114+
self.eval_queue[idx] = (probas[idx].cpu().data.numpy(), v[idx])
115+
self.condition_eval.notifyAll()
116+
self.condition_eval.release()
123117
total_sim -= BATCH_SIZE_EVAL
124118

125119

126120

127121
class SearchThread(threading.Thread):
128-
def __init__(self, mcts, game, eval_queue, result_queue, thread_id, \
129-
lock, condition_search, condition_eval, condition_mix):
122+
def __init__(self, mcts, game, eval_queue, thread_id, lock, condition_search, condition_eval):
130123
threading.Thread.__init__(self)
131-
self.result_queue = result_queue
132124
self.eval_queue = eval_queue
133125
self.mcts = mcts
134126
self.game = game
135127
self.lock = lock
136128
self.thread_id = thread_id
137129
self.condition_eval = condition_eval
138130
self.condition_search = condition_search
139-
self.condition_mix = condition_mix
140131

141132

142133
def run(self):
@@ -155,25 +146,18 @@ def run(self):
155146
## Predict the probas
156147
if not done:
157148
self.condition_search.acquire()
158-
while len(self.eval_queue.values()) < BATCH_SIZE_EVAL and \
159-
len(self.result_queue.values()) == 0:
160-
print("trying to release in thread id %d" % self.thread_id)
161-
self.condition_search.wait()
162-
print("added move in thread %d" % self.thread_id)
163-
self.condition_search.release()
164-
165149
self.eval_queue[self.thread_id] = state
150+
self.condition_search.notify()
151+
152+
self.condition_search.release()
166153
self.condition_eval.acquire()
167-
self.condition_eval.notify()
168-
self.condition_eval.release()
154+
self.condition_eval.wait()
169155

170-
self.condition_mix.acquire()
171-
self.condition_mix.wait()
172-
self.condition_mix.release()
156+
res = self.eval_queue[self.thread_id]
157+
probas = np.array(res[0])
158+
v = float(res[1])
159+
self.condition_eval.release()
173160

174-
probas = np.array(self.result_queue[self.thread_id][0], copy=True)
175-
v = float(self.result_queue[self.thread_id][1])
176-
del self.result_queue[self.thread_id]
177161

178162
## Add noise in the root node
179163
if not current_node.parent:
@@ -195,7 +179,7 @@ def run(self):
195179
current_node.update(v)
196180
current_node = current_node.parent
197181
self.lock.release()
198-
print("done in thread id %d" % self.thread_id)
182+
199183

200184

201185
class MCTS:
@@ -238,18 +222,15 @@ def search(self, current_game, player, competitive=False):
238222
threads = []
239223
condition_eval = threading.Condition()
240224
condition_search = threading.Condition()
241-
condition_mix = threading.Condition()
242225
lock = threading.Lock()
243226
eval_queue = {}
244-
result_queue = {}
245-
evaluator = EvaluatorThread(player, eval_queue, result_queue, condition_search, \
246-
condition_eval, condition_mix)
227+
evaluator = EvaluatorThread(player, eval_queue, condition_search, condition_eval)
247228
evaluator.start()
248229
for sim in range(MCTS_SIM // MCTS_PARALLEL):
249230
eval_queue.clear()
250231
for idx in range(MCTS_PARALLEL):
251-
threads.append(SearchThread(self, current_game, eval_queue, result_queue, idx,
252-
lock, condition_search, condition_eval, condition_mix))
232+
threads.append(SearchThread(self, current_game, eval_queue, idx,
233+
lock, condition_search, condition_eval))
253234
threads[-1].start()
254235
for thread in threads:
255236
thread.join()
@@ -265,7 +246,6 @@ def search(self, current_game, player, competitive=False):
265246
break
266247

267248
self.root = self.root.childrens[idx]
268-
print(final_probas, final_move)
269249
return final_probas, final_move
270250

271251

viewer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from pymongo import MongoClient
1010

1111

12-
def game_to_gtp(game, game_id, collection_name):
12+
def game_to_gtp(game, game_id, collection_name, color):
1313
""" Take a game from the database and convert it to send GTP instructions """
1414

1515
board_size = int(np.sqrt(len(game[0][0][1]) - 1))
1616
moves = np.array(game[0])[:,3]
17-
move_count = 0
17+
move_count = 0 if color == 0 else 1
1818
game_winner = game[1]
1919

2020
## Wait for input
@@ -32,20 +32,23 @@ def game_to_gtp(game, game_id, collection_name):
3232
else:
3333
print(format_success(None, response="{}{}".format("ABCDEFGHJKLMNOPQRSTYVWYZ"\
3434
[int(move % board_size)], int(board_size - move // board_size))))
35-
move_count += 1
35+
move_count += 2
3636
else:
3737
print('?name %s ???\n\n' % (command))
3838
elif "name" in command:
3939
print(format_success(None, response="folder {}, game id: {}, winner: {}"\
4040
.format(collection_name, game_id, game_winner)))
41+
elif "play" in command:
42+
print(format_success(message_id, ""))
4143
else:
4244
print('?name %s ???\n\n' % (command))
4345

4446

4547
@click.command()
4648
@click.option("--folder", default=-1)
4749
@click.option("--game_id", default=-1)
48-
def main(folder, game_id):
50+
@click.option("--color", default=1)
51+
def main(folder, game_id, color):
4952
## Init Mongo
5053
client = MongoClient()
5154
db = client.superGo
@@ -71,7 +74,7 @@ def main(folder, game_id):
7174
else:
7275
for game in last_game:
7376
final_game = pickle.loads(game['game'])
74-
game_to_gtp(final_game, game['id'], collection)
77+
game_to_gtp(final_game, game['id'], collection, color)
7578
break
7679

7780
if __name__ == "__main__":

0 commit comments

Comments
 (0)