antoniaebner commited on
Commit
101fde6
·
1 Parent(s): 4a5bd87

update pipeline & add selected hyperparams

Browse files
Files changed (5) hide show
  1. predict.py +15 -4
  2. src/data.py +76 -184
  3. src/preprocess.py +405 -0
  4. src/train.py +139 -32
  5. src/utils.py +2 -0
predict.py CHANGED
@@ -10,8 +10,9 @@ from collections import defaultdict
10
 
11
  import numpy as np
12
 
13
- from src.data import preprocess_molecules
14
  from src.model import Tox21XGBClassifier
 
 
15
 
16
  # ---------------------------------------------------------------------------------------
17
 
@@ -28,11 +29,21 @@ def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]:
28
  """
29
  print(f"Received {len(smiles_list)} SMILES strings")
30
  # preprocessing pipeline
31
- features, mol_mask = preprocess_molecules(
 
 
 
 
 
 
 
 
32
  smiles_list,
33
- load_ecdf_path="assets/ecdfs.pkl",
34
- load_scaler_path="assets/scaler.pkl",
 
35
  )
 
36
  print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning")
37
 
38
  # setup model
 
10
 
11
  import numpy as np
12
 
 
13
  from src.model import Tox21XGBClassifier
14
+ from src.preprocess import create_descriptors
15
+ from src.utils import load_pickle, KNOWN_DESCR
16
 
17
  # ---------------------------------------------------------------------------------------
18
 
 
29
  """
30
  print(f"Received {len(smiles_list)} SMILES strings")
31
  # preprocessing pipeline
32
+ ecdfs_path = "assets/ecdfs.pkl"
33
+ scaler_path = "assets/scaler.pkl"
34
+ ecdfs = load_pickle(ecdfs_path)
35
+ scaler = load_pickle(scaler_path)
36
+ print(f"Loaded ecdfs from {ecdfs_path}")
37
+ print(f"Loaded scaler from {scaler_path}")
38
+
39
+ descriptors = KNOWN_DESCR
40
+ features, mol_mask = create_descriptors(
41
  smiles_list,
42
+ ecdfs=ecdfs,
43
+ scaler=scaler,
44
+ descriptors=descriptors,
45
  )
46
+ print(f"Created descriptors {descriptors} for molecules.")
47
  print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning")
48
 
49
  # setup model
src/data.py CHANGED
@@ -6,193 +6,85 @@ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
  SMILES and target names as keys.
7
  """
8
 
9
- import os
10
 
11
  import numpy as np
12
-
13
- from sklearn.preprocessing import StandardScaler
14
- from statsmodels.distributions.empirical_distribution import ECDF
15
-
16
- from rdkit import Chem, DataStructs
17
- from rdkit.Chem import Descriptors, rdFingerprintGenerator
18
- from rdkit.Chem.rdchem import Mol
19
-
20
- from .utils import USED_200_DESCR, Standardizer, load_pickle, write_pickle
21
-
22
-
23
- def preprocess_molecules(
24
- smiles_list: list[str],
25
- load_ecdf_path: str = "",
26
- load_scaler_path: str = "",
27
- save_ecdf_path: str = "",
28
- save_scaler_path: str = "",
29
- ) -> tuple[np.ndarray, list[int]]:
30
- """Preprocessing pipeline for a list of molecules.
31
-
32
- Args:
33
- smiles_list (list[str]): list of SMILES
34
- load_ecdf_path (str, optional): Path to load ECDFs from. Defaults to "".
35
- load_scaler_path (str, optional): Path to load fitted StandardScaler from. Defaults to "".
36
- save_ecdf_path (str, optional): Path to save calculated ECDFs. Defaults to "".
37
- save_scaler_path (str, optional): Path to save fitted StandardScaler. Defaults to "".
38
-
39
- Returns:
40
- np.ndarray: normalized ECFPs fingerprints and RDKit descriptor quantiles
41
- list[bool]: mask that contains False at index `i`, if molecule in `smiles_list` at
42
- index `i` could not be cleaned and was removed.
43
- """
44
-
45
- assert not (
46
- load_ecdf_path and save_ecdf_path
47
- ), "Cannot pass 'load_ecdf_path' and 'save_ecdf_path' simultaneously"
48
- assert not (
49
- load_scaler_path and save_scaler_path
50
- ), "Cannot pass 'load_scaler_path' and 'save_scaler_path' simultaneously"
51
-
52
- ecdfs = (
53
- load_pickle(load_ecdf_path)
54
- if load_ecdf_path and os.path.exists(load_ecdf_path)
55
- else None
56
- )
57
- scaler = (
58
- load_pickle(load_scaler_path)
59
- if load_scaler_path and os.path.exists(load_scaler_path)
60
- else None
 
61
  )
62
 
63
- # Create cleanded rdkit mol objects
64
- mols, clean_mol_mask = create_cleaned_mol_objects(smiles_list)
65
- print("Cleaned molecules")
66
-
67
- # Create fingerprints and descriptors
68
- ecfps = create_ecfp_fps(mols)
69
- print("Created ECFP fingerprints")
70
- rdkit_descrs = create_rdkit_descriptors(mols)
71
- print("Created RDKit descriptors")
72
-
73
- # Create and save ecdfs
74
- if ecdfs is None:
75
- print("Create ECDFs")
76
- ecdfs = []
77
- for column in range(rdkit_descrs.shape[1]):
78
- raw_values = rdkit_descrs[:, column].reshape(-1)
79
- ecdfs.append(ECDF(raw_values))
80
- if save_ecdf_path:
81
- write_pickle(save_ecdf_path, ecdfs)
82
- print(f"Saved ECDFs under {save_ecdf_path}")
83
-
84
- # Create quantiles
85
- rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
86
- print("Created quantiles of RDKit descriptors")
87
-
88
- # Concatenate features
89
- raw_features = np.concatenate((ecfps, rdkit_descr_quantiles), axis=1)
90
-
91
- if scaler is None:
92
- scaler = StandardScaler()
93
- scaler.fit(raw_features)
94
- print("Fitted the StandardScaler")
95
- if save_scaler_path:
96
- write_pickle(save_scaler_path, scaler)
97
- print(f"Saved the StandardScaler under {save_scaler_path}")
98
-
99
- # Normalize feature vectors
100
- normalized_features = scaler.transform(raw_features)
101
- print("Normalized the molecule features")
102
-
103
- return normalized_features, clean_mol_mask
104
-
105
-
106
- def create_cleaned_mol_objects(smiles: list[str]) -> list[Mol]:
107
- """This function creates cleaned RDKit mol objects from a list of SMILES.
108
-
109
- Args:
110
- smiles (list[str]): list of SMILES
111
-
112
- Returns:
113
- list[Mol]: list of cleaned molecules
114
- list[bool]: mask that contains False at index `i`, if molecule in `smiles` at
115
- index `i` could not be cleaned and was removed.
116
- """
117
- sm = Standardizer(canon_taut=True)
118
-
119
- clean_mol_mask = list()
120
- mols = list()
121
- for i, smile in enumerate(smiles):
122
- mol = Chem.MolFromSmiles(smile)
123
- standardized_mol, _ = sm.standardize_mol(mol)
124
- is_cleaned = standardized_mol is not None
125
- clean_mol_mask.append(is_cleaned)
126
- if not is_cleaned:
127
- continue
128
- can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
129
- mols.append(can_mol)
130
-
131
- return mols, clean_mol_mask
132
-
133
-
134
- def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
135
- """This function ECFP fingerprints for a list of molecules.
136
-
137
- Args:
138
- mols (list[Mol]): list of molecules
139
-
140
- Returns:
141
- np.ndarray: ECFP fingerprints of molecules
142
- """
143
- ecfps = list()
144
-
145
- for mol in mols:
146
- fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
147
- [mol], fpType=rdFingerprintGenerator.MorganFP
148
- )[0]
149
- fp = np.zeros((0,), np.int8)
150
- DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
151
-
152
- ecfps.append(fp)
153
-
154
- return np.array(ecfps)
155
-
156
-
157
- def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
158
- """This function creates RDKit descriptors for a list of molecules.
159
-
160
- Args:
161
- mols (list[Mol]): list of molecules
162
-
163
- Returns:
164
- np.ndarray: RDKit descriptors of molecules
165
- """
166
- rdkit_descriptors = list()
167
-
168
- for mol in mols:
169
- descrs = []
170
- for _, descr_calc_fn in Descriptors._descList:
171
- descrs.append(descr_calc_fn(mol))
172
-
173
- descrs = np.array(descrs)
174
- descrs = descrs[USED_200_DESCR]
175
- rdkit_descriptors.append(descrs)
176
-
177
- return np.array(rdkit_descriptors)
178
-
179
-
180
- def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
181
- """Create quantile values for given features using the columns
182
-
183
- Args:
184
- raw_features (np.ndarray): values to put into quantiles
185
- ecdfs (list): ECDFs to use
186
-
187
- Returns:
188
- np.ndarray: computed quantiles
189
- """
190
- quantiles = np.zeros_like(raw_features)
191
 
192
- for column in range(raw_features.shape[1]):
193
- raw_values = raw_features[:, column].reshape(-1)
194
- ecdf = ecdfs[column]
195
- q = ecdf(raw_values)
196
- quantiles[:, column] = q
197
 
198
- return quantiles
 
 
 
 
6
  SMILES and target names as keys.
7
  """
8
 
9
+ from typing import Iterable, Literal
10
 
11
  import numpy as np
12
+ import torch
13
+
14
+ from .preprocess import normalize_features
15
+
16
+ KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
17
+
18
+
19
+ def get_descriptor_dataset(
20
+ data_path: str,
21
+ descriptors: Iterable[str] | Literal["all"],
22
+ scaler=None,
23
+ save_scaler_path: str = "data/scaler.pkl",
24
+ verbose=True,
25
+ normalize=True,
26
+ ):
27
+ if descriptors == "all":
28
+ descriptors = KNOWN_DESCR
29
+
30
+ assert isinstance(descriptors, Iterable), "Passed descriptors are not iterable!"
31
+ assert all(
32
+ [descr in KNOWN_DESCR for descr in descriptors]
33
+ ), f"Passed descriptors contains unknown descriptor types. Allowed descriptors: {KNOWN_DESCR}"
34
+
35
+ datafile = np.load(data_path)
36
+
37
+ if not isinstance(datafile, np.ndarray):
38
+ # concatenate all descriptors and normalize
39
+ data = np.concatenate([datafile[descr] for descr in descriptors], axis=1)
40
+ labels = datafile["labels"]
41
+
42
+ else:
43
+ print("NPY file passed, cannot select specific descriptors")
44
+ data, labels = datafile[:, :-12], datafile[:, -12:]
45
+
46
+ if normalize:
47
+ data, scaler = normalize_features(
48
+ data,
49
+ scaler=scaler,
50
+ save_scaler_path=save_scaler_path,
51
+ verbose=verbose,
52
+ )
53
+
54
+ # filter out unsanitized molecules
55
+ mask = ~np.isnan(data).any(axis=1)
56
+ data = data[mask]
57
+ labels = labels[mask]
58
+
59
+ assert data.shape[0] == labels.shape[0], (
60
+ f"Mismatch between data and labels: "
61
+ f"data has {data.shape[0]} samples, but labels has {labels.shape[0]} samples."
62
  )
63
 
64
+ return (data, labels, scaler)
65
+
66
+
67
+ def get_torch_descriptor_dataset(
68
+ data_path: str,
69
+ descriptors: list[str],
70
+ scaler=None,
71
+ save_scaler_path: str = "data/scaler.pkl",
72
+ nan_to_num: int = -100,
73
+ verbose=True,
74
+ normalize=True,
75
+ ) -> torch.utils.data.TensorDataset:
76
+ data, labels, scaler = get_descriptor_dataset(
77
+ data_path,
78
+ descriptors,
79
+ scaler,
80
+ save_scaler_path,
81
+ verbose=verbose,
82
+ normalize=normalize,
83
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ labels = np.nan_to_num(labels, nan=nan_to_num)
 
 
 
 
86
 
87
+ dataset = torch.utils.data.TensorDataset(
88
+ torch.FloatTensor(data), torch.LongTensor(labels)
89
+ )
90
+ return dataset, scaler
src/preprocess.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import os
10
+ import argparse
11
+ import json
12
+ from typing import Iterable
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from sklearn.preprocessing import StandardScaler
18
+ from statsmodels.distributions.empirical_distribution import ECDF
19
+ from datasets import load_dataset
20
+
21
+ from rdkit import Chem, DataStructs
22
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
23
+ from rdkit.Chem.rdchem import Mol
24
+
25
+ from src.utils import (
26
+ TASKS,
27
+ KNOWN_DESCR,
28
+ HF_TOKEN,
29
+ USED_200_DESCR,
30
+ Standardizer,
31
+ load_pickle,
32
+ write_pickle,
33
+ )
34
+
35
+ parser = argparse.ArgumentParser(
36
+ description="Data preprocessing script for the Tox21 dataset"
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--save_folder",
41
+ type=str,
42
+ default="data/",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--use_hf",
47
+ type=int,
48
+ default=0,
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--path_ecdfs",
53
+ type=str,
54
+ default="data/ecdfs.pkl",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--tox_smarts_filepath",
59
+ type=str,
60
+ default="data/tox_smarts.json",
61
+ )
62
+
63
+
64
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
65
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
66
+
67
+ Args:
68
+ smiles (list[str]): list of SMILES
69
+
70
+ Returns:
71
+ list[Mol]: list of cleaned molecules
72
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
73
+ index `i` could not be cleaned and was removed.
74
+ """
75
+ sm = Standardizer(canon_taut=True)
76
+
77
+ clean_mol_mask = list()
78
+ mols = list()
79
+ for i, smile in enumerate(smiles):
80
+ mol = Chem.MolFromSmiles(smile)
81
+ standardized_mol, _ = sm.standardize_mol(mol)
82
+ is_cleaned = standardized_mol is not None
83
+ clean_mol_mask.append(is_cleaned)
84
+ if not is_cleaned:
85
+ continue
86
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
87
+ mols.append(can_mol)
88
+
89
+ return mols, np.array(clean_mol_mask)
90
+
91
+
92
+ def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
93
+ """This function ECFP fingerprints for a list of molecules.
94
+
95
+ Args:
96
+ mols (list[Mol]): list of molecules
97
+
98
+ Returns:
99
+ np.ndarray: ECFP fingerprints of molecules
100
+ """
101
+ ecfps = list()
102
+
103
+ for mol in mols:
104
+ fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
105
+ [mol], fpType=rdFingerprintGenerator.MorganFP
106
+ )[0]
107
+ fp = np.zeros((0,), np.int8)
108
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
109
+
110
+ ecfps.append(fp)
111
+
112
+ return np.array(ecfps)
113
+
114
+
115
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
116
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
117
+ return np.array(maccs)
118
+
119
+
120
+ def get_tox_patterns(filepath: str):
121
+ """This calculates tox features defined in tox_smarts.json.
122
+ Args:
123
+ mols: A list of Mol
124
+ n_jobs: If >1 multiprocessing is used
125
+ """
126
+ # load patterns
127
+ with open(filepath) as f:
128
+ smarts_list = [s[1] for s in json.load(f)]
129
+
130
+ # Code does not work for this case
131
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
132
+
133
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
134
+ # and then use them for all molecules. This gives a huge speedup over existing code.
135
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
136
+ all_patterns = []
137
+ for smarts in smarts_list:
138
+ patterns = [] # list of smarts-patterns
139
+ # value for each of the patterns above. Negates the values of the above later.
140
+ negations = []
141
+
142
+ if " AND " in smarts:
143
+ smarts = smarts.split(" AND ")
144
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
145
+ else:
146
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
147
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
148
+ smarts = smarts.split(" OR ")
149
+ merge_any = True
150
+
151
+ # for all subsmarts check if they are preceded by 'NOT '
152
+ for s in smarts:
153
+ neg = s.startswith("NOT ")
154
+ if neg:
155
+ s = s[4:]
156
+ patterns.append(Chem.MolFromSmarts(s))
157
+ negations.append(neg)
158
+
159
+ all_patterns.append((patterns, negations, merge_any))
160
+ return all_patterns
161
+
162
+
163
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
164
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
165
+ tox_data = []
166
+ for mol in mols:
167
+ mol_features = []
168
+ for patts, negations, merge_any in patterns:
169
+ matches = [mol.HasSubstructMatch(p) for p in patts]
170
+ matches = [m != n for m, n in zip(matches, negations)]
171
+ if merge_any:
172
+ pres = any(matches)
173
+ else:
174
+ pres = all(matches)
175
+ mol_features.append(pres)
176
+
177
+ tox_data.append(np.array(mol_features))
178
+
179
+ return np.array(tox_data)
180
+
181
+
182
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
183
+ """This function creates RDKit descriptors for a list of molecules.
184
+
185
+ Args:
186
+ mols (list[Mol]): list of molecules
187
+
188
+ Returns:
189
+ np.ndarray: RDKit descriptors of molecules
190
+ """
191
+ rdkit_descriptors = list()
192
+
193
+ for mol in mols:
194
+ descrs = []
195
+ for _, descr_calc_fn in Descriptors._descList:
196
+ descrs.append(descr_calc_fn(mol))
197
+
198
+ descrs = np.array(descrs)
199
+ descrs = descrs[USED_200_DESCR]
200
+ rdkit_descriptors.append(descrs)
201
+
202
+ return np.array(rdkit_descriptors)
203
+
204
+
205
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
206
+ """Create quantile values for given features using the columns
207
+
208
+ Args:
209
+ raw_features (np.ndarray): values to put into quantiles
210
+ ecdfs (list): ECDFs to use
211
+
212
+ Returns:
213
+ np.ndarray: computed quantiles
214
+ """
215
+ quantiles = np.zeros_like(raw_features)
216
+
217
+ for column in range(raw_features.shape[1]):
218
+ raw_values = raw_features[:, column].reshape(-1)
219
+ ecdf = ecdfs[column]
220
+ q = ecdf(raw_values)
221
+ quantiles[:, column] = q
222
+
223
+ return quantiles
224
+
225
+
226
+ def fill(features, mask, value=np.nan):
227
+ n_mols = len(mask)
228
+ n_features = features.shape[1]
229
+
230
+ data = np.zeros(shape=(n_mols, n_features))
231
+ data.fill(value)
232
+ data[~mask] = features
233
+ return data
234
+
235
+
236
+ def normalize_features(
237
+ raw_features,
238
+ scaler=None,
239
+ save_scaler_path: str = "",
240
+ verbose=True,
241
+ ):
242
+ if scaler is None:
243
+ scaler = StandardScaler()
244
+ scaler.fit(raw_features)
245
+ if verbose:
246
+ print("Fitted the StandardScaler")
247
+ if save_scaler_path:
248
+ write_pickle(save_scaler_path, scaler)
249
+ if verbose:
250
+ print(f"Saved the StandardScaler under {save_scaler_path}")
251
+
252
+ # Normalize feature vectors
253
+ normalized_features = scaler.transform(raw_features)
254
+ if verbose:
255
+ print("Normalized molecule features")
256
+ return normalized_features, scaler
257
+
258
+
259
+ def create_descriptors(
260
+ smiles,
261
+ ecdfs=None,
262
+ scaler=None,
263
+ descriptors: Iterable = KNOWN_DESCR,
264
+ ):
265
+ # Create cleanded rdkit mol objects
266
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
267
+ print("Cleaned molecules")
268
+
269
+ features = []
270
+ if "ecfps" in descriptors:
271
+ # Create fingerprints and descriptors
272
+ ecfps = create_ecfp_fps(mols)
273
+ # expand using mol_mask
274
+ ecfps = fill(ecfps, ~clean_mol_mask)
275
+ features.append(ecfps)
276
+ print("Created ECFP fingerprints")
277
+
278
+ if "rdkit_descr_quantiles" in descriptors:
279
+ rdkit_descrs = create_rdkit_descriptors(mols)
280
+ print("Created RDKit descriptors")
281
+
282
+ # Create and save ecdfs
283
+ if ecdfs is None:
284
+ print("Create ECDFs")
285
+ ecdfs = []
286
+ for column in range(rdkit_descrs.shape[1]):
287
+ raw_values = rdkit_descrs[:, column].reshape(-1)
288
+ ecdfs.append(ECDF(raw_values))
289
+
290
+ # Create quantiles
291
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
292
+ # expand using mol_mask
293
+ rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
294
+ features.append(rdkit_descr_quantiles)
295
+ print("Created quantiles of RDKit descriptors")
296
+
297
+ if "maccs" in descriptors:
298
+ maccs = create_maccs_keys(mols)
299
+ maccs = fill(maccs, ~clean_mol_mask)
300
+ features.append(maccs)
301
+ print("Created MACCS keys")
302
+
303
+ if "tox" in descriptors:
304
+ tox_patterns = get_tox_patterns("assets/tox_smarts.json")
305
+ tox = create_tox_features(mols, tox_patterns)
306
+ tox = fill(tox, ~clean_mol_mask)
307
+ features.append(tox)
308
+ print("Created Tox features")
309
+
310
+ # concatenate features
311
+ raw_features = np.concatenate(features, axis=1)
312
+
313
+ # normalize with scaler if scaler is passed, else create scaler
314
+ features, _ = normalize_features(
315
+ raw_features,
316
+ scaler=scaler,
317
+ verbose=True,
318
+ )
319
+
320
+ return features, clean_mol_mask
321
+
322
+
323
+ def main(args):
324
+ splits = ["train", "validation"]
325
+ ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
326
+
327
+ for split in splits:
328
+
329
+ print(f"Preprocess {split} molecules")
330
+ smiles = list(ds[split]["smiles"])
331
+
332
+ # Create cleanded rdkit mol objects
333
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
334
+ print("Cleaned molecules")
335
+
336
+ tox_patterns = get_tox_patterns(args.tox_smarts_filepath)
337
+
338
+ # Create fingerprints and descriptors
339
+ ecfps = create_ecfp_fps(mols)
340
+ # expand using mol_mask
341
+ ecfps = fill(ecfps, ~clean_mol_mask)
342
+ print("Created ECFP fingerprints")
343
+
344
+ rdkit_descrs = create_rdkit_descriptors(mols)
345
+ print("Created RDKit descriptors")
346
+
347
+ # Create and save ecdfs
348
+ if split == "train":
349
+ print("Create ECDFs")
350
+ ecdfs = []
351
+ for column in range(rdkit_descrs.shape[1]):
352
+ raw_values = rdkit_descrs[:, column].reshape(-1)
353
+ ecdfs.append(ECDF(raw_values))
354
+
355
+ write_pickle(args.path_ecdfs, ecdfs)
356
+ print(f"Saved ECDFs under {args.path_ecdfs}")
357
+ else:
358
+ print(f"Load ECDFs from {args.path_ecdfs}")
359
+ ecdfs = load_pickle(args.path_ecdfs)
360
+
361
+ # Create quantiles
362
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
363
+ # expand using mol_mask
364
+ rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
365
+ print("Created quantiles of RDKit descriptors")
366
+
367
+ maccs = create_maccs_keys(mols)
368
+ maccs = fill(maccs, ~clean_mol_mask)
369
+ print("Created MACCS keys")
370
+
371
+ tox = create_tox_features(mols, tox_patterns)
372
+ tox = fill(tox, ~clean_mol_mask)
373
+ print("Created Tox features")
374
+
375
+ labels = []
376
+ for task in TASKS:
377
+ datasplit = ds[split].to_pandas() if args.use_hf else ds[split]
378
+ labels.append(datasplit[task].to_numpy())
379
+ labels = np.stack(labels, axis=1)
380
+
381
+ save_path = os.path.join(args.save_folder, f"tox21_{split}.npz")
382
+ with open(save_path, "wb") as f:
383
+ np.savez(
384
+ f,
385
+ labels=labels,
386
+ ecfps=ecfps,
387
+ rdkit_descr_quantiles=rdkit_descr_quantiles,
388
+ maccs=maccs,
389
+ tox=tox,
390
+ )
391
+ print(f"Saved preprocessed {split} split under {save_path}")
392
+
393
+ print("Preprocessing finished successfully")
394
+
395
+
396
+ if __name__ == "__main__":
397
+ args = parser.parse_args()
398
+
399
+ if not os.path.exists(args.save_folder):
400
+ os.makedirs(args.save_folder)
401
+
402
+ if not os.path.exists(os.path.dirname(args.path_ecdfs)):
403
+ os.makedirs(os.path.dirname(args.path_ecdfs))
404
+
405
+ main(args)
src/train.py CHANGED
@@ -2,17 +2,19 @@
2
  Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model
3
  """
4
 
 
5
  import argparse
6
 
7
  import numpy as np
8
 
9
  from tabulate import tabulate
10
- from datasets import load_dataset
11
  from sklearn.metrics import roc_auc_score
12
 
13
- from .data import preprocess_molecules
14
  from .model import Tox21XGBClassifier
15
- from .utils import HF_TOKEN
 
 
16
 
17
  parser = argparse.ArgumentParser(description="XGBoost Trainig script for Tox21 dataset")
18
 
@@ -36,51 +38,156 @@ parser.add_argument(
36
 
37
 
38
  def main(args):
39
- ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
40
-
41
  print("Preprocess train molecules")
42
- train_smiles = list(ds["train"]["smiles"])
43
-
44
- train_features, train_mol_mask = preprocess_molecules(
45
- train_smiles,
46
- save_ecdf_path=args.path_ecdfs,
47
- save_scaler_path=args.path_scaler,
48
  )
49
-
50
- print("Preprocess validation molecules")
51
- val_smiles = list(ds["validation"]["smiles"])
52
- val_features, val_mol_mask = preprocess_molecules(
53
- val_smiles,
54
- load_ecdf_path=args.path_ecdfs,
55
- load_scaler_path=args.path_scaler,
56
  )
57
 
58
- model = Tox21XGBClassifier(seed=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  print("Start training.")
60
- for task in model.tasks:
61
- task_labels = ds["train"].to_pandas()[task].to_numpy()
62
- task_labels = task_labels[train_mol_mask]
63
-
64
  label_mask = ~np.isnan(task_labels)
65
 
 
 
 
66
  print(f"Fit task {task} using {sum(label_mask)} samples")
67
- model.fit(task, train_features[label_mask], task_labels[label_mask].astype(int))
68
 
69
  print(f"Save model under {args.save_path_model}")
70
  model.save_model(args.save_path_model)
71
 
72
  print("Evaluate model")
73
  results = {}
74
- for task in model.tasks:
75
- task_labels = ds["validation"].to_pandas()[task].to_numpy()
76
- task_labels = task_labels[val_mol_mask]
77
-
78
  label_mask = ~np.isnan(task_labels)
79
 
80
- pred = model.predict(task, val_features[label_mask])
81
- results[task] = [
82
- roc_auc_score(y_true=task_labels[label_mask].astype(int), y_score=pred)
83
- ]
 
84
 
85
  print("Results:")
86
  print(tabulate(results, headers="keys"))
 
2
  Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model
3
  """
4
 
5
+ import os
6
  import argparse
7
 
8
  import numpy as np
9
 
10
  from tabulate import tabulate
 
11
  from sklearn.metrics import roc_auc_score
12
 
13
+ from .data import get_descriptor_dataset
14
  from .model import Tox21XGBClassifier
15
+
16
+ SEED = 42
17
+ DATA_FOLDER = "data/"
18
 
19
  parser = argparse.ArgumentParser(description="XGBoost Trainig script for Tox21 dataset")
20
 
 
38
 
39
 
40
  def main(args):
 
 
41
  print("Preprocess train molecules")
42
+ # load datasets
43
+ train_X, train_y, scaler = get_descriptor_dataset(
44
+ os.path.join(DATA_FOLDER, "tox21_train.npz"),
45
+ descriptors="all",
46
+ save_scaler_path="data/scaler.pkl",
 
47
  )
48
+ val_X, val_y, _ = get_descriptor_dataset(
49
+ os.path.join(DATA_FOLDER, "tox21_validation.npz"),
50
+ descriptors="all",
51
+ scaler=scaler,
 
 
 
52
  )
53
 
54
+ task_config = {
55
+ "NR-AR": {
56
+ "colsample_bytree": 0.5,
57
+ "learning_rate": 0.05,
58
+ "max_depth": 12,
59
+ "min_child_weight": 2,
60
+ "n_estimators": 1000,
61
+ "scale_pos_weight": 80,
62
+ "subsample": 0.4,
63
+ },
64
+ "NR-AR-LBD": {
65
+ "colsample_bytree": 0.8,
66
+ "learning_rate": 0.04,
67
+ "max_depth": 10,
68
+ "min_child_weight": 8,
69
+ "n_estimators": 1000,
70
+ "scale_pos_weight": 10,
71
+ "subsample": 0.4,
72
+ },
73
+ "NR-AhR": {
74
+ "colsample_bytree": 0.8,
75
+ "learning_rate": 0.05,
76
+ "max_depth": 16,
77
+ "min_child_weight": 2,
78
+ "n_estimators": 1000,
79
+ "scale_pos_weight": 80,
80
+ "subsample": 1,
81
+ },
82
+ "NR-Aromatase": {
83
+ "colsample_bytree": 0.7,
84
+ "learning_rate": 0.05,
85
+ "max_depth": 16,
86
+ "min_child_weight": 1,
87
+ "n_estimators": 1000,
88
+ "scale_pos_weight": 50,
89
+ "subsample": 0.7,
90
+ },
91
+ "NR-ER": {
92
+ "colsample_bytree": 0.7,
93
+ "learning_rate": 0.05,
94
+ "max_depth": 10,
95
+ "min_child_weight": 4,
96
+ "n_estimators": 1000,
97
+ "scale_pos_weight": 25,
98
+ "subsample": 0.4,
99
+ },
100
+ "NR-ER-LBD": {
101
+ "colsample_bytree": 0.7,
102
+ "learning_rate": 0.05,
103
+ "max_depth": 16,
104
+ "min_child_weight": 4,
105
+ "n_estimators": 1000,
106
+ "scale_pos_weight": 10,
107
+ "subsample": 0.4,
108
+ },
109
+ "NR-PPAR-gamma": {
110
+ "colsample_bytree": 0.8,
111
+ "learning_rate": 0.01,
112
+ "max_depth": 12,
113
+ "min_child_weight": 2,
114
+ "n_estimators": 1000,
115
+ "scale_pos_weight": 80,
116
+ "subsample": 0.4,
117
+ },
118
+ "SR-ARE": {
119
+ "colsample_bytree": 0.7,
120
+ "learning_rate": 0.05,
121
+ "max_depth": 16,
122
+ "min_child_weight": 8,
123
+ "n_estimators": 1000,
124
+ "scale_pos_weight": 10,
125
+ "subsample": 0.7,
126
+ },
127
+ "SR-ATAD5": {
128
+ "colsample_bytree": 0.5,
129
+ "learning_rate": 0.02,
130
+ "max_depth": 12,
131
+ "min_child_weight": 8,
132
+ "n_estimators": 1000,
133
+ "scale_pos_weight": 10,
134
+ "subsample": 0.4,
135
+ },
136
+ "SR-HSE": {
137
+ "colsample_bytree": 0.8,
138
+ "learning_rate": 0.02,
139
+ "max_depth": 6,
140
+ "min_child_weight": 1,
141
+ "n_estimators": 1000,
142
+ "scale_pos_weight": 25,
143
+ "subsample": 1,
144
+ },
145
+ "SR-MMP": {
146
+ "colsample_bytree": 0.5,
147
+ "learning_rate": 0.02,
148
+ "max_depth": 16,
149
+ "min_child_weight": 2,
150
+ "n_estimators": 1000,
151
+ "scale_pos_weight": 10,
152
+ "subsample": 0.7,
153
+ },
154
+ "SR-p53": {
155
+ "colsample_bytree": 0.5,
156
+ "learning_rate": 0.02,
157
+ "max_depth": 12,
158
+ "min_child_weight": 8,
159
+ "n_estimators": 1000,
160
+ "scale_pos_weight": 10,
161
+ "subsample": 0.4,
162
+ },
163
+ }
164
+
165
+ model = Tox21XGBClassifier(seed=42, task_config=task_config)
166
  print("Start training.")
167
+ for i, task in enumerate(model.tasks):
168
+ task_labels = train_y[:, i]
 
 
169
  label_mask = ~np.isnan(task_labels)
170
 
171
+ task_data = train_X[label_mask]
172
+ task_labels = task_labels[label_mask].astype(int)
173
+
174
  print(f"Fit task {task} using {sum(label_mask)} samples")
175
+ model.fit(task, task_data, task_labels)
176
 
177
  print(f"Save model under {args.save_path_model}")
178
  model.save_model(args.save_path_model)
179
 
180
  print("Evaluate model")
181
  results = {}
182
+ for i, task in enumerate(model.tasks):
183
+ task_labels = val_y[:, i]
 
 
184
  label_mask = ~np.isnan(task_labels)
185
 
186
+ task_data = val_X[label_mask]
187
+ task_labels = task_labels[label_mask].astype(int)
188
+
189
+ pred = model.predict(task, task_data)
190
+ results[task] = [roc_auc_score(y_true=task_labels, y_score=pred)]
191
 
192
  print("Results:")
193
  print(tabulate(results, headers="keys"))
src/utils.py CHANGED
@@ -28,6 +28,8 @@ TASKS = [
28
  "SR-p53",
29
  ]
30
 
 
 
31
  USED_200_DESCR = [
32
  0,
33
  1,
 
28
  "SR-p53",
29
  ]
30
 
31
+ KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
32
+
33
  USED_200_DESCR = [
34
  0,
35
  1,