8
8
import h5py
9
9
import numpy as np
10
10
import pandas as pd
11
+ from general_class_balancer import general_class_balancer as gcb
11
12
from nilearn .connectome import ConnectivityMeasure
12
13
from sklearn .model_selection import train_test_split
13
14
from sklearn .preprocessing import StandardScaler
@@ -93,32 +94,96 @@ def split_data_by_site(
93
94
return tng_data , test_data
94
95
95
96
97
+ def create_hold_out_sample (
98
+ phenotype_path : Union [Path , str ],
99
+ phenotype_meta : Union [Path , str ],
100
+ class_balance_confounds : List [str ],
101
+ hold_out_set : float = 0.25 ,
102
+ random_state : int = 42 ,
103
+ ) -> Dict :
104
+ """Create experiment sample with patients in the hold out set.
105
+
106
+ Args:
107
+ phenotype_path (Union[Path, str]): Path to the tsv file.
108
+ Column index 0 must be participant_id.
109
+ phenotype_meta (Union[Path, str]): Path to the json file.
110
+ confounds (List[str]): list of confounds to use for class
111
+ balancing.
112
+ hold_out_set (float, optional): proportion of the test set size
113
+ in relation to the full sample. Defaults to 0.25.
114
+ random_state (int, optional): random state for reproducibility.
115
+ Returns:
116
+ dict: dictionary with list of participant ID for training and
117
+ hold out set, and the downstream task samples.
118
+ """
119
+ with open (phenotype_meta , "r" ) as f :
120
+ meta = json .load (f )
121
+
122
+ data = pd .read_csv (phenotype_path , sep = "\t " , index_col = 0 )
123
+
124
+ diagnosis_groups = list (meta ["diagnosis" ]["labels" ].keys ())
125
+ diagnosis_groups .remove ("HC" )
126
+
127
+ n_sample = data .shape [0 ]
128
+
129
+ # create a hold out set for downstream analysis including all
130
+ # the patients
131
+ any_patients = data [diagnosis_groups ].sum (axis = 1 ) > 0
132
+ patients = list (data [any_patients ].index )
133
+ controls = list (data [~ any_patients ].index )
134
+
135
+ n_patients = len (patients )
136
+ n_control = n_sample - n_patients
137
+ n_control_in_hold_out_set = int (n_sample * hold_out_set - n_patients )
138
+
139
+ corrected_hold_out_set = n_control_in_hold_out_set / n_control
140
+ controls_site = list (data [~ any_patients ]["site" ])
141
+ train , hold_out = train_test_split (
142
+ controls ,
143
+ test_size = corrected_hold_out_set ,
144
+ random_state = random_state ,
145
+ stratify = controls_site ,
146
+ )
147
+ hold_out += patients
148
+
149
+ # get controls that matches patients confounds
150
+ data_hold_out = data .loc [hold_out ]
151
+ downstreams = {}
152
+ for d in diagnosis_groups :
153
+ select_sample = gcb .class_balance (
154
+ classes = data_hold_out [d ].values .astype (int ),
155
+ confounds = data_hold_out [class_balance_confounds ].values .T ,
156
+ plim = 0.05 ,
157
+ random_seed = random_state , # fix random seed for reproducibility
158
+ )
159
+ selected = data_hold_out .index [select_sample ].tolist ()
160
+ selected .sort ()
161
+ downstreams [d ] = selected
162
+ train .sort ()
163
+ hold_out .sort ()
164
+ return {
165
+ "train" : train ,
166
+ "hold_out" : hold_out ,
167
+ "test_downstreams" : downstreams ,
168
+ }
169
+
170
+
96
171
def load_ukbb_dset_path (
97
- path : Union [ Path , str ],
172
+ participant_id : List [ str ],
98
173
atlas_desc : str ,
99
- n_sample : int = 50 ,
100
- val_set : float = 0.25 ,
101
- test_set : float = 0.25 ,
102
174
segment : Union [int , List [int ]] = - 1 ,
103
- random_state : int = 42 ,
104
175
) -> Dict :
105
- """Load time series of UK Biobank.
176
+ """Load time series path in h5 file of UK Biobank.
106
177
107
178
We segmented the time series per subject as independent samples,
108
179
hence it's important to make sure the same subject is not in both
109
180
training and testing set.
110
181
111
182
Args:
112
- path (Union[Path, str]): Path to the hdf5 file .
183
+ participant_id List[ str]: List of participant ID .
113
184
atlas_desc (str): Regex pattern to look for suitable data,
114
185
such as the right `desc` field for atlas,
115
186
e.g., "atlas-MIST_desc-197".
116
- n_sample (int, optional): number of subjects to use.
117
- Defaults to 50, and -1 would take the full sample.
118
- val_set (float, optional): proportion of the validation set
119
- size in relation to the full sample. Defaults to 0.25.
120
- test_set (float, optional): proportion of the test set size
121
- in relation to the full sample. Defaults to 0.25.
122
187
segment (Union[int, List[int]], optional): segments of the
123
188
time series to use. 0 for the full time series.
124
189
Defaults to -1 to load all four segments.
@@ -144,41 +209,19 @@ def load_ukbb_dset_path(
144
209
elif segment <= 4 :
145
210
segment = [segment ]
146
211
147
- # get the participant IDs to use
148
- with h5py .File (path , "r" ) as h5file :
149
- participant_id = list (h5file ["ukbb" ].keys ())
150
-
151
- if n_sample == - 1 :
152
- pass
153
- elif n_sample < len (participant_id ):
154
- total_proportion_sample = n_sample / len (participant_id )
155
- participant_id , _ = train_test_split (
156
- participant_id ,
157
- test_size = (1 - total_proportion_sample ),
158
- random_state = random_state ,
159
- )
160
-
161
212
# construct path
162
213
subject_path_template = (
163
214
"/ukbb/{sub}/{sub}_task-rest_{atlas_desc}_{seg}timeseries"
164
215
)
165
- data_list = []
216
+ h5_path = []
166
217
for sub in participant_id :
167
218
for seg in segment :
168
219
seg = f"seg-{ seg } _" if seg is not None else ""
169
220
cur_sub_path = subject_path_template .format (
170
221
sub = sub , seg = seg , atlas_desc = atlas_desc
171
222
)
172
- data_list .append (cur_sub_path )
173
- # train-test-val split
174
- train , test = train_test_split (
175
- data_list , test_size = test_set , random_state = random_state
176
- )
177
- # calculate the proportion of val_set in the training loop
178
- train , val = train_test_split (
179
- train , test_size = val_set / (1 - test_set ), random_state = random_state
180
- )
181
- return {"train" : train , "val" : val , "test" : test }
223
+ h5_path .append (cur_sub_path )
224
+ return h5_path
182
225
183
226
184
227
def load_data (
0 commit comments