Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 882763d

Browse files
author
Hao-Ting Wang
authoredSep 15, 2024
summer sprint ENLSP 2024 (#14)
* expose more model parameters * FIX torchinfo save out * FIX torchinfo save out * update key module * architecture scaling * update default * separate script to create data * use proportion of sample instead of number of sample * update the main library * combined holdout set creation and label creation * Remove range limit for testing * increase default runtime for creating labels * Update training script to fit the new data * update module * EHN resource management - ADD batch size finding script - modify parameters accordingly * update tools for resource benchmark * full script for number of worker estimate * ENLSP 2024
1 parent 4e8285e commit 882763d

31 files changed

+1758
-515
lines changed
 

‎config/base.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
---
22
defaults:
33
- _self_
4-
- hydra: default
4+
- hydra: make_data
55

66
verbose: 2
7-
random_state: 42
7+
random_state: 1
88
return_type: float

‎config/data/default.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
---
22
standardize: false
33
n_embed: 197
4+
time_stride: 1
5+
lag: 1
6+
seq_length: 16
47
atlas_desc: atlas-MIST_desc-${data.n_embed}
58
hold_out_set: 0.20
69
validation_set: 0.25
7-
n_sample: -1
10+
proportion_sample: 1.0
811
class_balance_confounds:
912
- site
1013
- sex

‎config/extract.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ defaults:
33
- _self_
44
- hydra: extract
55

6-
horizon: 1
7-
convlayer_index: -1
6+
random_state: 435
7+
horizon: 6
8+
convlayer_index: -99
89
# passing model path is necessary for evaluation
910
model_path: ???

‎config/hydra/hyperparameters.yaml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ sweeper:
2121
# default parametrization of the search space
2222
params:
2323
model:
24-
nb_epochs: uniform(16, 24, discrete=True)
25-
FK: choices(["\'128,32,128,32,128,32,128,32,128,32,128,32\'", "\'8,6,8,6,8,6,8,6,8,6,8,6\'", "\'8,3,8,3,8,3\'"])
26-
M: choices(["\'32,16,8,1\'", "\'16,8,1\'"])
27-
lr: uniform(1e-4, 0.3)
28-
lr_thres: uniform(1e-6, 1)
29-
dropout: uniform(1e-4, 0.3)
30-
batch_size: loguniform(128, 256, discrete=True)
31-
seq_length: uniform(12, 32, discrete=True)
24+
lr: loguniform(1e-5, 1e-2)
25+
weight_decay: uniform(1e-6, 1e-4)
26+
lr_thres: loguniform(1e-5, 1e-3)
27+
lr_patience: choices([4, 5, 6])
28+
dropout: uniform(0, 0.5)
29+
bn_momentum: uniform(0, 0.99)
30+
GCL: choices([3,6,12])
31+
F: choices([8,16,32,64])
32+
K: choices([3,6,9])
3233

3334
experiment:
3435
name: experiment
@@ -55,7 +56,7 @@ sweeper:
5556

5657
worker:
5758
n_workers: -1
58-
max_broken: 10
59+
max_broken: 20
5960
max_trials: 100
6061

6162
storage:
@@ -67,11 +68,11 @@ sweeper:
6768

6869
launcher:
6970
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
70-
timeout_min: 720
71-
cpus_per_task: 4
71+
timeout_min: 180
72+
cpus_per_task: 5
7273
gpus_per_node: 1
7374
tasks_per_node: 1
74-
mem_gb: 4
75+
mem_gb: 8
7576
nodes: 1
7677
name: ${hydra.job.name}
7778
stderr_to_stdout: false
@@ -91,4 +92,5 @@ launcher:
9192
max_num_timeout: 0
9293
additional_parameters: {mail-user: '${oc.env:SLACK_EMAIL_BOT}', mail-type: ALL}
9394
array_parallelism: 256
94-
setup:
95+
setup: [export HYDRA_FULL_ERROR=1, export NCCL_DEBUG=INFO, 'rsync -tv --info=progress2 /lustre03/project/6003287/hwang1/rs-autoregression-prediction/outputs/sample_for_pretraining/seed-${random_state}/./sample_seed-${random_state}_data-train.h5
96+
$SLURM_TMPDIR/data_$SLURM_JOB_ID.h5']

‎config/hydra/scaling.yaml renamed to ‎config/hydra/make_data.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ run:
88
dir: ${oc.env:SCRATCH}/autoreg/${hydra.job.name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
99
sweep:
1010
dir: ${oc.env:SCRATCH}/autoreg/${hydra.job.name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
11-
subdir: seed-${random_state}_n-${data.n_sample}
11+
subdir: seed-${random_state}
1212

1313
job_logging:
1414
handlers:
@@ -19,10 +19,10 @@ job_logging:
1919
launcher:
2020
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
2121
timeout_min: 600
22-
cpus_per_task: 4
23-
gpus_per_node: 1
22+
cpus_per_task: 1
23+
gpus_per_node:
2424
tasks_per_node: 1
25-
mem_gb: 4
25+
mem_gb: 2
2626
nodes: 1
2727
name: ${hydra.job.name}
2828
stderr_to_stdout: false

‎config/hydra/scaling_cpu.yaml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
---
2+
defaults:
3+
- _self_
4+
- override launcher: submitit_slurm
5+
6+
# output directory, generated dynamically on each run
7+
run:
8+
dir: ${oc.env:SCRATCH}/autoreg/${hydra.job.name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
9+
sweep:
10+
dir: ${oc.env:SCRATCH}/autoreg/${hydra.job.name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
11+
subdir: ${hydra.job.override_dirname}
12+
13+
job_logging:
14+
handlers:
15+
file:
16+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
17+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
18+
19+
launcher:
20+
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
21+
timeout_min: 60
22+
cpus_per_task: 10
23+
gpus_per_node:
24+
tasks_per_node: 1
25+
mem_gb: 16
26+
nodes: 1
27+
name: ${hydra.job.name}
28+
stderr_to_stdout: false
29+
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
30+
partition:
31+
qos:
32+
comment:
33+
constraint:
34+
exclude:
35+
gres:
36+
cpus_per_gpu:
37+
gpus_per_task:
38+
mem_per_gpu:
39+
mem_per_cpu:
40+
account: ${oc.env:SLURM_COMPUTE_ACCOUNT}
41+
signal_delay_s: 120
42+
max_num_timeout: 0
43+
additional_parameters: {mail-user: '${oc.env:SLACK_EMAIL_BOT}', mail-type: ALL}
44+
array_parallelism: 256
45+
setup: []

‎config/hydra/scaling_gpu.yaml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
---
2+
defaults:
3+
- _self_
4+
- override launcher: submitit_slurm
5+
6+
# output directory, generated dynamically on each run
7+
run:
8+
dir: ${oc.env:SCRATCH}/autoreg/${hydra.job.name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
9+
sweep:
10+
dir: ${oc.env:SCRATCH}/autoreg/${hydra.job.name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
11+
subdir: ${hydra.job.override_dirname}
12+
13+
job_logging:
14+
handlers:
15+
file:
16+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
17+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
18+
19+
launcher:
20+
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
21+
timeout_min: 180
22+
cpus_per_task: 5
23+
gpus_per_node: 1
24+
tasks_per_node: 1
25+
mem_gb: 8
26+
nodes: 1
27+
name: ${hydra.job.name}
28+
stderr_to_stdout: false
29+
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
30+
partition:
31+
qos:
32+
comment:
33+
constraint:
34+
exclude:
35+
gres:
36+
cpus_per_gpu:
37+
gpus_per_task:
38+
mem_per_gpu:
39+
mem_per_cpu:
40+
account: ${oc.env:SLURM_COMPUTE_ACCOUNT}
41+
signal_delay_s: 120
42+
max_num_timeout: 0
43+
additional_parameters: {mail-user: '${oc.env:SLACK_EMAIL_BOT}', mail-type: ALL}
44+
array_parallelism: 256
45+
setup: [export HYDRA_FULL_ERROR=1, export NCCL_DEBUG=INFO, 'rsync -tv --info=progress2 /lustre03/project/6003287/hwang1/rs-autoregression-prediction/outputs/sample_for_pretraining/seed-${random_state}/./sample_seed-${random_state}_data-train.h5
46+
$SLURM_TMPDIR/data_$SLURM_JOB_ID.h5']

‎config/model/linearchebnet.yaml renamed to ‎config/model/basic_model.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
defaults:
33
- _self_
44
- experiment
5-
6-
model: LinearChebnet
5+
6+
# these defaults are from FP's paper
7+
model: Chebnet
78
FC_type: nonshared_uni
8-
FK: 8,3,8,3,8,3
9-
M: '1'
10-
use_bn: true
9+
GCL: 3
10+
F: 8
11+
K: 3
12+
FCL: 1
13+
M: 8
14+
aggrs: add
1115
dropout: 0
16+
use_bn: true
1217
bn_momentum: 0.1
13-

‎config/model/chebnet_detailed.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
defaults:
3+
- _self_
4+
- experiment
5+
6+
model: Chebnet
7+
FC_type: nonshared_uni
8+
use_bn: true
9+
dropout: 0
10+
bn_momentum: 0.1
11+
layers:
12+
- {F: 8, K: 3, aggr: add}
13+
- {F: 8, K: 3, aggr: add}
14+
- {F: 8, K: 3, aggr: add}
15+
- {M: 1}

‎config/model/experiment.yaml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
---
2-
nb_epochs: 20
3-
lr: 0.01
4-
lr_patience: 4
5-
lr_thres: 0.001
6-
weight_decay: 0
7-
batch_size: 100
8-
num_workers: 4
9-
time_stride: 1
10-
lag: 1
11-
seq_length: 16
2+
# these defaults are from FP's paper
3+
# https://doi.org/10.1162/imag_a_00228
4+
nb_epochs: 20 # best was 100 in the paper, use 20 as default for faster iteration for scaling
5+
batch_size: 512
126
edge_index_thres: 0.9
7+
lr: 1e-2 # default 1e-2, Common ranges include 1e-3 to 1e-1.
8+
weight_decay: 0 # this has to be really low, like 0 - 0.0001 range
9+
lr_patience: 4 # default 4
10+
lr_thres: 1e-3 # default 1e-3
11+
early_stopping:
12+
min_delta: 1e-03
13+
tolerance: 3

‎config/predict.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
---
22
defaults:
33
- _self_
4-
- extract
4+
- hydra: scaling_cpu
55

66
predict_variable: sex
77
predict_variable_type: binary
8-
phenotype_file: inputs/connectomes/sourcedata/ukbb/ukbb_pheno.tsv
98
percentage_sample: 100
10-
random_state: 42
119
# passing extracted feature path is necessary for evaluation
1210
feature_path: ???

‎config/train.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
---
22
defaults:
33
- _self_
4-
- model: chebnet
4+
- model: basic_model
55
- data: ukbb
66
- hydra: default
77

88
verbose: 2
9-
random_state: 42
9+
random_state: 1
1010
return_type: float
11-
data_split: ???
11+
# checkpoints: 0,2,4,6,8,10,12,14,16,18,20
12+
num_workers: 4

‎env/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ seaborn==0.13.2
1111
hydra-core==1.3.2
1212
hpbandster==0.7.4
1313
configspace==0.7.1
14+
torchinfo==1.8.0

‎pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ exclude = [
2222
"**/tests/*",
2323
"*build/",
2424
"code/fmri-autoreg",
25+
"src/utils/plot_*.py"
2526
]
2627

2728
ignore = [

‎src/create_holdout_sample.py

Lines changed: 147 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
1+
"""
2+
python src/create_holdout_sample.py --multirun \
3+
+data=ukbb ++random_state=1,2,3,5,7,10,42,435,764,9999
4+
This is script will create all the input/labels for the full dataset.
5+
"""
16
import json
27
import logging
8+
import os
39
from pathlib import Path
410

11+
import h5py
512
import hydra
613
import matplotlib.pyplot as plt
14+
import numpy as np
715
import pandas as pd
816
import seaborn as sns
17+
from fmri_autoreg.data.load_data import get_edge_index, load_data, make_seq
918
from omegaconf import DictConfig
19+
from sklearn.model_selection import train_test_split
20+
from tqdm import tqdm
1021

1122
log = logging.getLogger(__name__)
1223

@@ -17,7 +28,12 @@ def main(params: DictConfig) -> None:
1728

1829
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
1930
output_dir = Path(output_dir)
31+
rng = np.random.default_rng(params["random_state"])
32+
log.info(f"Current working directory : {os.getcwd()}")
33+
log.info(f"Output directory : {output_dir}")
34+
log.info(f"Random seed {params['random_state']}")
2035

36+
# create hold out sample using the full dataset and save things
2137
sample = create_hold_out_sample(
2238
phenotype_path=params["data"]["phenotype_file"],
2339
phenotype_meta=params["data"]["phenotype_json"],
@@ -35,6 +51,9 @@ def main(params: DictConfig) -> None:
3551
json.dump(sample, f, indent=2)
3652

3753
# plot the distribution of confounds of downstreams balanced samples
54+
log.info("Holdout sample created")
55+
report_dir = output_dir / "report"
56+
report_dir.mkdir(exist_ok=True)
3857
demographics = {}
3958
for d in sample["test_downstreams"].keys():
4059
d_subjects = sample["test_downstreams"][d]
@@ -50,7 +69,7 @@ def main(params: DictConfig) -> None:
5069
)
5170
for ax, c in zip(axes, params["data"]["class_balance_confounds"]):
5271
sns.histplot(x=c, data=df, hue=d, kde=True, ax=ax)
53-
fig.savefig(output_dir / f"{d}.png")
72+
fig.savefig(report_dir / f"{d}.png")
5473
demographics[d] = {
5574
"patient": {
5675
"condition": d,
@@ -81,14 +100,14 @@ def main(params: DictConfig) -> None:
81100
"proportion_kept_sd": df[df[d] == 0]["proportion_kept"].std(),
82101
},
83102
}
84-
103+
# save the summary
85104
demographics_summary = pd.DataFrame()
86105
for d in demographics.keys():
87106
df = pd.DataFrame.from_dict(demographics[d], orient="index")
88107
df.set_index([df.index, "condition"], inplace=True)
89108
demographics_summary = pd.concat([demographics_summary, df])
90109
demographics_summary.round(decimals=2).to_csv(
91-
output_dir / "demographics_summary.tsv", sep="\t"
110+
report_dir / "demographics_summary.tsv", sep="\t"
92111
)
93112

94113
for key in sample.keys():
@@ -104,7 +123,131 @@ def main(params: DictConfig) -> None:
104123
fig.suptitle(f"{key} sample (N={len(d_subjects)})")
105124
for ax, c in zip(axes, params["data"]["class_balance_confounds"]):
106125
sns.histplot(x=c, data=df, kde=True, ax=ax)
107-
fig.savefig(output_dir / f"{key}.png")
126+
fig.savefig(report_dir / f"{key}.png")
127+
128+
log.info("Sample report created")
129+
130+
full_train_sample = [f"sub-{s}" for s in sample["train"]]
131+
test_participant_ids = [f"sub-{s}" for s in sample["hold_out"]]
132+
rng.shuffle(full_train_sample)
133+
134+
# pre generate labels for training samples
135+
prefix = f"sample_seed-{params['random_state']}"
136+
data_h5 = Path(output_dir) / f"{prefix}_data-train.h5"
137+
original_reference = Path(output_dir) / f"{prefix}_split.json"
138+
data_reference = {}
139+
140+
# further split the training sample into training and validation
141+
142+
log.info(
143+
f"Create dataset of {len(full_train_sample)} subjects "
144+
"for pretrain model. "
145+
)
146+
train_participant_ids, val_participant_ids = train_test_split(
147+
full_train_sample,
148+
test_size=params["data"]["validation_set"],
149+
shuffle=False,
150+
random_state=params["random_state"],
151+
)
152+
data_ids = (
153+
train_participant_ids,
154+
val_participant_ids,
155+
test_participant_ids,
156+
)
157+
# save reference to the h5 path in the original data file
158+
data_reference = create_reference(params, data_ids)
159+
160+
# generate labels for the autoregressive model
161+
with h5py.File(data_h5, "a") as f:
162+
for n_embed in data_reference.keys():
163+
base = f"n_embed-{n_embed}"
164+
log.info(f"Creating dataset for n_embed-{n_embed}")
165+
f.create_group(base)
166+
for split in ["train", "val"]:
167+
cur_group = f.create_group(f"{base}/{split}")
168+
169+
if split == "train":
170+
# use the training set (exclude validation set)
171+
# to create the connectome
172+
edges = get_edge_index(
173+
data_file=params["data"]["data_file"],
174+
dset_paths=data_reference[n_embed]["train"],
175+
)
176+
f[f"n_embed-{n_embed}"]["train"].create_dataset(
177+
"connectome", data=edges
178+
)
179+
180+
for dset in tqdm(
181+
data_reference[n_embed][split],
182+
desc=f"Creating {split} set",
183+
):
184+
data = load_data(
185+
path=params["data"]["data_file"],
186+
h5dset_path=dset,
187+
standardize=False,
188+
dtype="data",
189+
)
190+
x, y = make_seq(
191+
data,
192+
params["data"]["seq_length"],
193+
params["data"]["time_stride"],
194+
params["data"]["lag"],
195+
)
196+
if x.shape[0] == 0 or x is None:
197+
log.warning(
198+
f"Skipping {dset} as label couldn't be created."
199+
)
200+
continue
201+
if cur_group.get("input") is None:
202+
cur_group.create_dataset(
203+
name="input",
204+
data=x,
205+
dtype=np.float32,
206+
maxshape=(
207+
None,
208+
n_embed,
209+
params["data"]["seq_length"],
210+
),
211+
chunks=(
212+
x.shape[0],
213+
n_embed,
214+
params["data"]["seq_length"],
215+
),
216+
)
217+
cur_group.create_dataset(
218+
name="label",
219+
data=y,
220+
dtype=np.float32,
221+
maxshape=(None, n_embed),
222+
chunks=(y.shape[0], n_embed),
223+
)
224+
else:
225+
cur_group["input"].resize(
226+
(cur_group["input"].shape[0] + x.shape[0]), axis=0
227+
)
228+
cur_group["input"][-x.shape[0] :] = x
229+
230+
cur_group["label"].resize(
231+
(cur_group["label"].shape[0] + y.shape[0]), axis=0
232+
)
233+
cur_group["label"][-y.shape[0] :] = y
234+
with open(original_reference, "a") as f:
235+
json.dump(data_reference, f, indent=2)
236+
237+
238+
def create_reference(params, data_ids):
239+
data_reference = {}
240+
from src.data.load_data import load_ukbb_dset_path
241+
242+
for n_embed in [64, 197, 444]:
243+
data_reference[n_embed] = {}
244+
for d in zip(["train", "val", "test"], data_ids):
245+
data_reference[n_embed][d[0]] = load_ukbb_dset_path(
246+
participant_id=d[1],
247+
atlas_desc=f"atlas-MIST_desc-{n_embed}",
248+
segment=params["data"]["segment"],
249+
)
250+
return data_reference
108251

109252

110253
if __name__ == "__main__":

‎src/data/load_data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def get_model_data(
316316
phenotype_file: Union[Path, str],
317317
measure: str = "connectome",
318318
label: str = "sex",
319-
pooling_target: str = "max",
320319
log: logging = logging,
321320
) -> Dict[str, np.ndarray]:
322321
"""Get the data from pretrained model for the downstrean task.
@@ -374,11 +373,10 @@ def get_model_data(
374373
if subject in participant_id:
375374
df_phenotype.loc[subject, "path"] = p
376375
selected_path = df_phenotype.loc[participant_id, "path"].values.tolist()
377-
log.info(len(selected_path))
378376
data = load_data(data_file, selected_path, dtype="data")
379377

380378
if "r2" in measure:
381-
data = np.concatenate(data).squeeze()
379+
data = np.array(data)[:, 0, :]
382380
if measure == "avgr2":
383381
data = data.mean(axis=1).reshape(-1, 1)
384382
data = StandardScaler().fit_transform(data)

‎src/extract.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import hydra
1818
import torch
1919
from fmri_autoreg.models.predict_model import predict_horizon
20-
from fmri_autoreg.tools import load_model
20+
from fmri_autoreg.tools import chebnet_argument_resolver, load_model
2121
from omegaconf import DictConfig, OmegaConf
2222
from tqdm import tqdm
2323

2424
log = logging.getLogger(__name__)
25+
LABEL_DIR = Path(__file__).parents[1] / "outputs" / "sample_for_pretraining"
2526

2627

2728
@hydra.main(version_base="1.3", config_path="../config", config_name="extract")
@@ -50,16 +51,18 @@ def main(params: DictConfig) -> None:
5051
log.info(f"predicting horizon: {horizons}")
5152

5253
# load test set subject path from the training
53-
with open(model_path.parent / "train_test_split.json", "r") as f:
54+
with open(
55+
LABEL_DIR
56+
/ f"seed-{params['random_state']}"
57+
/ f"sample_seed-{params['random_state']}_split.json",
58+
"r",
59+
) as f:
5460
subj = json.load(f)
5561

56-
subj_list = subj["test"]
57-
58-
# save test data path to a text file for easy future reference
59-
with open(output_dir / "test_set_connectome.txt", "w") as f:
60-
for item in subj_list:
61-
f.write("%s\n" % item)
62-
62+
subj_list = subj[str(params["data"]["n_embed"])]["test"]
63+
model_params = chebnet_argument_resolver(
64+
OmegaConf.to_container(params["model"])
65+
)
6366
log.info("Load model")
6467
model = load_model(model_path)
6568
if isinstance(model, torch.nn.Module):
@@ -79,13 +82,12 @@ def main(params: DictConfig) -> None:
7982
# get the prediction of t+1
8083
r2, Z, Y = predict_horizon(
8184
model=model,
82-
seq_length=params["model"]["seq_length"],
85+
seq_length=params["data"]["seq_length"],
8386
horizon=horizon,
8487
data_file=params["data"]["data_file"],
8588
dset_path=h5_dset_path,
86-
batch_size=params["model"]["batch_size"],
87-
stride=params["model"]["time_stride"],
88-
standardize=False, # the ts is already standardized
89+
batch_size=None,
90+
stride=params["data"]["time_stride"],
8991
)
9092
# save the original output to a h5 file
9193
with h5py.File(output_horizon_path, "a") as f:
@@ -108,17 +110,17 @@ def main(params: DictConfig) -> None:
108110
data_file=params["data"]["data_file"],
109111
h5_dset_path=h5_dset_path,
110112
model=model,
111-
seq_length=params["model"]["seq_length"],
112-
time_stride=params["model"]["time_stride"],
113-
lag=params["model"]["lag"],
113+
seq_length=params["data"]["seq_length"],
114+
time_stride=params["data"]["time_stride"],
115+
lag=params["data"]["lag"],
114116
)
115117
# save the original output to a h5 file
116118
with h5py.File(output_conv_path, "a") as f:
117119
new_ds_path = h5_dset_path.replace("timeseries", "convlayers")
118120
f[new_ds_path] = convlayers.numpy()
119121
convlayers_F = [
120122
int(F)
121-
for i, F in enumerate(params["model"]["FK"].split(","))
123+
for i, F in enumerate(model_params["FK"].split(","))
122124
if i % 2 == 0
123125
]
124126
# get the pooling features of the assigned layer
@@ -138,6 +140,7 @@ def main(params: DictConfig) -> None:
138140
# save the original output to a h5 file
139141
with h5py.File(output_conv_path, "a") as f:
140142
f.attrs["convolution_layers_F"] = convlayers_F
143+
log.info("Extraction completed.")
141144

142145

143146
if __name__ == "__main__":

‎src/model/extract_features.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ def pooling_convlayers(
7474
convlayers: torch.tensor,
7575
pooling_methods: str = "average",
7676
pooling_target: str = "parcel",
77-
layer_index: int = -1,
77+
layer_index: int = -99,
7878
layer_structure: Tuple[int] = None,
7979
) -> np.array:
8080
"""Pooling the conv layers.
8181
8282
Args:
8383
convlayers (torch.tensor) : shape
8484
(time series, parcel, stack layer feature F)
85-
layer_index (int) : the index of the layer to be pooled, -1
85+
layer_index (int) : the index of the layer to be pooled, -99
8686
means pooling all layers.
8787
pooling_methods (str) : "average", "max", "std"
8888
pooling_target (str) : keep "parcel" or "timeseries" and parcels
@@ -96,14 +96,14 @@ def pooling_convlayers(
9696
raise ValueError(f"Pooling method {pooling_methods} is not supported.")
9797
if pooling_target not in ["parcel", "timeseries"]:
9898
raise ValueError(f"Pooling target {pooling_target} is not supported.")
99-
if layer_index > len(layer_structure):
99+
if layer_structure and layer_index > len(layer_structure):
100100
raise ValueError(
101101
"The layer index should be smaller than the length of the "
102102
f"layer structure. layer index is {layer_index} but there "
103103
f"are {len(layer_structure)} layers."
104104
)
105105

106-
if layer_index != -1: # select the layer to be pooled
106+
if layer_index != -99: # select the layer to be pooled
107107
if sum(layer_structure) != convlayers.shape[-1]:
108108
raise ValueError(
109109
"The sum of layer structure should be equal to the "

‎src/predict.py

Lines changed: 98 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
"""
22
Execute at the root of the repo, not in the code directory.
33
4-
To execute the code,
5-
you need to create a directory structure as follows:
6-
```
7-
.
8-
└── <name of your analysis>/
9-
├── extract -> symlink to the output of the `extract` script
10-
└── model -> symlink to the output of a fitted model
114
```
125
python src/predict.py --multirun \
136
feature_path=/path/to/<name of your analysis>/extract \
14-
++phenotype_file=/path/to/phenotype.tsv
157
```
168
179
Currently the script hard coded to predict sex or age.
@@ -31,7 +23,11 @@
3123
Ridge,
3224
RidgeClassifier,
3325
)
34-
from sklearn.model_selection import ShuffleSplit, StratifiedKFold
26+
from sklearn.model_selection import (
27+
ShuffleSplit,
28+
StratifiedKFold,
29+
StratifiedShuffleSplit,
30+
)
3531
from sklearn.neural_network import MLPClassifier, MLPRegressor
3632
from sklearn.svm import LinearSVC, LinearSVR
3733

@@ -41,6 +37,16 @@
4137
"data_file_pattern": None,
4238
"plot_label": "Connectome",
4339
},
40+
"avgr2": {
41+
"data_file": None,
42+
"data_file_pattern": "r2map",
43+
"plot_label": "t+1\n average R2",
44+
},
45+
"r2map": {
46+
"data_file": None,
47+
"data_file_pattern": "r2map",
48+
"plot_label": "t+1\nR2 map",
49+
},
4450
"conv_avg": {
4551
"data_file": None,
4652
"data_file_pattern": "average",
@@ -61,16 +67,6 @@
6167
"data_file_pattern": "1dconv",
6268
"plot_label": "Conv layers \n 1D convolution",
6369
},
64-
"avgr2": {
65-
"data_file": None,
66-
"data_file_pattern": "r2map",
67-
"plot_label": "t+1\n average R2",
68-
},
69-
"r2map": {
70-
"data_file": None,
71-
"data_file_pattern": "r2map",
72-
"plot_label": "t+1\nR2 map",
73-
},
7470
}
7571

7672
log = logging.getLogger(__name__)
@@ -85,40 +81,98 @@ def train(dataset, tng, tst, clf, clf_name):
8581
}
8682

8783

84+
LABEL_DIR = Path(__file__).parents[1] / "outputs" / "sample_for_pretraining"
85+
86+
8887
@hydra.main(version_base="1.3", config_path="../config", config_name="predict")
8988
def main(params: DictConfig) -> None:
90-
from src.data.load_data import get_model_data, load_h5_data_path
89+
from src.data.load_data import (
90+
get_model_data,
91+
load_h5_data_path,
92+
load_ukbb_dset_path,
93+
)
9194

9295
# parse parameters
9396
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
9497
output_dir = Path(output_dir)
9598
log.info(f"Output data {output_dir}")
9699
feature_path = Path(params["feature_path"])
97-
phenotype_file = Path(params["phenotype_file"])
98-
convlayers_path = feature_path / "feature_convlayers.h5"
99-
feature_t1_file = feature_path / f"feature_horizon-{params['horizon']}.h5"
100-
test_subjects = feature_path / "test_set_connectome.txt"
100+
extract_config = OmegaConf.load(feature_path / ".hydra/config.yaml")
101101
model_config = OmegaConf.load(
102-
feature_path.parent / "model/.hydra/config.yaml"
102+
Path(extract_config["model_path"]).parent / ".hydra/config.yaml"
103+
)
104+
105+
phenotype_file = Path(model_config["data"]["phenotype_file"])
106+
convlayers_path = feature_path / "feature_convlayers.h5"
107+
feature_t1_file = (
108+
feature_path / f"feature_horizon-{extract_config['horizon']}.h5"
103109
)
104110
params = OmegaConf.merge(model_config, params)
105111
log.info(params)
106-
107-
# load test set subject path from the training
108-
with open(test_subjects, "r") as f:
109-
hold_outs = f.read().splitlines()
110112
percentage_sample = params["percentage_sample"]
111-
if percentage_sample != 100:
112-
proportion = percentage_sample / 100
113-
sample_select = ShuffleSplit(
114-
n_splits=1,
115-
train_size=proportion,
116-
random_state=params["random_state"],
117-
)
118-
sample_index, _ = next(sample_select.split(hold_outs))
119-
subj = [hold_outs[i] for i in sample_index]
113+
114+
if params["predict_variable"] in ["age", "sex"]:
115+
sample_file = list(
116+
(LABEL_DIR / f"seed-{model_config['random_state']}").glob(
117+
"sample*split.json"
118+
)
119+
)[0]
120+
121+
# load test set subject path from the training
122+
with open(sample_file, "r") as f:
123+
hold_outs = json.load(f)[f"{model_config['data']['n_embed']}"][
124+
"test"
125+
]
126+
127+
if percentage_sample != 100:
128+
proportion = percentage_sample / 100
129+
sample_select = ShuffleSplit(
130+
n_splits=1,
131+
train_size=proportion,
132+
random_state=params["random_state"],
133+
)
134+
sample_index, _ = next(sample_select.split(hold_outs))
135+
subj = [hold_outs[i] for i in sample_index]
136+
else:
137+
subj = hold_outs.copy()
120138
else:
121-
subj = hold_outs.copy()
139+
sample_file = (
140+
LABEL_DIR
141+
/ f"seed-{model_config['random_state']}"
142+
/ "downstream_sample.json"
143+
)
144+
with open(sample_file, "r") as f:
145+
hold_outs = json.load(f)["test_downstreams"][
146+
params["predict_variable"]
147+
] # these are subject ids
148+
diagnosis_data = (
149+
pd.read_csv(phenotype_file, sep="\t")
150+
.set_index("participant_id")
151+
.loc[hold_outs, :]
152+
)
153+
154+
percentage_sample = params["percentage_sample"]
155+
if percentage_sample != 100:
156+
proportion = percentage_sample / 100
157+
sample_select = StratifiedShuffleSplit(
158+
n_splits=1,
159+
train_size=proportion,
160+
random_state=params["random_state"],
161+
)
162+
sample_index, _ = next(
163+
sample_select.split(
164+
diagnosis_data.index,
165+
diagnosis_data[params["predict_variable"]],
166+
)
167+
)
168+
subj = [diagnosis_data.index[i] for i in sample_index]
169+
170+
else:
171+
subj = hold_outs.copy()
172+
subj = [f"sub-{s}" for s in subj]
173+
subj = load_ukbb_dset_path(
174+
subj, params["data"]["atlas_desc"], params["data"]["segment"]
175+
)
122176

123177
log.info(
124178
f"Downstream prediction on {len(subj)}, "
@@ -143,31 +197,29 @@ def main(params: DictConfig) -> None:
143197
C=100,
144198
penalty="l2",
145199
class_weight="balanced",
146-
max_iter=1000000,
200+
max_iter=10000,
147201
random_state=params["random_state"],
148202
),
149203
"LogisticR": LogisticRegression(
150204
penalty="l2",
151205
class_weight="balanced",
152-
max_iter=100000,
206+
max_iter=1000,
153207
random_state=params["random_state"],
154208
n_jobs=-1,
155209
),
156210
"Ridge": RidgeClassifier(
157211
class_weight="balanced",
158212
random_state=params["random_state"],
159-
max_iter=100000,
213+
max_iter=1000,
160214
),
161215
}
162216
elif params["predict_variable_type"] == "numerical": # need to fix this
163217
clf_options = {
164218
"SVM": LinearSVR(
165-
C=100, max_iter=1000000, random_state=params["random_state"]
219+
C=100, max_iter=10000, random_state=params["random_state"]
166220
),
167221
"LinearR": LinearRegression(n_jobs=-1),
168-
"Ridge": Ridge(
169-
random_state=params["random_state"], max_iter=100000
170-
),
222+
"Ridge": Ridge(random_state=params["random_state"], max_iter=1000),
171223
}
172224
else:
173225
raise ValueError(
@@ -180,11 +232,6 @@ def main(params: DictConfig) -> None:
180232
log.info(f"Load data {baseline_details[measure]['data_file']}")
181233
if measure == "connectome":
182234
dset_path = baseline_details[measure]["data_file_pattern"]
183-
elif percentage_sample == 100:
184-
dset_path = load_h5_data_path(
185-
baseline_details[measure]["data_file"],
186-
baseline_details[measure]["data_file_pattern"],
187-
)
188235
else:
189236
dset_path = []
190237
for connectome_path in subj:

‎src/train.py

Lines changed: 84 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
model training
99
```
1010
python src/train.py --multirun hydra=scaling \
11-
++data.n_sample=100,200,300,-1
11+
++data.proportion_sample=1,0.5,0.25,0.1
1212
```
1313
"""
14-
import json
1514
import logging
1615
import os
1716
import pickle as pk
@@ -23,50 +22,41 @@
2322
import numpy as np
2423
import pandas as pd
2524
import torch
26-
from fmri_autoreg.data.load_data import make_input_labels
25+
from fmri_autoreg.data.load_data import get_edge_index_threshold
2726
from fmri_autoreg.models.train_model import train
28-
from omegaconf import DictConfig
27+
from fmri_autoreg.tools import chebnet_argument_resolver
28+
from omegaconf import DictConfig, OmegaConf
2929
from seaborn import lineplot
30-
from sklearn.model_selection import train_test_split
30+
from torchinfo import summary
3131

32-
33-
def convert_bytes(num):
34-
for x in ["bytes", "KB", "MB", "GB", "TB"]:
35-
if num < 1024.0:
36-
return f"{num:.1f} {x}"
37-
num /= 1024.0
38-
39-
40-
log = logging.getLogger(__name__)
32+
LABEL_DIR = Path(__file__).parents[1] / "outputs" / "sample_for_pretraining"
4133

4234

4335
@hydra.main(version_base="1.3", config_path="../config", config_name="train")
4436
def main(params: DictConfig) -> None:
4537
"""Train model using parameters dict and save results."""
4638
# import local library here because sumbitit and hydra being weird
4739
# if not interactive session of slurm, import submit it
48-
from src.data.load_data import load_ukbb_dset_path
49-
50-
rng = np.random.default_rng(params["random_state"])
51-
5240
if (
5341
"SLURM_JOB_ID" in os.environ
5442
and os.environ["SLURM_JOB_NAME"] != "interactive"
5543
):
56-
# import submitit
57-
# env = submitit.JobEnvironment()
5844
pid = os.getpid()
5945
# A logger for this file
6046
log = logging.getLogger(f"Process ID {pid}")
6147
log.info(f"Process ID {pid}")
62-
# use SLURM_TMPDIR for data_dir
63-
data_dir = Path(os.environ["SLURM_TMPDIR"]) / f"pid_{pid}"
64-
data_dir.mkdir()
48+
tng_data_h5 = (
49+
Path(os.environ["SLURM_TMPDIR"])
50+
/ f"data_{os.environ['SLURM_JOB_ID']}.h5"
51+
)
6552
else:
6653
log = logging.getLogger(__name__)
67-
data_dir = Path(
68-
hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
69-
)
54+
tng_data_h5 = list(
55+
(LABEL_DIR / f"seed-{params['random_state']}").glob("*train.h5")
56+
)[
57+
0
58+
] # will be shuffled after loading
59+
7060
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
7161
log.info(f"Current working directory : {os.getcwd()}")
7262
log.info(f"Output directory : {output_dir}")
@@ -80,111 +70,98 @@ def main(params: DictConfig) -> None:
8070
# flatten the parameters
8171
device = "cuda:0" if torch.cuda.is_available() else "cpu"
8272
train_param = {**params["model"], **params["data"]}
73+
train_param["num_workers"] = params["num_workers"]
8374
train_param["torch_device"] = device
8475
train_param["random_state"] = params["random_state"]
76+
if "checkpoints" in params:
77+
train_param["checkpoints"] = params["checkpoints"]
8578
log.info(f"Working on {device}.")
8679

87-
# load data path
88-
n_sample = params["data"]["n_sample"]
89-
90-
data_split_json = params["data_split"]
91-
92-
with open(data_split_json, "r") as f:
93-
train_subject = json.load(f)["train"]
94-
test_subject = json.load(f)["holdout"]
95-
96-
rng.shuffle(train_subject)
97-
98-
if n_sample > 0:
99-
train_subject = train_subject[:n_sample]
100-
101-
train_subject = [f"sub-{s}" for s in train_subject]
102-
test_subject = [f"sub-{s}" for s in test_subject]
103-
104-
train_participant_ids, val_participant_ids = train_test_split(
105-
train_subject,
106-
test_size=params["data"]["validation_set"],
107-
shuffle=True,
108-
random_state=params["random_state"],
109-
)
110-
111-
data_reference = {}
112-
data_reference["train"] = load_ukbb_dset_path(
113-
participant_id=train_participant_ids,
114-
atlas_desc=params["data"]["atlas_desc"],
115-
segment=params["data"]["segment"],
80+
# get path data
81+
try:
82+
with h5py.File(tng_data_h5, "r") as h5file:
83+
connectome = h5file[f"n_embed-{train_param['n_embed']}"]["train"][
84+
"connectome"
85+
][:]
86+
except OSError:
87+
log.error(f"File {tng_data_h5} corrupted.")
88+
return 1
89+
90+
# get edge index
91+
edge_index = get_edge_index_threshold(
92+
connectome, train_param["edge_index_thres"]
11693
)
117-
data_reference["val"] = load_ukbb_dset_path(
118-
participant_id=val_participant_ids,
119-
atlas_desc=params["data"]["atlas_desc"],
120-
segment=params["data"]["segment"],
121-
)
122-
data_reference["test"] = load_ukbb_dset_path(
123-
participant_id=test_subject,
124-
atlas_desc=params["data"]["atlas_desc"],
125-
segment=params["data"]["segment"],
126-
)
127-
with open(Path(output_dir) / "train_test_split.json", "w") as f:
128-
json.dump(data_reference, f, indent=2)
129-
n_sample_pretrain = len(data_reference["train"]) + len(
130-
data_reference["val"]
131-
)
132-
log.info(
133-
f"Experiment on {n_sample_pretrain} subjects for pretrain model. "
134-
)
135-
136-
tng_data_h5 = data_dir / "data_train.h5"
137-
val_data_h5 = data_dir / "data_val.h5"
138-
tng_data_h5, edge_index = make_input_labels(
139-
data_file=params["data"]["data_file"],
140-
dset_paths=data_reference["train"],
141-
params=train_param,
142-
output_file_path=tng_data_h5,
143-
compute_edge_index=compute_edge_index,
144-
log=log,
145-
)
146-
val_data_h5, _ = make_input_labels(
147-
data_file=params["data"]["data_file"],
148-
dset_paths=data_reference["val"],
149-
params=train_param,
150-
output_file_path=val_data_h5,
151-
compute_edge_index=False,
152-
log=log,
153-
)
154-
if params["verbose"] > 1:
155-
log.info(
156-
f"Training data: {convert_bytes(os.path.getsize(tng_data_h5))}"
157-
)
158-
log.info(
159-
f"Validation data: {convert_bytes(os.path.getsize(val_data_h5))}"
160-
)
161-
162-
train_data = (tng_data_h5, val_data_h5, edge_index)
94+
log.info("Loaded connectome.")
95+
train_data = (tng_data_h5, edge_index)
16396
del edge_index
16497

16598
with h5py.File(tng_data_h5, "r") as h5file:
166-
n_seq = h5file["input"].shape[0]
167-
if n_seq < train_param["batch_size"]:
99+
n_tng_inputs = h5file[f"n_embed-{train_param['n_embed']}"]["train"][
100+
"input"
101+
].shape[0]
102+
n_tng_inputs *= train_param["proportion_sample"]
103+
104+
if n_tng_inputs < train_param["batch_size"]:
168105
log.info(
169106
"Batch size is greater than the number of sequences. "
170107
"Setting batch size to number of sequences. "
171-
f"New batch size: {n_seq}. "
108+
f"New batch size: {n_tng_inputs}. "
172109
f"Old batch size: {train_param['batch_size']}."
173110
)
174-
train_param["batch_size"] = n_seq
111+
train_param["batch_size"] = n_tng_inputs
112+
if compute_edge_index: # chebnet
113+
train_param = chebnet_argument_resolver(train_param)
114+
# save train_param
115+
with open(os.path.join(output_dir, "train_param.yaml"), "w") as f:
116+
OmegaConf.save(config=train_param, f=f)
175117

176118
log.info("Start training.")
177119
(
178120
model,
179121
mean_r2_tng,
180122
mean_r2_val,
181123
losses,
182-
_,
124+
checkpoints,
183125
) = train(train_param, train_data, verbose=params["verbose"])
126+
184127
# save training results
185128
np.save(os.path.join(output_dir, "mean_r2_tng.npy"), mean_r2_tng)
186129
np.save(os.path.join(output_dir, "mean_r2_val.npy"), mean_r2_val)
187130
np.save(os.path.join(output_dir, "training_losses.npy"), losses)
131+
if "checkpoints" in params:
132+
# save a list of dictionaries as pd dataframe
133+
checkpoints = pd.DataFrame(checkpoints)
134+
checkpoints.to_csv(
135+
os.path.join(output_dir, "checkpoints.tsv"), sep="\t"
136+
)
137+
if params["verbose"] > 3:
138+
# get model info
139+
with open(os.path.join(output_dir, "model_info.txt"), "w") as f:
140+
model_stats = summary(model)
141+
summary_str = str(model_stats)
142+
f.write(summary_str)
143+
144+
# get model info
145+
with open(
146+
os.path.join(output_dir, "model_info_with_input.txt"), "w"
147+
) as f:
148+
model_stats = summary(
149+
model,
150+
input_size=(
151+
train_param["batch_size"],
152+
train_param["n_embed"],
153+
train_param["seq_length"],
154+
),
155+
col_names=[
156+
"input_size",
157+
"output_size",
158+
"num_params",
159+
"kernel_size",
160+
],
161+
)
162+
summary_str = str(model_stats)
163+
f.write(summary_str)
164+
188165
log.info(f"Mean r2 tng: {mean_r2_tng}")
189166
log.info(f"Mean r2 val: {mean_r2_val}")
190167

@@ -196,7 +173,7 @@ def main(params: DictConfig) -> None:
196173
training_losses = pd.DataFrame(losses)
197174
plt.figure()
198175
g = lineplot(data=training_losses)
199-
g.set_title(f"Training Losses (N={n_sample})")
176+
g.set_title(f"Training Losses (number of inputs={n_tng_inputs})")
200177
g.set_xlabel("Epoc")
201178
g.set_ylabel("Loss (MSE)")
202179
plt.savefig(Path(output_dir) / "training_losses.png")

‎src/utils/explore_hyperparameters.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import itertools
2+
import re
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import pandas as pd
7+
8+
output_dirs = Path(
9+
"outputs/autoreg/train/multiruns/nembed-197_hyperparameters"
10+
).glob("**/train.log")
11+
data = []
12+
13+
14+
def peek(iterable):
15+
try:
16+
first = next(iterable)
17+
except StopIteration:
18+
return None
19+
return itertools.chain([first], iterable)
20+
21+
22+
for p in output_dirs:
23+
experiment = {
24+
f.groups()[0]: float(f.groups()[1])
25+
for f in re.finditer(r"model\.([a-z_]*)=([\d\.e?-]*)", p.parent.name)
26+
}
27+
experiment["mean_r2_val"] = np.nan
28+
experiment["runtime"] = np.nan
29+
if (p.parent / "model.pkl").exists():
30+
with open(p, "r") as log:
31+
report = log.read()
32+
mean_r2_val = re.search(r"Mean r2 val: ([\-\.\d]*)", report).groups()[
33+
0
34+
]
35+
starttime = re.search(r"\[([\d\-\s:,]*)\].*Process ID", report).group(
36+
1
37+
)
38+
endtime = re.search(r"\[([\d\-\s:,]*)\].*model trained", report).group(
39+
1
40+
)
41+
starttime = pd.to_datetime(starttime)
42+
endtime = pd.to_datetime(endtime)
43+
runtime = endtime - starttime
44+
experiment["mean_r2_val"] = mean_r2_val
45+
experiment["runtime"] = runtime.total_seconds() / 60
46+
data.append(experiment)
47+
48+
data = pd.DataFrame(data)
49+
data = data.sort_values("mean_r2_val", ascending=False)
50+
data.to_csv("_explore_hyperparameters.tsv", sep="\t", index=False)

‎src/utils/plot_diagnosis.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import re
2+
from pathlib import Path
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pandas as pd
7+
import seaborn as sns
8+
from matplotlib.lines import Line2D
9+
10+
DIAGNOSIS_PATH = "outputs/neuroips-workshop_2024/downstreams_last-layer/data/data.proportion_sample_1.0" # noqa: E501
11+
12+
feature_fullname = {
13+
"connectome": "Connectome\n(baseline)",
14+
"avgr2": "t+1\naverage R2",
15+
"r2map": "t+1\nR2 map",
16+
"conv_avg": "Conv layers\navg pooling",
17+
"conv_std": "Conv layers\nstd pooling",
18+
"conv_max": "Conv layers\nmax pooling",
19+
"conv_conv1d": "Conv layers\n1D convolution",
20+
}
21+
22+
diagnosis_fullname = {
23+
"sex": "Sex",
24+
"DEP": "Depressive\ndisorder",
25+
"ALCO": "Alcohol Abuse",
26+
"EPIL": "Epilepsy",
27+
"MS": "Multiple\nsclerosis",
28+
"PARK": "Parkinson",
29+
"BIPOLAR": "Bipolar",
30+
"ADD": "Alzheimer -\nDementia",
31+
"SCZ": "Schizophrenia",
32+
}
33+
34+
35+
def main():
36+
diagnosis_path = Path(DIAGNOSIS_PATH)
37+
diagnosis_files = diagnosis_path.glob("**/*.tsv")
38+
sns.set_theme(style="whitegrid")
39+
sns.set_context("paper", font_scale=1.5)
40+
fig, axs = plt.subplots(1, 3, figsize=(13, 6), sharey=True)
41+
# fig, ax = plt.subplots(1, 1, figsize=(6, 6))
42+
n_subjects_diagnosis = {}
43+
for ax, classifier in zip(axs, ["SVM", "LogisticR", "Ridge"]):
44+
# classifier = "LogisticR"
45+
data_clf = []
46+
diagnosis_files = diagnosis_path.glob("**/*.tsv")
47+
for p in diagnosis_files:
48+
filename = p.name
49+
diagnosis = filename.split("_")[-1].split(".")[0]
50+
df = pd.read_csv(p, sep="\t")
51+
df = df.loc[df.classifier == classifier, :]
52+
df = df.groupby("feature")["score"].agg("mean").reset_index()
53+
if diagnosis != "sex":
54+
with open(p.parent / "predict.log", "r") as f:
55+
log = f.read()
56+
n_subjects = (
57+
int(
58+
re.search(
59+
r"Downstream prediction on ([\d]*),", log
60+
).group(1)
61+
)
62+
/ 2
63+
)
64+
n_subjects_diagnosis[diagnosis] = int(n_subjects)
65+
else:
66+
with open(p.parent / "predict.log", "r") as f:
67+
log = f.read()
68+
n_holdout = int(
69+
re.search(
70+
r"Downstream prediction on ([\d]*),", log
71+
).group(1)
72+
)
73+
n_subjects_diagnosis[
74+
diagnosis
75+
] = 3341 # number of male subjects
76+
df["diagnosis"] = diagnosis_fullname[diagnosis]
77+
data_clf.append(df)
78+
data_clf = pd.concat(data_clf)
79+
data_clf = data_clf.reset_index(drop=True)
80+
81+
# for each diagnosis, get index of results better than connectome
82+
idx_better = []
83+
for _, diagnosis in enumerate(diagnosis_fullname.values()):
84+
baseline = data_clf.loc[
85+
(data_clf.feature == "connectome")
86+
& (data_clf.diagnosis == diagnosis),
87+
"score",
88+
].values[0]
89+
baseline_idx = data_clf.loc[
90+
(data_clf.feature == "connectome")
91+
& (data_clf.diagnosis == diagnosis),
92+
"score",
93+
].index[0]
94+
better = (
95+
data_clf.loc[
96+
(data_clf.feature != "connectome")
97+
& (data_clf.diagnosis == diagnosis),
98+
"score",
99+
]
100+
>= baseline
101+
)
102+
better = (
103+
data_clf.loc[
104+
(data_clf.feature != "connectome")
105+
& (data_clf.diagnosis == diagnosis),
106+
"score",
107+
]
108+
.index[better]
109+
.tolist()
110+
)
111+
better.append(baseline_idx)
112+
idx_better += better
113+
idx_better.sort()
114+
# get index that is the opposite of idx_better
115+
idx_better = np.array(idx_better)
116+
idx = np.zeros(data_clf.shape[0], dtype=bool)
117+
idx[idx_better] = True
118+
sns.stripplot(
119+
x="diagnosis",
120+
y="score",
121+
hue="feature",
122+
data=data_clf.iloc[~idx, :],
123+
ax=ax,
124+
legend=False,
125+
order=diagnosis_fullname.values(),
126+
hue_order=feature_fullname.keys(),
127+
marker="$\circ$",
128+
size=10,
129+
jitter=0.2,
130+
)
131+
sns.stripplot(
132+
x="diagnosis",
133+
y="score",
134+
hue="feature",
135+
data=data_clf.iloc[idx, :],
136+
ax=ax,
137+
legend=classifier == "Ridge",
138+
order=diagnosis_fullname.values(),
139+
hue_order=feature_fullname.keys(),
140+
size=8,
141+
jitter=0.25,
142+
)
143+
144+
ax.hlines(y=0.5, xmin=0.5, xmax=8.5, color="k", linestyle="--")
145+
ax.hlines(
146+
y=n_subjects_diagnosis["sex"] / n_holdout,
147+
xmin=-0.5,
148+
xmax=0.5,
149+
color="k",
150+
linestyle="--",
151+
)
152+
ax.set_title(f"{classifier}")
153+
ax.set_ylim(0.4, 1)
154+
ax.set_ylabel("Accuracy score")
155+
tick_lables = []
156+
for tl in diagnosis_fullname:
157+
if tl == "sex":
158+
tick_lables.append(
159+
tl.upper()
160+
+ " ($N_{male}=$"
161+
+ f"${n_subjects_diagnosis[tl]}$)"
162+
)
163+
else:
164+
tick_lables.append(tl + f" ($N={n_subjects_diagnosis[tl]}$)")
165+
166+
ax.set_xticklabels(tick_lables, rotation=90)
167+
chance = Line2D([0], [0], color="black", label="Chance", ls="--")
168+
# get legend handles and labels
169+
han, lab = axs[-1].get_legend_handles_labels()
170+
han.append(chance)
171+
legend_labels = [feature_fullname[i] for i in lab]
172+
legend_labels.append("Chance")
173+
# append cahnce line to the legend
174+
axs[-1].legend(handles=han, labels=legend_labels)
175+
sns.move_legend(axs[-1], "upper left", bbox_to_anchor=(1, 1))
176+
fig.suptitle(
177+
"Downstream prediction (Training set proportion = "
178+
f"{DIAGNOSIS_PATH.split('_')[-1].replace('-', ' ')})"
179+
)
180+
# fig.suptitle(f"Downstream prediction on Full hold out sample)")
181+
plt.tight_layout()
182+
plt.savefig(
183+
Path(DIAGNOSIS_PATH).parents[1]
184+
/ "reports"
185+
/ f"{Path(DIAGNOSIS_PATH).name}_overview_LR.png"
186+
)
187+
188+
189+
if __name__ == "__main__":
190+
main()

‎src/utils/plot_orion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@
44
"type": "legacy",
55
"database": {
66
"type": "pickleddb",
7-
"host": "outputs/autoreg/train/multiruns/2024-04-11_07-38-37/database.pkl",
7+
"host": "outputs/autoreg/train/multiruns/2024-08-21_11-09-24/database.pkl",
88
},
99
}
1010

1111
experiment = get_experiment("experiment", storage=storage)
1212

1313
fig = experiment.plot.regret()
1414
fig.write_html(
15-
"outputs/autoreg/train/multiruns/2024-04-11_07-38-37/regret.html"
15+
"outputs/autoreg/train/multiruns/2024-08-21_11-09-24/regret.html"
1616
)
1717

1818
fig = experiment.plot.parallel_coordinates()
1919
fig.write_html(
20-
"outputs/autoreg/train/multiruns/2024-04-11_07-38-37/parallel_coordinates.html"
20+
"outputs/autoreg/train/multiruns/2024-08-21_11-09-24/parallel_coordinates.html"
2121
)
2222

2323
fig = experiment.plot.lpi()
24-
fig.write_html("outputs/autoreg/train/multiruns/2024-04-11_07-38-37/lpi.html")
24+
fig.write_html("outputs/autoreg/train/multiruns/2024-08-21_11-09-24/lpi.html")
2525

2626
fig = experiment.plot.partial_dependencies()
2727
fig.write_html(
28-
"outputs/autoreg/train/multiruns/2024-04-11_07-38-37/partial_dependencies.html"
28+
"outputs/autoreg/train/multiruns/2024-08-21_11-09-24/partial_dependencies.html"
2929
)

‎src/utils/plot_scaling.py

Lines changed: 0 additions & 210 deletions
This file was deleted.
Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
"""
2+
look through the `outputs/` directory, find instance of completed
3+
training, and get the number of subjects used, mean R2 of test set,
4+
plot the number of subjects (y-axis) against R2 (x axis)
5+
"""
6+
import itertools
7+
import json
8+
import re
9+
from pathlib import Path
10+
11+
import matplotlib.pyplot as plt
12+
import numpy as np
13+
import pandas as pd
14+
import seaborn as sns
15+
import yaml
16+
17+
sns.set_theme(style="whitegrid")
18+
19+
20+
def peek(iterable):
21+
try:
22+
first = next(iterable)
23+
except StopIteration:
24+
return None
25+
return itertools.chain([first], iterable)
26+
27+
28+
def main():
29+
path_success_job = Path(
30+
"outputs/neuroips-workshop_2024/scale-by-architecture"
31+
).glob("scale-*/**/training_losses.npy")
32+
path_success_job = peek(path_success_job)
33+
34+
scaling_stats = pd.DataFrame()
35+
for p in path_success_job:
36+
# parse the path and get number of subjects
37+
log_file = p.parent / "train.log"
38+
with open(log_file, "r") as f:
39+
log_text = f.read()
40+
# parse the path and get number of subjects
41+
n_sample = int(
42+
re.search(r"Using ([\d]*) samples for training", log_text).group(1)
43+
)
44+
# get random seed
45+
random_seed = int(re.search(r"Random seed ([\d]*)", log_text).group(1))
46+
# load r2_val.npy get mean r2
47+
mean_r2_val = np.load(p.parent / "mean_r2_val.npy").tolist()
48+
mean_r2_tng = np.load(p.parent / "mean_r2_tng.npy").tolist()
49+
# get runtime from log file text
50+
starttime = re.search(
51+
r"\[([\d\-\s:,]*)\].*Process ID", log_text
52+
).group(1)
53+
endtime = re.search(
54+
r"\[([\d\-\s:,]*)\].*model trained", log_text
55+
).group(1)
56+
starttime = pd.to_datetime(starttime)
57+
endtime = pd.to_datetime(endtime)
58+
runtime = endtime - starttime
59+
60+
# convert to log scale
61+
runtime = runtime.total_seconds() / 60
62+
runtime_log = np.log10(runtime)
63+
64+
# read trian_param.json (which is an ymal...)
65+
if (p.parent / "train_param.json").exists():
66+
with open(p.parent / "train_param.json") as f:
67+
train_param = yaml.safe_load(f)
68+
else:
69+
with open(p.parent / "train_param.ymal") as f:
70+
train_param = yaml.safe_load(f)
71+
72+
train_param["M"] = int(train_param["M"].split(",")[0])
73+
# total number of parameters
74+
model_info_file = p.parent / "model_info_with_input.txt"
75+
with open(model_info_file, "r") as f:
76+
model_info = f.read()
77+
78+
total_parameters = int(
79+
re.search(r"Total params: ([\d,]*)", model_info)
80+
.group(1)
81+
.replace(",", "")
82+
)
83+
total_mult = float(
84+
re.search(r"Total mult-adds \(M\): ([\d.]*)", model_info).group(1)
85+
)
86+
total_size = float(
87+
re.search(
88+
r"Estimated Total Size \(MB\): ([\d.]*)", model_info
89+
).group(1)
90+
)
91+
92+
# # load connectome accuracy
93+
# prediction = pd.read_csv(
94+
# p.parent / "simple_classifiers_sex.tsv", sep="\t", index_col=0
95+
# )
96+
# prediction = prediction.loc[
97+
# prediction["classifier"] == "SVM", ["feature", "score"]
98+
# ]
99+
# prediction = prediction.set_index("feature")
100+
# prediction = prediction.T.reset_index(drop=True)
101+
102+
df = pd.DataFrame(
103+
[
104+
n_sample,
105+
train_param["GCL"],
106+
train_param["F"],
107+
train_param["K"],
108+
train_param["FCL"],
109+
train_param["M"],
110+
random_seed,
111+
mean_r2_val,
112+
mean_r2_tng,
113+
runtime,
114+
runtime_log,
115+
total_parameters,
116+
total_mult,
117+
total_size,
118+
],
119+
index=[
120+
"n_sample_train",
121+
"GCL",
122+
"F",
123+
"K",
124+
"FCL",
125+
"M",
126+
"random_seed",
127+
"mean_r2_val",
128+
"mean_r2_tng",
129+
"runtime",
130+
"runtime_log",
131+
"total_parameters",
132+
"total_mult",
133+
"total_size",
134+
],
135+
).T
136+
# df = pd.concat([df, prediction], axis=1)
137+
scaling_stats = pd.concat([scaling_stats, df], axis=0)
138+
139+
# sort by n_sample
140+
scaling_stats = scaling_stats.sort_values(by="n_sample_train")
141+
# for each n_sample, sort by random seed
142+
scaling_stats = scaling_stats.groupby("n_sample_train").apply(
143+
lambda x: x.sort_values(by="random_seed")
144+
)
145+
scaling_stats = scaling_stats.reset_index(drop=True)
146+
147+
scaling_stats.to_csv(
148+
"outputs/neuroips-workshop_2024/scale-by-architecture/reports/scaling_data.tsv",
149+
"\t",
150+
)
151+
152+
mask_compare_FCL = (
153+
(scaling_stats["GCL"] == 3)
154+
& (scaling_stats["F"] == 8)
155+
& (scaling_stats["K"] == 3)
156+
)
157+
158+
# fig, axs = plt.subplots(1, 2, figsize=(12, 6))
159+
fig = plt.figure()
160+
ax1 = fig.add_subplot(121)
161+
plot_compare_FCL = sns.heatmap(
162+
scaling_stats[mask_compare_FCL].pivot_table(
163+
index="M", columns="FCL", values="mean_r2_val"
164+
),
165+
cmap="coolwarm",
166+
square=True,
167+
linewidth=0.5,
168+
vmax=0.185,
169+
vmin=0.16,
170+
annot=True,
171+
fmt=".3f",
172+
cbar_kws={"label": "Mean R2 of validation set"},
173+
ax=ax1,
174+
)
175+
plot_compare_FCL.set_title(
176+
"Testing different parameters of MLP\nGCN architecture fixed; batch size ~8k"
177+
)
178+
plot_compare_FCL.set_xlabel("Number of fully connected layer")
179+
plot_compare_FCL.set_ylabel("Number of neurons per layer")
180+
# plot_compare_FCL.figure.savefig("outputs/neuroips-workshop_2024/scale-by-architecture/reports/compare_FCL.png")
181+
# plt.close()
182+
183+
mask_compare_GCL = (scaling_stats["M"] == 8) & (scaling_stats["FCL"] == 1)
184+
185+
# 3d scatter plot of F, GCL, K
186+
# fig = plt.figure()
187+
ax = fig.add_subplot(122, projection="3d")
188+
im = ax.scatter(
189+
scaling_stats[mask_compare_GCL]["F"],
190+
scaling_stats[mask_compare_GCL]["K"],
191+
scaling_stats[mask_compare_GCL]["GCL"],
192+
c=scaling_stats[mask_compare_GCL]["mean_r2_val"],
193+
cmap="coolwarm",
194+
s=100,
195+
vmin=0.16,
196+
vmax=0.185,
197+
)
198+
ax.set_xlabel("F")
199+
ax.set_xticks([8, 16, 32])
200+
ax.set_ylabel("K")
201+
ax.set_yticks([3, 5, 10])
202+
ax.set_zlabel("Number of layers")
203+
ax.set_zticks([3, 6, 9, 12])
204+
ax.set_title(
205+
"Testing different parameters of chebnet\nMLP architecture fixed; batch size ~8k"
206+
)
207+
# fig.colorbar(im, ax=ax, label="Mean R2 of validation set")
208+
fig.savefig(
209+
"outputs/neuroips-workshop_2024/scale-by-architecture/reports/compare_F-GCL-K.png"
210+
)
211+
plt.close()
212+
213+
for g in [3, 6, 9, 12]:
214+
cur_df = mask_compare_GCL & (scaling_stats["GCL"] == g)
215+
plot_compare_FCL = sns.heatmap(
216+
scaling_stats[cur_df].pivot_table(
217+
index="F", columns="K", values="mean_r2_val"
218+
),
219+
square=True,
220+
linewidth=0.5,
221+
vmax=0.185,
222+
vmin=0.16,
223+
annot=True,
224+
fmt=".3f",
225+
cmap="coolwarm",
226+
)
227+
plot_compare_FCL.set_title("Mean R2 of validation set")
228+
plot_compare_FCL.set_xlabel("K")
229+
plot_compare_FCL.set_ylabel("F")
230+
plot_compare_FCL.figure.savefig(
231+
f"outputs/neuroips-workshop_2024/scale-by-architecture/reports/compare_GCL-{g}.png"
232+
)
233+
plt.close()
234+
235+
for f in [8, 16, 32]:
236+
cur_df = mask_compare_GCL & (scaling_stats["F"] == f)
237+
plot_compare_FCL = sns.heatmap(
238+
scaling_stats[cur_df].pivot_table(
239+
index="GCL", columns="K", values="mean_r2_val"
240+
),
241+
square=True,
242+
linewidth=0.5,
243+
vmax=0.185,
244+
vmin=0.16,
245+
annot=True,
246+
fmt=".3f",
247+
cmap="coolwarm",
248+
)
249+
plot_compare_FCL.set_title("Mean R2 of validation set")
250+
plot_compare_FCL.set_xlabel("K")
251+
plot_compare_FCL.set_ylabel("Number of convolution layer")
252+
plot_compare_FCL.figure.savefig(
253+
f"outputs/neuroips-workshop_2024/scale-by-architecture/reports/compare_F-{f}.png"
254+
)
255+
plt.close()
256+
257+
for k in [3, 5, 10]:
258+
cur_df = mask_compare_GCL & (scaling_stats["K"] == k)
259+
plot_compare_FCL = sns.heatmap(
260+
scaling_stats[cur_df].pivot_table(
261+
index="GCL", columns="F", values="mean_r2_val"
262+
),
263+
square=True,
264+
linewidth=0.5,
265+
vmax=0.185,
266+
vmin=0.16,
267+
annot=True,
268+
fmt=".3f",
269+
cmap="coolwarm",
270+
)
271+
plot_compare_FCL.set_title("Mean R2 of validation set")
272+
plot_compare_FCL.set_xlabel("F")
273+
plot_compare_FCL.set_ylabel("Number of convolution layer")
274+
plot_compare_FCL.figure.savefig(
275+
f"outputs/neuroips-workshop_2024/scale-by-architecture/reports/compare_K-{k}.png"
276+
)
277+
plt.close()
278+
279+
# # stats[name] = scaling_stats
280+
# # alternative data to show missing experiment
281+
# # random seed as column and runtime as value
282+
# scaling_overview = scaling_stats.pivot(
283+
# index="n_sample", columns="random_seed", values="mean_r2_val"
284+
# )
285+
286+
# # give a summary of the random seed and n_sample pair
287+
# # with no runtime. this is because the experiment failed
288+
# incomplete_n_sample = scaling_overview.isna().sum(axis=1)
289+
# incomplete_n_sample = incomplete_n_sample[incomplete_n_sample > 0]
290+
# # make sure all possible n_sample are included
291+
# for n_sample in scaling_overview.index:
292+
# if n_sample not in incomplete_n_sample.index:
293+
# incomplete_n_sample[n_sample] = 0
294+
# incomplete_n_sample = incomplete_n_sample.sort_index()
295+
# missing_experiment = {}
296+
# for n_sample in incomplete_n_sample.index:
297+
# missing_experiment[n_sample] = scaling_overview.columns[
298+
# scaling_overview.loc[n_sample].isna()
299+
# ].tolist()
300+
# # save to json
301+
# with open(
302+
# "outputs/ccn2024/scaling_missing_experiment.json",
303+
# "w",
304+
# ) as f:
305+
# json.dump(missing_experiment, f, indent=2)
306+
307+
# plt.figure(figsize=(7, 4.5))
308+
# # plot
309+
# sns.lineplot(
310+
# data=scaling_stats,
311+
# x="n_sample_train",
312+
# y="mean_r2_tng",
313+
# marker="o",
314+
# label="Traing set",
315+
# )
316+
# sns.lineplot(
317+
# data=scaling_stats,
318+
# x="n_sample_train",
319+
# y="mean_r2_val",
320+
# marker="o",
321+
# label="Validation set",
322+
# )
323+
# plt.ylim(0.10, 0.19)
324+
# plt.xlabel("Number of subject in model training")
325+
# plt.ylabel("R-squared")
326+
# plt.legend()
327+
# plt.title("R-squared of t+1 prediction")
328+
# plt.savefig("outputs/ccn2024/scaling_r2_tng_plot.png")
329+
# plt.close()
330+
331+
# plt.figure(figsize=(7, 4.5))
332+
# sns.lineplot(
333+
# data=scaling_stats,
334+
# x="n_sample_train",
335+
# y="runtime_log",
336+
# marker="o",
337+
# )
338+
# plt.xlabel("Number of subject in model training")
339+
# plt.ylabel("log10(runtime) (minutes)")
340+
# plt.title("Runtime of training a group model")
341+
# plt.savefig("outputs/ccn2024/scaling_runtime_plot.png")
342+
# plt.close()
343+
344+
# plt.figure(figsize=(7, 4.5))
345+
# # plot
346+
# features = prediction.columns.tolist()
347+
# for y, label in zip(
348+
# features,
349+
# [
350+
# "connectomes",
351+
# "average pooling",
352+
# "standard deviation pooling",
353+
# "max pooling",
354+
# "1D convolution",
355+
# "average R-squared",
356+
# "R-squared map",
357+
# ],
358+
# ):
359+
# if label in [
360+
# "connectomes",
361+
# "standard deviation pooling",
362+
# "R-squared map",
363+
# ]:
364+
# sns.lineplot(
365+
# data=scaling_stats,
366+
# x="n_sample_downstream",
367+
# y=y,
368+
# marker="o",
369+
# label=label,
370+
# )
371+
# plt.xlabel("Number of subject in prediction task")
372+
# plt.ylabel("Accuracy")
373+
# plt.legend()
374+
# plt.title("Sex prediction accuracy with SVM")
375+
# plt.savefig("outputs/ccn2024/_scaling_connectome.png")
376+
# plt.close()
377+
378+
379+
if __name__ == "__main__":
380+
main()

‎src/utils/plot_scaling_downstream.py

Lines changed: 119 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
plot the number of subjects (y-axis) against R2 (x axis)
55
"""
66
import itertools
7-
import json
87
import re
98
from pathlib import Path
109

@@ -22,80 +21,131 @@ def peek(iterable):
2221
return itertools.chain([first], iterable)
2322

2423

25-
def main():
26-
path_success_job = Path("outputs/autoreg/predict/downstream").glob(
27-
"**/simple_classifiers_sex.tsv"
28-
)
29-
# path_success_job = peek(path_success_job)
24+
feature_fullname = {
25+
"connectome": "Connectome\n(baseline)",
26+
"avgr2": "t+1\naverage R2",
27+
"r2map": "t+1\nR2 map",
28+
"conv_avg": "Conv layers\navg pooling",
29+
"conv_std": "Conv layers\nstd pooling",
30+
"conv_max": "Conv layers\nmax pooling",
31+
"conv_conv1d": "Conv layers\n1D convolution",
32+
}
33+
34+
diagnosis_fullname = {
35+
"sex": "Sex",
36+
"DEP": "Depressive\ndisorder",
37+
"ALCO": "Alcohol Abuse",
38+
"EPIL": "Epilepsy",
39+
"MS": "Multiple sclerosis",
40+
"PARK": "Parkinson",
41+
"BIPOLAR": "Bipolar",
42+
"ADD": "Alzheimer - Dementia",
43+
"SCZ": "Schizophrenia",
44+
}
45+
PREDICTION_DATA = Path(
46+
"outputs/neuroips-workshop_2024/downstreams_last-layer/data"
47+
)
3048

31-
scaling_stats = pd.DataFrame()
32-
for p in path_success_job:
33-
log_file = p.parent / "predict.log"
34-
with open(log_file, "r") as f:
35-
log_text = f.read()
36-
# parse the path and get number of subjects
37-
n_sample = re.search(
38-
r"Subjects with phenotype data: ([\d]*)", log_text
39-
).group(1)
40-
n_sample = int(n_sample)
41-
percent_sample = re.search(
42-
r"([\d]*)% of the full sample", log_text
43-
).group(1)
44-
# get random seed
45-
random_seed = re.search(r"'random\_state': ([\d]+)", log_text).group(1)
46-
# load connectome accuracy
47-
prediction = pd.read_csv(p, sep="\t", index_col=0)
48-
prediction["percent_sample"] = int(percent_sample)
49-
prediction["n_sample"] = n_sample
50-
prediction["random_seed"] = random_seed
5149

52-
scaling_stats = pd.concat([scaling_stats, prediction], axis=0)
50+
def main():
51+
sns.set_theme(style="whitegrid")
52+
sns.set_context("paper", font_scale=1.5)
53+
pal = sns.color_palette()
54+
for d in diagnosis_fullname.keys():
55+
path_success_job = PREDICTION_DATA.glob(
56+
f"data.proportion_sample_*/**/simple_classifiers_{d}.tsv"
57+
)
58+
scaling_stats = pd.DataFrame()
59+
for p in path_success_job:
60+
log_file = p.parent / "predict.log"
61+
with open(log_file, "r") as f:
62+
log_text = f.read()
63+
n_holdout = int(
64+
re.search(
65+
r"Downstream prediction on ([\d]*),", log_text
66+
).group(1)
67+
)
68+
# parse the path and get number of subjects
69+
percent_training_sample = p.parents[0].name.split("_")[-1]
70+
percent_training_sample = float(percent_training_sample) * 100
71+
percent_holdout_sample = re.search(
72+
r"([\d]*)% of the full sample", log_text
73+
).group(1)
74+
# get random seed
75+
random_seed = re.search(
76+
r"'random\_state': ([\d]+)", log_text
77+
).group(1)
78+
# load connectome accuracy
79+
prediction = pd.read_csv(p, sep="\t", index_col=0)
80+
prediction["percent_holdout_sample"] = int(percent_holdout_sample)
81+
prediction["percent_training_sample"] = percent_training_sample
82+
prediction["random_seed"] = random_seed
5383

54-
# sort by n_sample
55-
scaling_stats = scaling_stats.sort_values(by="n_sample")
56-
# for each n_sample, sort by random seed
57-
scaling_stats = scaling_stats.groupby("n_sample").apply(
58-
lambda x: x.sort_values(by="random_seed")
59-
)
60-
scaling_stats = scaling_stats.reset_index(drop=True)
84+
scaling_stats = pd.concat([scaling_stats, prediction], axis=0)
6185

62-
scaling_stats.to_csv(
63-
"outputs/autoreg/predict/downstream/downstream_scaling_data.csv"
64-
)
86+
# sort by n_sample
87+
scaling_stats = scaling_stats.sort_values(by="percent_training_sample")
88+
# # for each n_sample, sort by random seed
89+
# scaling_stats = scaling_stats.groupby("percent_training_sample").apply(
90+
# lambda x: x.sort_values(by="random_seed")
91+
# )
92+
scaling_stats = scaling_stats.reset_index(drop=True)
6593

66-
mask = scaling_stats["classifier"] == "SVM"
67-
plt.figure(figsize=(7, 4.5))
68-
# plot
69-
features = prediction["feature"].unique().tolist()
70-
for y, label in zip(
71-
features,
72-
[
73-
"connectomes",
74-
"average pooling",
75-
"standard deviation pooling",
76-
"max pooling",
77-
"1D convolution",
78-
"average R-squared",
79-
"R-squared map",
80-
],
81-
):
82-
feat_mask = scaling_stats["feature"] == y
83-
cur_mask = mask & feat_mask
84-
sns.lineplot(
85-
data=scaling_stats[cur_mask],
86-
x="percent_sample",
87-
y="score",
88-
marker="o",
89-
label=label,
94+
scaling_stats.to_csv(
95+
PREDICTION_DATA.parent / f"reports/downstream_scaling_{d}.tsv",
96+
sep="\t",
97+
)
98+
# replace feature name
99+
scaling_stats["percent_training_sample"] = np.log10(
100+
scaling_stats["percent_training_sample"]
101+
)
102+
scaling_stats["feature"] = scaling_stats["feature"].replace(
103+
feature_fullname
90104
)
91-
plt.xlabel("Percent of subject in the downstream prediction.")
92-
plt.ylabel("Accuracy")
93-
plt.legend()
94-
plt.title(
95-
"Sex prediction accuracy with SVM with saturated pretrained model."
96-
)
97-
plt.savefig("outputs/autoreg/predict/downstream/downstream_scaling.png")
98-
plt.close()
105+
for clf in scaling_stats["classifier"].unique():
106+
mask = scaling_stats["classifier"] == clf
107+
no_connectome = (
108+
scaling_stats["feature"] != "Connectome\n(baseline)"
109+
)
110+
is_connectome = ~no_connectome
111+
plt.figure(figsize=(7, 4.5))
112+
benchmark = scaling_stats[mask & is_connectome]["score"].mean()
113+
plt.axhline(
114+
y=benchmark, color=pal[0], linestyle="-.", label="Connectome"
115+
)
116+
# plot
117+
features = prediction["feature"].unique().tolist()
118+
sns.lineplot(
119+
data=scaling_stats[mask & no_connectome],
120+
x="percent_training_sample",
121+
y="score",
122+
hue="feature",
123+
marker="o",
124+
errorbar="ci",
125+
palette=pal[1 : len(features)],
126+
)
127+
if d != "sex":
128+
plt.axhline(y=0.5, color="k", linestyle="--", label="Chance")
129+
else:
130+
plt.axhline(
131+
y=3341 / n_holdout,
132+
color="k",
133+
linestyle="--",
134+
label="Chance",
135+
)
136+
plt.xlabel("Percent of subject in the pretrained model")
137+
plt.xticks(scaling_stats["percent_training_sample"].unique())
138+
plt.ylabel("Accuracy")
139+
plt.legend(bbox_to_anchor=(1, 1))
140+
plt.title(
141+
f"{diagnosis_fullname[d]} prediction accuracy with {clf}"
142+
)
143+
plt.tight_layout()
144+
plt.savefig(
145+
PREDICTION_DATA.parent
146+
/ f"reports/downstream_scaling_{d}_{clf}.png"
147+
)
148+
plt.close()
99149

100150

101151
if __name__ == "__main__":
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
look through the `outputs/` directory, find instance of completed
3+
training, and get the number of subjects used, mean R2 of test set,
4+
plot the number of subjects (y-axis) against R2 (x axis)
5+
"""
6+
import itertools
7+
import re
8+
from pathlib import Path
9+
10+
import matplotlib.pyplot as plt
11+
import numpy as np
12+
import pandas as pd
13+
import seaborn as sns
14+
15+
16+
def peek(iterable):
17+
try:
18+
first = next(iterable)
19+
except StopIteration:
20+
return None
21+
return itertools.chain([first], iterable)
22+
23+
24+
feature_fullname = {
25+
"connectome": "Connectome\n(baseline)",
26+
"avgr2": "t+1\naverage R2",
27+
"r2map": "t+1\nR2 map",
28+
"conv_avg": "Conv layers\navg pooling",
29+
"conv_std": "Conv layers\nstd pooling",
30+
"conv_max": "Conv layers\nmax pooling",
31+
"conv_conv1d": "Conv layers\n1D convolution",
32+
}
33+
34+
diagnosis_fullname = {
35+
"sex": "Sex",
36+
"DEP": "Depressive\ndisorder",
37+
"ALCO": "Alcohol Abuse",
38+
"EPIL": "Epilepsy",
39+
"MS": "Multiple sclerosis",
40+
"PARK": "Parkinson",
41+
"BIPOLAR": "Bipolar",
42+
"ADD": "Alzheimer - Dementia",
43+
"SCZ": "Schizophrenia",
44+
}
45+
PREDICTION_DATA = Path(
46+
"outputs/neuroips-workshop_2024/downstreams_fewshot/data"
47+
)
48+
49+
50+
def main():
51+
sns.set_theme(style="whitegrid")
52+
sns.set_context("paper", font_scale=1.5)
53+
pal = sns.color_palette()
54+
for d in diagnosis_fullname.keys():
55+
path_success_job = PREDICTION_DATA.glob(
56+
f"**/simple_classifiers_{d}.tsv"
57+
)
58+
scaling_stats = pd.DataFrame()
59+
for p in path_success_job:
60+
log_file = p.parent / "predict.log"
61+
with open(log_file, "r") as f:
62+
log_text = f.read()
63+
n_holdout = int(
64+
re.search(
65+
r"Downstream prediction on ([\d]*),", log_text
66+
).group(1)
67+
)
68+
# parse the path and get number of subjects
69+
percent_holdout_sample = re.search(
70+
r"([\d]*)% of the full sample", log_text
71+
).group(1)
72+
# get random seed
73+
print(percent_holdout_sample)
74+
random_seed = re.search(
75+
r"'random\_state': ([\d]+)", log_text
76+
).group(1)
77+
# load connectome accuracy
78+
prediction = pd.read_csv(p, sep="\t", index_col=0)
79+
prediction["percent_holdout_sample"] = int(percent_holdout_sample)
80+
prediction["random_seed"] = random_seed
81+
82+
scaling_stats = pd.concat([scaling_stats, prediction], axis=0)
83+
84+
# sort by n_sample
85+
scaling_stats = scaling_stats.sort_values(by="percent_holdout_sample")
86+
# # for each n_sample, sort by random seed
87+
# scaling_stats = scaling_stats.groupby("percent_training_sample").apply(
88+
# lambda x: x.sort_values(by="random_seed")
89+
# )
90+
scaling_stats = scaling_stats.reset_index(drop=True)
91+
92+
scaling_stats.to_csv(
93+
PREDICTION_DATA.parent / f"reports/downstream_fewshot_{d}.tsv",
94+
sep="\t",
95+
)
96+
# replace feature name
97+
scaling_stats["feature"] = scaling_stats["feature"].replace(
98+
feature_fullname
99+
)
100+
for clf in scaling_stats["classifier"].unique():
101+
mask = scaling_stats["classifier"] == clf
102+
plt.figure(figsize=(7, 4.5))
103+
# plot
104+
features = prediction["feature"].unique().tolist()
105+
sns.lineplot(
106+
data=scaling_stats[mask],
107+
x="percent_holdout_sample",
108+
y="score",
109+
hue="feature",
110+
hue_order=feature_fullname.values(),
111+
marker="o",
112+
errorbar=("ci", 95),
113+
)
114+
if d != "sex":
115+
plt.axhline(y=0.5, color="k", linestyle="--", label="Chance")
116+
plt.xlabel("Percent of subject in the downstream task")
117+
plt.xticks(scaling_stats["percent_holdout_sample"].unique())
118+
plt.ylabel("Accuracy")
119+
plt.legend(bbox_to_anchor=(1, 1))
120+
plt.title(
121+
f"{diagnosis_fullname[d]} prediction accuracy with {clf}"
122+
)
123+
plt.tight_layout()
124+
plt.savefig(
125+
PREDICTION_DATA.parent
126+
/ f"reports/downstream_fewshot_{d}_{clf}.png"
127+
)
128+
plt.close()
129+
130+
131+
if __name__ == "__main__":
132+
main()

‎src/utils/plot_scaling_sample.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""
2+
look through the `outputs/` directory, find instance of completed
3+
training, and get the number of subjects used, mean R2 of test set,
4+
plot the number of subjects (y-axis) against R2 (x axis)
5+
"""
6+
import itertools
7+
import json
8+
import re
9+
from pathlib import Path
10+
11+
import matplotlib.pyplot as plt
12+
import numpy as np
13+
import pandas as pd
14+
import seaborn as sns
15+
import yaml
16+
17+
sns.set_theme(style="whitegrid")
18+
sns.set_context("paper", font_scale=1.5)
19+
20+
21+
def peek(iterable):
22+
try:
23+
first = next(iterable)
24+
except StopIteration:
25+
return None
26+
return itertools.chain([first], iterable)
27+
28+
29+
# BASE_PATH = "outputs/neuroips-workshop_2024/scale-sample_bestmodel_different-num_workers" # noqa
30+
BASE_PATH = "outputs/neuroips-workshop_2024/scale-sample_bestmodel"
31+
32+
33+
def main():
34+
path_success_job = Path(BASE_PATH).glob("data/**/training_losses.npy")
35+
path_success_job = peek(path_success_job)
36+
37+
scaling_stats = pd.DataFrame()
38+
for p in path_success_job:
39+
# parse the path and get number of subjects
40+
log_file = p.parent / "train.log"
41+
with open(log_file, "r") as f:
42+
log_text = f.read()
43+
# parse the path and get number of subjects
44+
n_sample = int(
45+
re.search(r"Using ([\d]*) samples for training", log_text).group(1)
46+
)
47+
# get random seed
48+
random_seed = int(re.search(r"Random seed ([\d]*)", log_text).group(1))
49+
# load r2_val.npy get mean r2
50+
mean_r2_val = np.load(p.parent / "mean_r2_val.npy").tolist()
51+
mean_r2_tng = np.load(p.parent / "mean_r2_tng.npy").tolist()
52+
# get runtime from log file text
53+
starttime = re.search(
54+
r"\[([\d\-\s:,]*)\].*Process ID", log_text
55+
).group(1)
56+
endtime = re.search(
57+
r"\[([\d\-\s:,]*)\].*model trained", log_text
58+
).group(1)
59+
starttime = pd.to_datetime(starttime)
60+
endtime = pd.to_datetime(endtime)
61+
runtime = endtime - starttime
62+
63+
# convert to log scale
64+
runtime = runtime.total_seconds() / 60
65+
runtime_log = np.log10(runtime)
66+
# total number of parameters
67+
model_info_file = p.parent / "model_info_with_input.txt"
68+
if model_info_file.exists():
69+
with open(model_info_file, "r") as f:
70+
model_info = f.read()
71+
total_parameters = int(
72+
re.search(r"Total params: ([\d,]*)", model_info)
73+
.group(1)
74+
.replace(",", "")
75+
)
76+
total_mult = float(
77+
re.search(
78+
r"Total mult-adds \(M\): ([\d.]*)", model_info
79+
).group(1)
80+
)
81+
total_size = float(
82+
re.search(
83+
r"Estimated Total Size \(MB\): ([\d.]*)", model_info
84+
).group(1)
85+
)
86+
else:
87+
total_parameters = np.nan
88+
total_mult = np.nan
89+
total_size = np.nan
90+
# # load connectome accuracy
91+
# prediction = pd.read_csv(
92+
# p.parent / "simple_classifiers_sex.tsv", sep="\t", index_col=0 # noqa
93+
# )
94+
# prediction = prediction.loc[
95+
# prediction["classifier"] == "SVM", ["feature", "score"]
96+
# ]
97+
# prediction = prediction.set_index("feature")
98+
# prediction = prediction.T.reset_index(drop=True)
99+
100+
df = pd.DataFrame(
101+
[
102+
n_sample,
103+
random_seed,
104+
mean_r2_val,
105+
mean_r2_tng,
106+
runtime,
107+
runtime_log,
108+
total_parameters,
109+
total_mult,
110+
total_size,
111+
],
112+
index=[
113+
"n_sample_train",
114+
"random_seed",
115+
"mean_r2_val",
116+
"mean_r2_tng",
117+
"runtime",
118+
"runtime_log",
119+
"total_parameters",
120+
"total_mult",
121+
"total_size",
122+
],
123+
).T
124+
# df = pd.concat([df, prediction], axis=1)
125+
scaling_stats = pd.concat([scaling_stats, df], axis=0)
126+
127+
# sort by n_sample
128+
scaling_stats = scaling_stats.sort_values(by="n_sample_train")
129+
# for each n_sample, sort by random seed
130+
scaling_stats = scaling_stats.groupby("n_sample_train").apply(
131+
lambda x: x.sort_values(by="random_seed")
132+
)
133+
scaling_stats = scaling_stats.reset_index(drop=True)
134+
scaling_stats["percent_sample"] = scaling_stats["n_sample_train"] / 2328583
135+
scaling_stats["percent_sample"] = (
136+
scaling_stats["percent_sample"].round(3) * 100
137+
)
138+
139+
scaling_stats.to_csv(Path(BASE_PATH) / "reports/scaling_data.tsv", "\t")
140+
141+
# alternative data to show missing experiment
142+
# random seed as column and runtime as value
143+
scaling_overview = scaling_stats.pivot(
144+
index="percent_sample", columns="random_seed", values="mean_r2_val"
145+
)
146+
147+
# give a summary of the random seed and n_sample pair
148+
# with no runtime. this is because the experiment failed
149+
incomplete_n_sample = scaling_overview.isna().sum(axis=1)
150+
incomplete_n_sample = incomplete_n_sample[incomplete_n_sample > 0]
151+
# make sure all possible n_sample are included
152+
for n_sample in scaling_overview.index:
153+
if n_sample not in incomplete_n_sample.index:
154+
incomplete_n_sample[n_sample] = 0
155+
incomplete_n_sample = incomplete_n_sample.sort_index()
156+
missing_experiment = {}
157+
for n_sample in incomplete_n_sample.index:
158+
missing_experiment[n_sample] = scaling_overview.columns[
159+
scaling_overview.loc[n_sample].isna()
160+
].tolist()
161+
# save to json
162+
with open(
163+
Path(BASE_PATH) / "reports/scaling_missing_experiment.json",
164+
"w",
165+
) as f:
166+
json.dump(missing_experiment, f, indent=2)
167+
168+
plt.figure(figsize=(5, 5))
169+
# plot
170+
sns.lineplot(
171+
data=scaling_stats,
172+
x="percent_sample",
173+
y="mean_r2_tng",
174+
marker="o",
175+
label="Traing set",
176+
)
177+
sns.lineplot(
178+
data=scaling_stats,
179+
x="percent_sample",
180+
y="mean_r2_val",
181+
marker="o",
182+
label="Validation set",
183+
)
184+
plt.ylim(0.145, 0.185)
185+
plt.xticks([0, 5, 10, 25, 50, 100])
186+
plt.xlabel("Percentage of training sample")
187+
plt.ylabel("R-squared")
188+
plt.legend()
189+
plt.title("R-squared of t+1 prediction")
190+
plt.tight_layout()
191+
plt.savefig(Path(BASE_PATH) / "reports/scaling_r2_tng_plot.png")
192+
plt.close()
193+
194+
plt.figure(figsize=(5, 5))
195+
sns.lineplot(
196+
data=scaling_stats,
197+
x="percent_sample",
198+
y="runtime_log",
199+
marker="o",
200+
)
201+
plt.xticks([0, 5, 10, 25, 50, 100])
202+
plt.xlabel("Percentage of training sample")
203+
plt.ylabel("log10(runtime) (minutes)")
204+
plt.title("Runtime of training")
205+
plt.tight_layout()
206+
plt.savefig(Path(BASE_PATH) / "reports/scaling_runtime_plot.png")
207+
plt.close()
208+
209+
210+
if __name__ == "__main__":
211+
main()

‎tools/find_batch_size.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
Resource:
3+
salloc --time=2:00:00 --mem=16G --cpus-per-task=16 --gpus-per-node=1
4+
Aim:
5+
- Fill up the GPU memory with the largest batch size possible
6+
"""
7+
8+
import typing as t
9+
10+
import numpy as np
11+
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
from fmri_autoreg.models.models import Chebnet
15+
16+
DATASET_SIZE = 2328583 # number of data point in training set
17+
SEQ = 16
18+
19+
20+
# make a random correlation matrix
21+
def get_edges(n_emb):
22+
ts = np.random.rand(n_emb, 117)
23+
corr = np.corrcoef(ts)
24+
thres_index = int(corr.shape[0] * corr.shape[1] * 0.9)
25+
thres_value = np.sort(corr.flatten())[thres_index]
26+
adj_mat = corr * (corr >= thres_value)
27+
edge_index = np.nonzero(adj_mat)
28+
return edge_index
29+
30+
31+
def get_batch_size(
32+
model: nn.Module,
33+
device: torch.device,
34+
input_shape: t.Tuple[int, int, int],
35+
output_shape: t.Tuple[int],
36+
dataset_size: int,
37+
max_batch_size: int = None,
38+
num_iterations: int = 5,
39+
) -> int:
40+
model.to(device)
41+
model.train(True)
42+
optimizer = torch.optim.Adam(model.parameters())
43+
44+
batch_size = 2
45+
while True:
46+
if max_batch_size is not None and batch_size >= max_batch_size:
47+
batch_size = max_batch_size
48+
break
49+
if batch_size >= dataset_size:
50+
batch_size = batch_size // 2
51+
break
52+
try:
53+
for _ in range(num_iterations):
54+
# dummy inputs and targets
55+
inputs = torch.rand(*(batch_size, *input_shape), device=device)
56+
targets = torch.rand(
57+
*(batch_size, *output_shape), device=device
58+
)
59+
outputs = model(inputs)
60+
loss = F.mse_loss(targets, outputs)
61+
loss.backward()
62+
optimizer.step()
63+
optimizer.zero_grad()
64+
batch_size *= 2
65+
except RuntimeError:
66+
batch_size //= 2
67+
break
68+
del model, optimizer
69+
torch.cuda.empty_cache()
70+
return batch_size
71+
72+
73+
if __name__ == "__main__":
74+
for n_emb in [64, 197, 444]:
75+
edge_index = get_edges(n_emb)
76+
print("our hypothetical biggest model")
77+
model = Chebnet(
78+
n_emb=n_emb,
79+
seq_len=16,
80+
edge_index=edge_index,
81+
FK="16,3,16,3,16,3,16,3,16,3,16,3",
82+
M="8,1",
83+
FC_type="nonshared_uni",
84+
aggrs="add",
85+
dropout=0.1,
86+
bn_momentum=0.1,
87+
use_bn=True,
88+
)
89+
batch_size = get_batch_size(
90+
model,
91+
torch.device("cuda"),
92+
(n_emb, SEQ),
93+
(n_emb,),
94+
DATASET_SIZE,
95+
num_iterations=10,
96+
)
97+
print(f"atlas {n_emb}, input length {SEQ}, batch size {batch_size}")
98+
del model
99+
del edge_index

‎tools/number_of_workers.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from time import time
2+
3+
import h5py
4+
from fmri_autoreg.data.load_data import Dataset
5+
from torch.utils.data import DataLoader, Subset
6+
from tqdm import tqdm
7+
8+
proportion_sample = 1
9+
tng_data_h5 = (
10+
"outputs/sample_for_pretraining/seed-42/sample_seed-42_data-train.h5"
11+
)
12+
IS_GPU = False
13+
N_EMBED = [64, 197, 444]
14+
BATCHSIZE = [512]
15+
16+
with open("outputs/performance_info/cpu_number_of_workers.tsv", "w") as f:
17+
f.write("batch_size\tn_embed\tnum_workers\tepoch_second\n")
18+
19+
for n_embed in N_EMBED:
20+
if proportion_sample != 1:
21+
with h5py.File(tng_data_h5, "r") as f:
22+
tng_length = f[f"n_embed-{n_embed}"]["train"]["input"].shape[0]
23+
tng_index = list(range(int(tng_length * proportion_sample)))
24+
tng_dataset = Subset(
25+
Dataset(
26+
tng_data_h5, n_embed=f"n_embed-{n_embed}", set_type="train"
27+
),
28+
tng_index,
29+
)
30+
else:
31+
tng_dataset = Dataset(
32+
tng_data_h5, n_embed=f"n_embed-{n_embed}", set_type="train"
33+
)
34+
for batch_size in [512]:
35+
for num_workers in range(8, 34, 2):
36+
train_loader = DataLoader(
37+
tng_dataset,
38+
shuffle=True,
39+
num_workers=num_workers,
40+
batch_size=batch_size,
41+
pin_memory=IS_GPU,
42+
)
43+
start = time()
44+
for _ in tqdm(
45+
range(1, 3),
46+
desc=f"batch_size={batch_size}; n_embed={n_embed}; Number of workers: {num_workers}",
47+
):
48+
for _, _ in enumerate(train_loader, 0):
49+
pass
50+
end = time()
51+
taken = (end - start) / 2
52+
with open(
53+
"outputs/performance_info/cpu_number_of_workers.tsv", "a"
54+
) as f:
55+
f.write(f"{batch_size}\t{n_embed}\t{num_workers}\t{taken}\n")

0 commit comments

Comments
 (0)
Please sign in to comment.