Skip to content

Commit 8da47a7

Browse files
committed
migrating towards openAI gym's API
1 parent ac6d0a9 commit 8da47a7

File tree

6 files changed

+20
-389
lines changed

6 files changed

+20
-389
lines changed

source/AlphaCompile.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33

44
from PyTorchRL.agents.AlphaZero import AlphaZero
55
from source.config import FLAGS
6+
from source.programs import Programs
67

78
STEPS = 10
89

910

1011
class AlphaCompile(AlphaZero):
1112
def __init__(self, input_dim: int = 5, output_dim: int = len(FLAGS) + 1):
12-
hidden = 5
13+
self.hidden = 5
1314
body = nn.Sequential(
14-
nn.Linear(input_dim, hidden),
15-
nn.Dropout(0.5),
16-
nn.Linear(hidden, hidden),
15+
nn.Linear(input_dim, self.hidden),
16+
nn.Dropout(0.4),
17+
nn.Linear(self.hidden, self.hidden),
1718
)
1819
super().__init__(input_dim, output_dim, body)
1920

@@ -30,4 +31,6 @@ def train(programs):
3031

3132
if __name__ == "__main__":
3233
log_param("steps", STEPS)
33-
train([])
34+
programs = Programs()
35+
programs = programs.filter(programs[0])
36+

source/_ck.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
import ck
3+
4+
results = ck.compile

source/config.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
""" Settings file.
22
"""
3-
import os
4-
import sqlite3 as sql
5-
from datetime import datetime
6-
from decimal import Decimal
73
from enum import Enum
84

9-
sql.register_adapter(Decimal, lambda d: str(d))
10-
sql.register_converter("DEC", lambda s: Decimal(s.decode('utf-8')))
11-
125

136
class Features(Enum):
147
HYBRID = 0
@@ -17,10 +10,6 @@ class Features(Enum):
1710

1811

1912
EPOCHS = 100
20-
THREADED_RUNTIMES = False
21-
USE_RUNTIMES = True
22-
NOW = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
23-
RUN_DIR = os.path.abspath('./runs/run_{}'.format(NOW))
2413

2514
# noinspection SpellCheckingInspection
2615
FLAGS = [
@@ -34,44 +23,3 @@ class Features(Enum):
3423
]
3524
N_FLAGS = len(FLAGS)
3625

37-
ACTIONS = [(
38-
'-O3',
39-
*(f for f, s in zip(FLAGS, list(format(a, '0%sb' % N_FLAGS))) if s == '1')
40-
) for a in range(2 ** N_FLAGS)]
41-
42-
N_ACTIONS = len(ACTIONS)
43-
44-
# noinspection SpellCheckingInspection
45-
LOG_CONFIG = {
46-
'version': 1,
47-
'disable_existing_loggers': False,
48-
'formatters': {
49-
'simple': {
50-
'format': '%(levelname)s: %(filename)s: %(message)s',
51-
},
52-
'detailed': {
53-
'class': 'logging.Formatter',
54-
'format': '%(asctime)s, %(levelname)-6s, %(filename)-6s, %(funcName)s, %(message)s',
55-
}
56-
},
57-
'handlers': {
58-
'console': {
59-
'class': 'logging.StreamHandler',
60-
'level': 'INFO',
61-
'formatter': 'simple',
62-
},
63-
'event_file': {
64-
'class': 'logging.FileHandler',
65-
'level': 'INFO',
66-
'formatter': 'detailed',
67-
'filename': '{}/events.log'.format(RUN_DIR),
68-
},
69-
},
70-
'loggers': {
71-
'': {
72-
'handlers': ['console', 'event_file'],
73-
'level': 'DEBUG',
74-
'propogate': False,
75-
}
76-
}
77-
}

source/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import source.config as c
88
import source.utils as u
9-
from source.programs import Program
9+
from Benchmarks.program import Program
1010

1111

1212
class Metric:

0 commit comments

Comments
 (0)