Skip to content

Commit b051110

Browse files
committed
[STATS] Done, just need to run the script
1 parent 4897579 commit b051110

File tree

5 files changed

+45
-37
lines changed

5 files changed

+45
-37
lines changed

README.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,16 @@ Ongoing project.
5959
* [Monte Carlo tree search explaination](https://int8.io/monte-carlo-tree-search-beginners-guide/)
6060
* [Nice tree search implementation](https://github.com/blanyal/alpha-zero/blob/master/mcts.py)
6161

62-
# Statistics
62+
# Statistics, check branch stats
6363

64-
## For a 10 layers deep Resnet evaluated on 50 games 64 simulations
64+
## For a 10 layers deep Resnet
6565

6666
### 9x9 board
6767

68-
* 0.2377991s / move - 0.00371561093s / simulation 2 threads 2 batch_size_eval
69-
* 0.1624937s / move - 0.00253896406s / simulation 4 threads 4 batch_size_eval
70-
* 0.1465123s / move - 0.00228925468s / simulation 8 threads 8 batch_size_eval
71-
* 0.1401098s / move - 0.00218921563s / simulation 16 threads 16 batch_size_eval
68+
soon
7269

7370
### 19x19 board
7471

75-
* 0.6306054s / move - 0.012612108s / simulation with 2 threads and 2 batch_size_eval with 50 simulations
76-
7772
# Differences with the official paper
7873

7974
* No resignation

const.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
## Dtype of the tensors depending on CUDA
99
DEVICE = torch.device("cuda") if CUDA else torch.device("cpu")
1010
## Number of self-play parallel games
11-
PARALLEL_SELF_PLAY = 6
11+
PARALLEL_SELF_PLAY = 2
1212
## Number of evaluation parallel games
1313
PARALLEL_EVAL = 3
1414
## MCTS parallel
15-
MCTS_PARALLEL = 6
15+
MCTS_PARALLEL = 12
1616

1717

1818
##### GLOBAL
@@ -28,7 +28,7 @@
2828
## Learning rate
2929
LR = 0.01
3030
## Number of MCTS simulation
31-
MCTS_SIM = 32
31+
MCTS_SIM = 128
3232
## Exploration constant
3333
C_PUCT = 0.2
3434
## L2 Regularization

lib/game.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,12 @@ def __call__(self):
102102
done = False
103103
state = self.board.reset()
104104
dataset = []
105+
move_times = []
105106
moves = 0
106107
comp = False
108+
109+
# if self.id % 10 == 0:
110+
print("Starting game number %d" % self.id)
107111

108112
while not done:
109113

@@ -112,7 +116,7 @@ def __call__(self):
112116
reward = self.board.get_winner()
113117
if self.opponent:
114118
final_time = timeit.default_timer() - start_time
115-
return pickle.dumps([reward, moves, final_time])
119+
return pickle.dumps([reward, moves, move_times, final_time])
116120
return pickle.dumps((dataset, reward))
117121

118122
## Adaptative temperature to stop exploration
@@ -121,10 +125,17 @@ def __call__(self):
121125

122126
## For evaluation
123127
if self.opponent:
128+
play_time = timeit.default_timer()
124129
state, reward, done, _, action = self._play(_prepare_state(state), \
125130
self.player, self.opponent.passed, competitive=True)
131+
final_play_time = timeit.default_timer() - play_time
132+
move_times.append(final_play_time)
133+
134+
play_time = timeit.default_timer()
126135
state, reward, done, _, action = self._play(_prepare_state(state), \
127136
self.opponent, self.player.passed, competitive=True)
137+
final_play_time = timeit.default_timer() - play_time
138+
move_times.append(final_play_time)
128139
moves += 2
129140

130141
## For self-play
@@ -141,7 +152,7 @@ def __call__(self):
141152
## Pickle the result because multiprocessing
142153
if self.opponent:
143154
final_time = timeit.default_timer() - start_time
144-
return pickle.dumps([reward, moves, final_time])
155+
return pickle.dumps([reward, moves, move_times, final_time])
145156

146157
return pickle.dumps((dataset, reward))
147158

models/mcts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import torch
23
import threading
34
import time
45
import random

stats.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
from subprocess import call
1010

1111

12-
MCTS_PARALLELS = [2, 4, 6, 8, 16]
13-
MCTS_SIMS = [32, 64, 128, 160]
14-
BATCH_SIZE_EVALS = [2, 4, 6, 8]
15-
SAMPLE_NUM = 30
16-
12+
SAMPLE_NUM = 50
1713

1814
def overwrite_file(old_values, new_values):
1915
for idx, new_value in new_values.items():
@@ -29,6 +25,8 @@ def do_sims(player, old_values, mcts_parallel=2, mcts_sim=8, batch_size_eval=2):
2925
"BATCH_SIZE_EVAL": batch_size_eval
3026
}
3127
overwrite_file(old_values, new_values)
28+
print("-- STARTING FOR %d GAMES WITH MCTS PARALLEL %d SIMS %d BATCH_SIZE %d --"\
29+
% (SAMPLE_NUM, mcts_parallel, mcts_sim, batch_size_eval))
3230
queue, results = create_matches(player, cores=PARALLEL_SELF_PLAY,
3331
opponent=player, match_number=SAMPLE_NUM)
3432
moves = []
@@ -38,50 +36,54 @@ def do_sims(player, old_values, mcts_parallel=2, mcts_sim=8, batch_size_eval=2):
3836
for _ in range(SAMPLE_NUM):
3937
result = pickle.loads(results.get())
4038
moves.append(result[1])
41-
times.append(result[2])
39+
move_times = result[2]
40+
times.append(result[3])
4241
finally:
4342
queue.close()
4443
results.close()
45-
46-
print("-- FINAL RESULTS FOR %d GAMES WITH MCTS PARALLEL %d SIMS %d BATCH_SIZE %d --"\
47-
% (SAMPLE_NUM, mcts_parallel, mcts_sim, batch_size_eval))
48-
print("total game duration: %d seconds, total game move count: %d" \
49-
% (sum(times) / PARALLEL_SELF_PLAY, sum(moves)))
44+
print("-- RESULTS --")
45+
print("real total game duration: %.3f seconds, total game move count: %d" \
46+
% (sum(times), sum(moves)))
5047
print("average game duration: %.5f seconds, average game move count: %.1f" \
5148
% (np.mean(times), np.mean(moves)))
5249
print("average move duration: %.5f seconds, average sim duration: %.8f seconds" \
53-
% (sum(times) / sum(moves), sum(times) / (sum(moves) * mcts_sim)))
50+
% (np.mean(move_times), np.mean(move_times) / mcts_sim))
5451
print("-- DONE --\n")
5552
return new_values
5653

5754

5855
def stats_report():
5956
multiprocessing.set_start_method("spawn")
6057
player = Player()
61-
old_values = {
58+
first_values = {
6259
"MCTS_PARALLEL": MCTS_PARALLEL,
6360
"MCTS_SIM": MCTS_SIM,
6461
"BATCH_SIZE_EVAL": BATCH_SIZE_EVAL
6562
}
6663

67-
old_values = do_sims(player, old_values, mcts_parallel=2, mcts_sim=32, batch_size_eval=2)
68-
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=32, batch_size_eval=2)
69-
old_values = do_sims(player, old_values, mcts_parallel=6, mcts_sim=32, batch_size_eval=2)
70-
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=32, batch_size_eval=2)
71-
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=32, batch_size_eval=2)
72-
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=32, batch_size_eval=4)
73-
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=32, batch_size_eval=4)
74-
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=32, batch_size_eval=6)
64+
## 64 simulations
65+
old_values = do_sims(player, first_values, mcts_parallel=2, mcts_sim=64, batch_size_eval=2)
66+
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=64, batch_size_eval=2)
67+
old_values = do_sims(player, old_values, mcts_parallel=6, mcts_sim=64, batch_size_eval=2)
68+
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=64, batch_size_eval=2)
69+
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=64, batch_size_eval=2)
70+
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=64, batch_size_eval=4)
71+
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=64, batch_size_eval=4)
72+
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=64, batch_size_eval=6)
7573

76-
old_values = do_sims(player, old_values, mcts_parallel=2, mcts_sim=128, batch_size_eval=2)
74+
75+
## 128 simulations
76+
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=128, batch_size_eval=2)
7777
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=128, batch_size_eval=2)
7878
old_values = do_sims(player, old_values, mcts_parallel=6, mcts_sim=128, batch_size_eval=2)
79-
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=128, batch_size_eval=2)
79+
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=128, batch_size_eval=4)
8080
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=128, batch_size_eval=2)
8181
old_values = do_sims(player, old_values, mcts_parallel=8, mcts_sim=128, batch_size_eval=4)
8282
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=128, batch_size_eval=4)
8383
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=128, batch_size_eval=6)
8484

85+
86+
## 160 simulations
8587
old_values = do_sims(player, old_values, mcts_parallel=2, mcts_sim=160, batch_size_eval=2)
8688
old_values = do_sims(player, old_values, mcts_parallel=4, mcts_sim=160, batch_size_eval=2)
8789
old_values = do_sims(player, old_values, mcts_parallel=6, mcts_sim=160, batch_size_eval=2)
@@ -91,7 +93,6 @@ def stats_report():
9193
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=160, batch_size_eval=4)
9294
old_values = do_sims(player, old_values, mcts_parallel=12, mcts_sim=160, batch_size_eval=6)
9395

94-
9596

9697

9798
if __name__ == "__main__":

0 commit comments

Comments
 (0)