antoniaebner commited on
Commit
28424e6
·
1 Parent(s): 3057490

update requirements and add preprocessing

Browse files
Files changed (5) hide show
  1. data/tox_smarts.json +0 -0
  2. preprocess.py +193 -0
  3. requirements.txt +3 -3
  4. src/data.py +313 -74
  5. src/utils.py +9 -0
data/tox_smarts.json ADDED
The diff for this file is too large to render. See raw diff
 
preprocess.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+
12
+ import numpy as np
13
+
14
+ from src.data import create_descriptors, get_tox21_split
15
+ from src.utils import (
16
+ TASKS,
17
+ HF_TOKEN,
18
+ write_pickle,
19
+ create_dir,
20
+ )
21
+
22
+ parser = argparse.ArgumentParser(
23
+ description="Data preprocessing script for the Tox21 dataset"
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--save_folder",
28
+ type=str,
29
+ default="data/",
30
+ help="Folder to which preprocessed the data CSV and NPZ files should be saved.",
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--cv_fold",
35
+ type=int,
36
+ default=4,
37
+ help="Select fold used as validation set.",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--feature_selection",
42
+ type=int,
43
+ default=1,
44
+ help="True (=1) to use feature selection.",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--feature_selection_path",
49
+ type=str,
50
+ default="feat_selection.npz",
51
+ help="Filename for saving feature selections.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--min_var",
56
+ type=float,
57
+ default=0.01,
58
+ help="Minimum variance threshold for selecting features.",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--max_corr",
63
+ type=float,
64
+ default=0.95,
65
+ help="Maximum correlation threshold for selecting features.",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--ecdfs_path",
70
+ type=str,
71
+ default="ecdfs.pkl",
72
+ help="Filename to save ECDFs.",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--ecfps_radius",
77
+ type=int,
78
+ default=3,
79
+ help="Radius used for creating ECFPs.",
80
+ )
81
+
82
+ parser.add_argument(
83
+ "--ecfps_folds",
84
+ type=int,
85
+ default=8192,
86
+ help="Folds used for creating ECFPs.",
87
+ )
88
+
89
+ parser.add_argument(
90
+ "--ecdfs",
91
+ type=int,
92
+ default=1,
93
+ help="True (=1) to use ECDFs for creating quantiles of the RDKit descriptors.",
94
+ )
95
+
96
+
97
+ def main(args):
98
+ """Preprocessing train/val data to use for TabPFN.
99
+
100
+ 1. Download Tox21 train/val data from HF
101
+ 2. Preprocess dataset splits
102
+ """
103
+ ds = get_tox21_split(HF_TOKEN, cvfold=args.cv_fold)
104
+
105
+ feature_creation_kwargs = {
106
+ "radius": args.ecfps_radius,
107
+ "fpsize": args.ecfps_folds,
108
+ "min_var": args.min_var,
109
+ "max_corr": args.max_corr,
110
+ }
111
+ removed_mols = 0
112
+
113
+ splits = ["train", "validation", "test"]
114
+ for split in splits:
115
+
116
+ print(f"Preprocess {split} molecules")
117
+
118
+ if split != "test":
119
+ ds_split = ds[split]
120
+ smiles = list(ds_split["smiles"])
121
+ else:
122
+ import pandas as pd
123
+
124
+ ds_split = pd.read_csv("data/tox21_test_cv4.csv")
125
+
126
+ smiles = ds_split["smiles"]
127
+
128
+ features, clean_mol_mask = create_descriptors(smiles, **feature_creation_kwargs)
129
+
130
+ # if split == "train":
131
+ # output = create_descriptors(
132
+ # smiles,
133
+ # return_feature_selection=True,
134
+ # return_ecdfs=True,
135
+ # **feature_creation_kwargs,
136
+ # )
137
+ # features = output.pop("features")
138
+
139
+ # if args.feature_selection:
140
+ # feature_selection = output.pop("feature_selection")
141
+ # np.savez(
142
+ # args.feature_selection_path,
143
+ # ecfps_selec=feature_selection["ecfps_selec"],
144
+ # tox_selec=feature_selection["tox_selec"],
145
+ # )
146
+
147
+ # print(f"Saved feature selection under {args.feature_selection_path}")
148
+
149
+ # if args.ecdfs:
150
+ # ecdfs = output.pop("ecdfs")
151
+ # write_pickle(args.ecdfs_path, ecdfs)
152
+ # print(f"Saved ECDFs under {args.ecdfs_path}")
153
+
154
+ # else:
155
+ # features = create_descriptors(
156
+ # smiles,
157
+ # ecdfs=ecdfs,
158
+ # feature_selection=feature_selection,
159
+ # **feature_creation_kwargs,
160
+ # )["features"]
161
+ removed_mols += (~clean_mol_mask).sum()
162
+
163
+ labels = []
164
+ for task in TASKS:
165
+ labels.append(ds_split[task].to_numpy())
166
+ labels = np.stack(labels, axis=1)
167
+
168
+ save_path = os.path.join(args.save_folder, f"tox21_{split}_cv4.npz")
169
+ with open(save_path, "wb") as f:
170
+ np.savez(
171
+ f,
172
+ labels=labels[clean_mol_mask, :],
173
+ features=features,
174
+ # **features,
175
+ )
176
+ print(f"Saved preprocessed {split} split under {save_path}")
177
+ print(f"{removed_mols} mols were removed during cleaning across all datasets")
178
+ print("Preprocessing finished successfully")
179
+
180
+
181
+ if __name__ == "__main__":
182
+ args = parser.parse_args()
183
+
184
+ # args.ecdfs_path = os.path.join(args.save_folder, args.ecdfs_path)
185
+ # args.feature_selection_path = os.path.join(
186
+ # args.save_folder, args.feature_selection_path
187
+ # )
188
+
189
+ create_dir(args.save_folder)
190
+ # create_dir(args.ecdfs_path, is_file=True)
191
+ # create_dir(args.feature_selection_path, is_file=True)
192
+
193
+ main(args)
requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
  fastapi
2
  uvicorn[standard]
3
  statsmodels
4
- rdkit
5
- numpy
6
  scikit-learn==1.7.1
7
- joblib
8
  tabulate
9
  datasets
10
  torch==2.8.0
 
1
  fastapi
2
  uvicorn[standard]
3
  statsmodels
4
+ rdkit==2025.09.1
5
+ numpy==2.3.3
6
  scikit-learn==1.7.1
7
+ joblib==1.5.2
8
  tabulate
9
  datasets
10
  torch==2.8.0
src/data.py CHANGED
@@ -6,85 +6,324 @@ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
  SMILES and target names as keys.
7
  """
8
 
9
- from typing import Iterable, Literal
10
 
11
  import numpy as np
12
- import torch
13
 
14
- from .preprocess import normalize_features
 
 
15
 
16
- KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
 
 
17
 
 
 
 
 
 
18
 
19
- def get_descriptor_dataset(
20
- data_path: str,
21
- descriptors: Iterable[str] | Literal["all"],
22
- scaler=None,
23
- save_scaler_path: str = "data/scaler.pkl",
24
- verbose=True,
25
- normalize=True,
26
- ):
27
- if descriptors == "all":
28
- descriptors = KNOWN_DESCR
29
-
30
- assert isinstance(descriptors, Iterable), "Passed descriptors are not iterable!"
31
- assert all(
32
- [descr in KNOWN_DESCR for descr in descriptors]
33
- ), f"Passed descriptors contains unknown descriptor types. Allowed descriptors: {KNOWN_DESCR}"
34
-
35
- datafile = np.load(data_path)
36
-
37
- if not isinstance(datafile, np.ndarray):
38
- # concatenate all descriptors and normalize
39
- data = np.concatenate([datafile[descr] for descr in descriptors], axis=1)
40
- labels = datafile["labels"]
41
-
42
- else:
43
- print("NPY file passed, cannot select specific descriptors")
44
- data, labels = datafile[:, :-12], datafile[:, -12:]
45
-
46
- if normalize:
47
- data, scaler = normalize_features(
48
- data,
49
- scaler=scaler,
50
- save_scaler_path=save_scaler_path,
51
- verbose=verbose,
 
 
 
 
 
 
 
 
 
 
52
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # filter out unsanitized molecules
55
- mask = ~np.isnan(data).any(axis=1)
56
- data = data[mask]
57
- labels = labels[mask]
58
-
59
- assert data.shape[0] == labels.shape[0], (
60
- f"Mismatch between data and labels: "
61
- f"data has {data.shape[0]} samples, but labels has {labels.shape[0]} samples."
62
- )
63
-
64
- return (data, labels, scaler)
65
-
66
-
67
- def get_torch_descriptor_dataset(
68
- data_path: str,
69
- descriptors: list[str],
70
- scaler=None,
71
- save_scaler_path: str = "data/scaler.pkl",
72
- nan_to_num: int = -100,
73
- verbose=True,
74
- normalize=True,
75
- ) -> torch.utils.data.TensorDataset:
76
- data, labels, scaler = get_descriptor_dataset(
77
- data_path,
78
- descriptors,
79
- scaler,
80
- save_scaler_path,
81
- verbose=verbose,
82
- normalize=normalize,
83
- )
84
-
85
- labels = np.nan_to_num(labels, nan=nan_to_num)
86
-
87
- dataset = torch.utils.data.TensorDataset(
88
- torch.FloatTensor(data), torch.LongTensor(labels)
89
- )
90
- return dataset, scaler
 
6
  SMILES and target names as keys.
7
  """
8
 
9
+ import json
10
 
11
  import numpy as np
12
+ import pandas as pd
13
 
14
+ from datasets import load_dataset
15
+ from sklearn.feature_selection import VarianceThreshold
16
+ from statsmodels.distributions.empirical_distribution import ECDF
17
 
18
+ from rdkit import Chem, DataStructs
19
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
20
+ from rdkit.Chem.rdchem import Mol
21
 
22
+ from .utils import (
23
+ USED_200_DESCR,
24
+ TOX_SMARTS_PATH,
25
+ Standardizer,
26
+ )
27
 
28
+
29
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
30
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
31
+
32
+ Args:
33
+ smiles (list[str]): list of SMILES
34
+
35
+ Returns:
36
+ list[Mol]: list of cleaned molecules
37
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
38
+ index `i` could not be cleaned and was removed.
39
+ """
40
+ sm = Standardizer(canon_taut=True)
41
+
42
+ clean_mol_mask = list()
43
+ mols = list()
44
+ for i, smile in enumerate(smiles):
45
+ mol = Chem.MolFromSmiles(smile)
46
+ standardized_mol, _ = sm.standardize_mol(mol)
47
+ is_cleaned = standardized_mol is not None
48
+ clean_mol_mask.append(is_cleaned)
49
+ if not is_cleaned:
50
+ continue
51
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
52
+ mols.append(can_mol)
53
+
54
+ return mols, np.array(clean_mol_mask)
55
+
56
+
57
+ def create_ecfp_fps(mols: list[Mol], radius=3, fpsize=2048, **kwargs) -> np.ndarray:
58
+ """This function ECFP fingerprints for a list of molecules.
59
+
60
+ Args:
61
+ mols (list[Mol]): list of molecules
62
+
63
+ Returns:
64
+ np.ndarray: ECFP fingerprints of molecules
65
+ """
66
+ ecfps = list()
67
+
68
+ for mol in mols:
69
+ gen = rdFingerprintGenerator.GetMorganGenerator(
70
+ countSimulation=True, fpSize=fpsize, radius=radius
71
  )
72
+ fp_sparse_vec = gen.GetCountFingerprint(mol)
73
+
74
+ fp = np.zeros((0,), np.int8)
75
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
76
+
77
+ ecfps.append(fp)
78
+
79
+ return np.array(ecfps)
80
+
81
+
82
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
83
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
84
+ return np.array(maccs)
85
+
86
+
87
+ def get_tox_patterns(filepath: str):
88
+ """This calculates tox features defined in tox_smarts.json.
89
+ Args:
90
+ mols: A list of Mol
91
+ n_jobs: If >1 multiprocessing is used
92
+ """
93
+ # load patterns
94
+ with open(filepath) as f:
95
+ smarts_list = [s[1] for s in json.load(f)]
96
+
97
+ # Code does not work for this case
98
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
99
+
100
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
101
+ # and then use them for all molecules. This gives a huge speedup over existing code.
102
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
103
+ all_patterns = []
104
+ for smarts in smarts_list:
105
+ patterns = [] # list of smarts-patterns
106
+ # value for each of the patterns above. Negates the values of the above later.
107
+ negations = []
108
+
109
+ if " AND " in smarts:
110
+ smarts = smarts.split(" AND ")
111
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
112
+ else:
113
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
114
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
115
+ smarts = smarts.split(" OR ")
116
+ merge_any = True
117
+
118
+ # for all subsmarts check if they are preceded by 'NOT '
119
+ for s in smarts:
120
+ neg = s.startswith("NOT ")
121
+ if neg:
122
+ s = s[4:]
123
+ patterns.append(Chem.MolFromSmarts(s))
124
+ negations.append(neg)
125
+
126
+ all_patterns.append((patterns, negations, merge_any))
127
+ return all_patterns
128
+
129
+
130
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
131
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
132
+ tox_data = []
133
+ for mol in mols:
134
+ mol_features = []
135
+ for patts, negations, merge_any in patterns:
136
+ matches = [mol.HasSubstructMatch(p) for p in patts]
137
+ matches = [m != n for m, n in zip(matches, negations)]
138
+ if merge_any:
139
+ pres = any(matches)
140
+ else:
141
+ pres = all(matches)
142
+ mol_features.append(pres)
143
+
144
+ tox_data.append(np.array(mol_features))
145
+
146
+ return np.array(tox_data)
147
+
148
+
149
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
150
+ """This function creates RDKit descriptors for a list of molecules.
151
+
152
+ Args:
153
+ mols (list[Mol]): list of molecules
154
+
155
+ Returns:
156
+ np.ndarray: RDKit descriptors of molecules
157
+ """
158
+ rdkit_descriptors = list()
159
+
160
+ for mol in mols:
161
+ descrs = []
162
+ for _, descr_calc_fn in Descriptors._descList:
163
+ descrs.append(descr_calc_fn(mol))
164
+
165
+ descrs = np.array(descrs)
166
+ descrs = descrs[USED_200_DESCR]
167
+ rdkit_descriptors.append(descrs)
168
+
169
+ return np.array(rdkit_descriptors)
170
+
171
+
172
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
173
+ """Create quantile values for given features using the columns
174
+
175
+ Args:
176
+ raw_features (np.ndarray): values to put into quantiles
177
+ ecdfs (list): ECDFs to use
178
+
179
+ Returns:
180
+ np.ndarray: computed quantiles
181
+ """
182
+ quantiles = np.zeros_like(raw_features)
183
+
184
+ for column in range(raw_features.shape[1]):
185
+ raw_values = raw_features[:, column].reshape(-1)
186
+ ecdf = ecdfs[column]
187
+ q = ecdf(raw_values)
188
+ quantiles[:, column] = q
189
+
190
+ return quantiles
191
+
192
+
193
+ def fill(features, mask, value=np.nan):
194
+ n_mols = len(mask)
195
+ n_features = features.shape[1]
196
+
197
+ data = np.zeros(shape=(n_mols, n_features))
198
+ data.fill(value)
199
+ data[~mask] = features
200
+ return data
201
+
202
+
203
+ def create_descriptors(
204
+ smiles,
205
+ ecdfs=None,
206
+ feature_selection=None,
207
+ return_ecdfs=False,
208
+ return_feature_selection=False,
209
+ **kwargs,
210
+ ):
211
+ # Create cleanded rdkit mol objects
212
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
213
+ print("Cleaned molecules")
214
+
215
+ tox_patterns = get_tox_patterns(TOX_SMARTS_PATH)
216
+
217
+ # Create fingerprints and descriptors
218
+ ecfps = create_ecfp_fps(mols, **kwargs)
219
+ # expand using mol_mask
220
+ # ecfps = fill(ecfps, ~clean_mol_mask)
221
+ print("Created ECFP fingerprints")
222
+ # print("ecfps features:", ecfps.shape)
223
+
224
+ tox = create_tox_features(mols, tox_patterns)
225
+ # tox = fill(tox, ~clean_mol_mask)
226
+ print("Created Tox features")
227
+ # print("tox features:", tox.shape)
228
+
229
+ # Create and save feature selection for ecfps and tox
230
+ # if feature_selection is None:
231
+ # print("Create Feature selection")
232
+ # ecfps_selec = get_feature_selection(ecfps, **kwargs)
233
+ # tox_selec = get_feature_selection(tox, **kwargs)
234
+ # feature_selection = {"ecfps_selec": ecfps_selec, "tox_selec": tox_selec}
235
+
236
+ # else:
237
+ # ecfps_selec = feature_selection["ecfps_selec"]
238
+ # tox_selec = feature_selection["tox_selec"]
239
+
240
+ # ecfps = ecfps[:, ecfps_selec]
241
+ # tox = tox[:, tox_selec]
242
+
243
+ maccs = create_maccs_keys(mols)
244
+ # maccs = fill(maccs, ~clean_mol_mask)
245
+ print("Created MACCS keys")
246
+
247
+ rdkit_descrs = create_rdkit_descriptors(mols)
248
+ # rdkit_descrs = fill(rdkit_descrs, ~clean_mol_mask)
249
+ print("Created RDKit descriptors")
250
+
251
+ # # Create and save ecdfs
252
+ # if ecdfs is None:
253
+ # print("Create ECDFs")
254
+ # ecdfs = []
255
+ # for column in range(rdkit_descrs.shape[1]):
256
+ # raw_values = rdkit_descrs[:, column].reshape(-1)
257
+ # ecdfs.append(ECDF(raw_values))
258
+
259
+ # # Create quantiles
260
+ # rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
261
+ # # expand using mol_mask
262
+ # rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
263
+ # print("Created quantiles of RDKit descriptors")
264
+
265
+ # concatenate features
266
+ # features = {
267
+ # "ecfps": ecfps,
268
+ # "tox": tox,
269
+ # "maccs": maccs,
270
+ # "rdkit_descr_quantiles": rdkit_descr_quantiles,
271
+ # }
272
+ # for feat in [ecfps, tox, maccs, rdkit_descrs]:
273
+ # print(feat.shape)
274
+ features = np.concat((ecfps, tox, maccs, rdkit_descrs), axis=1)
275
+ # return_dict = {"features": features}
276
+ # if return_ecdfs:
277
+ # return_dict["ecdfs"] = ecdfs
278
+ # if return_feature_selection:
279
+ # return_dict["feature_selection"] = feature_selection
280
+ return features, clean_mol_mask
281
+
282
+
283
+ def get_feature_selection(
284
+ raw_features: np.ndarray, min_var=0.01, max_corr=0.95, **kwargs
285
+ ) -> np.ndarray:
286
+ # select features with at least min_var variation
287
+ var_thresh = VarianceThreshold(threshold=min_var)
288
+ feature_selection = var_thresh.fit(raw_features).get_support(indices=True)
289
+
290
+ n_features_preselected = len(feature_selection)
291
+
292
+ # Remove highly correlated features
293
+ corr_matrix = np.corrcoef(raw_features[:, feature_selection], rowvar=False)
294
+ upper_tri = np.triu(corr_matrix, k=1)
295
+ to_keep = np.ones((n_features_preselected,), dtype=bool)
296
+ for i in range(upper_tri.shape[0]):
297
+ for j in range(upper_tri.shape[1]):
298
+ if upper_tri[i, j] > max_corr:
299
+ to_keep[j] = False
300
+
301
+ feature_selection = feature_selection[to_keep]
302
+ return feature_selection
303
+
304
+
305
+ def get_tox21_split(token, cvfold=None):
306
+ ds = load_dataset("tschouis/tox21", token=token)
307
+
308
+ train_df = ds["train"].to_pandas()
309
+ val_df = ds["validation"].to_pandas()
310
+
311
+ if cvfold is None:
312
+ return {"train": train_df, "validation": val_df}
313
+
314
+ combined_df = pd.concat([train_df, val_df], ignore_index=True)
315
+ cvfold = float(cvfold)
316
+
317
+ # create new splits
318
+ cvfold = float(cvfold)
319
+ train_df = combined_df[combined_df.CVfold != cvfold]
320
+ val_df = combined_df[combined_df.CVfold == cvfold]
321
+
322
+ # exclude train mols that occur in the validation split
323
+ val_inchikeys = set(val_df["inchikey"])
324
+ train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
325
 
326
+ return {
327
+ "train": train_df.reset_index(drop=True),
328
+ "validation": val_df.reset_index(drop=True),
329
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py CHANGED
@@ -12,6 +12,7 @@ from rdkit import Chem
12
  from rdkit.Chem.MolStandardize import rdMolStandardize
13
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
15
 
16
  TASKS = [
17
  "NR-AR",
@@ -441,3 +442,11 @@ def load_pickle(path: str):
441
  def write_pickle(path: str, obj: object):
442
  with open(path, "wb") as file:
443
  pickle.dump(obj, file)
 
 
 
 
 
 
 
 
 
12
  from rdkit.Chem.MolStandardize import rdMolStandardize
13
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
+ TOX_SMARTS_PATH = "data/tox_smarts.json"
16
 
17
  TASKS = [
18
  "NR-AR",
 
442
  def write_pickle(path: str, obj: object):
443
  with open(path, "wb") as file:
444
  pickle.dump(obj, file)
445
+
446
+
447
+ def create_dir(path, is_file=False):
448
+ """Creates the parent directories if a path to a file is given, else create the given directory"""
449
+
450
+ to_create = os.path.dirname(path) if is_file else path
451
+ if not os.path.exists(to_create):
452
+ os.makedirs(to_create)