Skip to content

Commit 4e8285e

Browse files
author
Hao-Ting Wang
authored
Balanced sample (#13)
* Inital version of creating a hold out set; need to modify the training script late; need to output the cohort demographic info * save demographic info summary * hydrafy * ENH adapt the training script to fit the new input
1 parent 6c00f53 commit 4e8285e

File tree

7 files changed

+270
-58
lines changed

7 files changed

+270
-58
lines changed

config/base.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
defaults:
3+
- _self_
4+
- hydra: default
5+
6+
verbose: 2
7+
random_state: 42
8+
return_type: float

config/data/default.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
---
2+
standardize: false
3+
n_embed: 197
4+
atlas_desc: atlas-MIST_desc-${data.n_embed}
5+
hold_out_set: 0.20
6+
validation_set: 0.25
7+
n_sample: -1
8+
class_balance_confounds:
9+
- site
10+
- sex
11+
- age
12+
- mean_fd_raw
13+
- proportion_kept

config/data/ukbb.yaml

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
data_file: inputs/connectomes/ukbb.h5
2-
standardize: false
3-
n_embed: 197
4-
n_sample: -1
1+
---
2+
defaults:
3+
- _self_
4+
- default
55

6-
split: # training and evaluation
7-
_target_: src.data.load_data.load_ukbb_dset_path
8-
path: ${data.data_file}
9-
atlas_desc: atlas-MIST_desc-${data.n_embed}
10-
n_sample: ${data.n_sample}
11-
val_set: 0.20
12-
test_set: 0.20
13-
segment: 1
14-
random_state: ${random_state}
6+
data_file: inputs/connectomes/ukbb_libral_scrub_20240716_connectome.h5
7+
phenotype_file: inputs/connectomes/ukbb_libral_scrub_20240716_phenotype.tsv
8+
phenotype_json: inputs/connectomes/ukbb_libral_scrub_20240716_phenotype.json
9+
segment: 1

config/train.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ defaults:
88
verbose: 2
99
random_state: 42
1010
return_type: float
11+
data_split: ???

src/create_holdout_sample.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import json
2+
import logging
3+
from pathlib import Path
4+
5+
import hydra
6+
import matplotlib.pyplot as plt
7+
import pandas as pd
8+
import seaborn as sns
9+
from omegaconf import DictConfig
10+
11+
log = logging.getLogger(__name__)
12+
13+
14+
@hydra.main(version_base="1.3", config_path="../config", config_name="base")
15+
def main(params: DictConfig) -> None:
16+
from src.data.load_data import create_hold_out_sample
17+
18+
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
19+
output_dir = Path(output_dir)
20+
21+
sample = create_hold_out_sample(
22+
phenotype_path=params["data"]["phenotype_file"],
23+
phenotype_meta=params["data"]["phenotype_json"],
24+
class_balance_confounds=params["data"]["class_balance_confounds"],
25+
hold_out_set=params["data"]["hold_out_set"],
26+
random_state=params["random_state"],
27+
)
28+
29+
data = pd.read_csv(params["data"]["phenotype_file"], sep="\t", index_col=0)
30+
31+
with open(params["data"]["phenotype_json"], "r") as f:
32+
meta = json.load(f)
33+
34+
with open(output_dir / "downstream_sample.json", "w") as f:
35+
json.dump(sample, f, indent=2)
36+
37+
# plot the distribution of confounds of downstreams balanced samples
38+
demographics = {}
39+
for d in sample["test_downstreams"].keys():
40+
d_subjects = sample["test_downstreams"][d]
41+
df = data.loc[d_subjects, :]
42+
fig, axes = plt.subplots(
43+
1,
44+
len(params["data"]["class_balance_confounds"]),
45+
figsize=(20, len(params["data"]["class_balance_confounds"]) + 1),
46+
)
47+
fig.suptitle(
48+
f"Confound balanced sample (N={len(d_subjects)}): "
49+
f"{meta[d]['instance']['1']['description']}"
50+
)
51+
for ax, c in zip(axes, params["data"]["class_balance_confounds"]):
52+
sns.histplot(x=c, data=df, hue=d, kde=True, ax=ax)
53+
fig.savefig(output_dir / f"{d}.png")
54+
demographics[d] = {
55+
"patient": {
56+
"condition": d,
57+
"total": df[df[d] == 1].shape[0],
58+
"n_female": df[df[d] == 1].shape[0]
59+
- df[df[d] == 1]["sex"].sum(),
60+
"age_mean": df[df[d] == 1]["age"].mean(),
61+
"age_sd": df[df[d] == 1]["age"].std(),
62+
"mean_fd_mean": df[df[d] == 1]["mean_fd_raw"].mean(),
63+
"mean_fd_sd": df[df[d] == 1]["mean_fd_raw"].std(),
64+
"proportion_kept_mean": df[df[d] == 1][
65+
"proportion_kept"
66+
].mean(),
67+
"proportion_kept_sd": df[df[d] == 1]["proportion_kept"].std(),
68+
},
69+
"control": {
70+
"condition": d,
71+
"total": df[df[d] == 0].shape[0],
72+
"n_female": df[df[d] == 0].shape[0]
73+
- df[df[d] == 0]["sex"].sum(),
74+
"age_mean": df[df[d] == 0]["age"].mean(),
75+
"age_sd": df[df[d] == 0]["age"].std(),
76+
"mean_fd_mean": df[df[d] == 0]["mean_fd_raw"].mean(),
77+
"mean_fd_sd": df[df[d] == 0]["mean_fd_raw"].std(),
78+
"proportion_kept_mean": df[df[d] == 0][
79+
"proportion_kept"
80+
].mean(),
81+
"proportion_kept_sd": df[df[d] == 0]["proportion_kept"].std(),
82+
},
83+
}
84+
85+
demographics_summary = pd.DataFrame()
86+
for d in demographics.keys():
87+
df = pd.DataFrame.from_dict(demographics[d], orient="index")
88+
df.set_index([df.index, "condition"], inplace=True)
89+
demographics_summary = pd.concat([demographics_summary, df])
90+
demographics_summary.round(decimals=2).to_csv(
91+
output_dir / "demographics_summary.tsv", sep="\t"
92+
)
93+
94+
for key in sample.keys():
95+
if key == "test_downstreams":
96+
continue
97+
d_subjects = sample[key]
98+
df = data.loc[d_subjects, :]
99+
fig, axes = plt.subplots(
100+
1,
101+
len(params["data"]["class_balance_confounds"]),
102+
figsize=(20, len(params["data"]["class_balance_confounds"]) + 1),
103+
)
104+
fig.suptitle(f"{key} sample (N={len(d_subjects)})")
105+
for ax, c in zip(axes, params["data"]["class_balance_confounds"]):
106+
sns.histplot(x=c, data=df, kde=True, ax=ax)
107+
fig.savefig(output_dir / f"{key}.png")
108+
109+
110+
if __name__ == "__main__":
111+
main()

src/data/load_data.py

Lines changed: 81 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import h5py
99
import numpy as np
1010
import pandas as pd
11+
from general_class_balancer import general_class_balancer as gcb
1112
from nilearn.connectome import ConnectivityMeasure
1213
from sklearn.model_selection import train_test_split
1314
from sklearn.preprocessing import StandardScaler
@@ -93,32 +94,96 @@ def split_data_by_site(
9394
return tng_data, test_data
9495

9596

97+
def create_hold_out_sample(
98+
phenotype_path: Union[Path, str],
99+
phenotype_meta: Union[Path, str],
100+
class_balance_confounds: List[str],
101+
hold_out_set: float = 0.25,
102+
random_state: int = 42,
103+
) -> Dict:
104+
"""Create experiment sample with patients in the hold out set.
105+
106+
Args:
107+
phenotype_path (Union[Path, str]): Path to the tsv file.
108+
Column index 0 must be participant_id.
109+
phenotype_meta (Union[Path, str]): Path to the json file.
110+
confounds (List[str]): list of confounds to use for class
111+
balancing.
112+
hold_out_set (float, optional): proportion of the test set size
113+
in relation to the full sample. Defaults to 0.25.
114+
random_state (int, optional): random state for reproducibility.
115+
Returns:
116+
dict: dictionary with list of participant ID for training and
117+
hold out set, and the downstream task samples.
118+
"""
119+
with open(phenotype_meta, "r") as f:
120+
meta = json.load(f)
121+
122+
data = pd.read_csv(phenotype_path, sep="\t", index_col=0)
123+
124+
diagnosis_groups = list(meta["diagnosis"]["labels"].keys())
125+
diagnosis_groups.remove("HC")
126+
127+
n_sample = data.shape[0]
128+
129+
# create a hold out set for downstream analysis including all
130+
# the patients
131+
any_patients = data[diagnosis_groups].sum(axis=1) > 0
132+
patients = list(data[any_patients].index)
133+
controls = list(data[~any_patients].index)
134+
135+
n_patients = len(patients)
136+
n_control = n_sample - n_patients
137+
n_control_in_hold_out_set = int(n_sample * hold_out_set - n_patients)
138+
139+
corrected_hold_out_set = n_control_in_hold_out_set / n_control
140+
controls_site = list(data[~any_patients]["site"])
141+
train, hold_out = train_test_split(
142+
controls,
143+
test_size=corrected_hold_out_set,
144+
random_state=random_state,
145+
stratify=controls_site,
146+
)
147+
hold_out += patients
148+
149+
# get controls that matches patients confounds
150+
data_hold_out = data.loc[hold_out]
151+
downstreams = {}
152+
for d in diagnosis_groups:
153+
select_sample = gcb.class_balance(
154+
classes=data_hold_out[d].values.astype(int),
155+
confounds=data_hold_out[class_balance_confounds].values.T,
156+
plim=0.05,
157+
random_seed=random_state, # fix random seed for reproducibility
158+
)
159+
selected = data_hold_out.index[select_sample].tolist()
160+
selected.sort()
161+
downstreams[d] = selected
162+
train.sort()
163+
hold_out.sort()
164+
return {
165+
"train": train,
166+
"hold_out": hold_out,
167+
"test_downstreams": downstreams,
168+
}
169+
170+
96171
def load_ukbb_dset_path(
97-
path: Union[Path, str],
172+
participant_id: List[str],
98173
atlas_desc: str,
99-
n_sample: int = 50,
100-
val_set: float = 0.25,
101-
test_set: float = 0.25,
102174
segment: Union[int, List[int]] = -1,
103-
random_state: int = 42,
104175
) -> Dict:
105-
"""Load time series of UK Biobank.
176+
"""Load time series path in h5 file of UK Biobank.
106177
107178
We segmented the time series per subject as independent samples,
108179
hence it's important to make sure the same subject is not in both
109180
training and testing set.
110181
111182
Args:
112-
path (Union[Path, str]): Path to the hdf5 file.
183+
participant_id List[str]: List of participant ID.
113184
atlas_desc (str): Regex pattern to look for suitable data,
114185
such as the right `desc` field for atlas,
115186
e.g., "atlas-MIST_desc-197".
116-
n_sample (int, optional): number of subjects to use.
117-
Defaults to 50, and -1 would take the full sample.
118-
val_set (float, optional): proportion of the validation set
119-
size in relation to the full sample. Defaults to 0.25.
120-
test_set (float, optional): proportion of the test set size
121-
in relation to the full sample. Defaults to 0.25.
122187
segment (Union[int, List[int]], optional): segments of the
123188
time series to use. 0 for the full time series.
124189
Defaults to -1 to load all four segments.
@@ -144,41 +209,19 @@ def load_ukbb_dset_path(
144209
elif segment <= 4:
145210
segment = [segment]
146211

147-
# get the participant IDs to use
148-
with h5py.File(path, "r") as h5file:
149-
participant_id = list(h5file["ukbb"].keys())
150-
151-
if n_sample == -1:
152-
pass
153-
elif n_sample < len(participant_id):
154-
total_proportion_sample = n_sample / len(participant_id)
155-
participant_id, _ = train_test_split(
156-
participant_id,
157-
test_size=(1 - total_proportion_sample),
158-
random_state=random_state,
159-
)
160-
161212
# construct path
162213
subject_path_template = (
163214
"/ukbb/{sub}/{sub}_task-rest_{atlas_desc}_{seg}timeseries"
164215
)
165-
data_list = []
216+
h5_path = []
166217
for sub in participant_id:
167218
for seg in segment:
168219
seg = f"seg-{seg}_" if seg is not None else ""
169220
cur_sub_path = subject_path_template.format(
170221
sub=sub, seg=seg, atlas_desc=atlas_desc
171222
)
172-
data_list.append(cur_sub_path)
173-
# train-test-val split
174-
train, test = train_test_split(
175-
data_list, test_size=test_set, random_state=random_state
176-
)
177-
# calculate the proportion of val_set in the training loop
178-
train, val = train_test_split(
179-
train, test_size=val_set / (1 - test_set), random_state=random_state
180-
)
181-
return {"train": train, "val": val, "test": test}
223+
h5_path.append(cur_sub_path)
224+
return h5_path
182225

183226

184227
def load_data(

0 commit comments

Comments
 (0)