antoniaebner commited on
Commit
a3a1ae9
·
1 Parent(s): b308ee1

add code and assets

Browse files
Files changed (11) hide show
  1. .gitignore +1 -0
  2. Dockerfile +16 -0
  3. README.md +92 -1
  4. app.py +78 -0
  5. predict.py +54 -0
  6. requirements.txt +10 -0
  7. src/__init__.py +0 -0
  8. src/data.py +198 -0
  9. src/model.py +79 -0
  10. src/train.py +92 -0
  11. src/utils.py +441 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -9,4 +9,95 @@ license: apache-2.0
9
  short_description: XGBoost baseline classifier for Tox21
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  short_description: XGBoost baseline classifier for Tox21
10
  ---
11
 
12
+ # Tox21 XGBoost Classifier
13
+
14
+ This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/tschouis/tox21_leaderboard).
15
+
16
+ In this example, we train a XGBoost classifier on the Tox21 targets and save the trained model in the `assets/` folder.
17
+
18
+ **Important:** For leaderboard submission, your Space does not need to include training code. It only needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a prediction dictionary as output, with SMILES and targets as keys. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
19
+
20
+ # Repository Structure
21
+ - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
22
+ - `app.py` - FastAPI application wrapper (can be used as-is).
23
+
24
+ - `src/` - Core model & preprocessing logic:
25
+ - `data.py` - SMILES preprocessing pipeline
26
+ - `model.py` - XGBoost classifier wrapper
27
+ - `train.py` - Script to train the classifier
28
+ - `utils.py` – Constants and Helper functions
29
+
30
+ # Quickstart with Spaces
31
+
32
+ You can easily adapt this project in your own Hugging Face account:
33
+
34
+ - Open this Space on Hugging Face.
35
+
36
+ - Click "Duplicate this Space" (top-right corner).
37
+
38
+ - Modify `src/` for your preprocessing pipeline and model class
39
+
40
+ - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
41
+
42
+ That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
43
+
44
+ # Installation
45
+ To run (and train) the XGBoost, clone the repository and install dependencies:
46
+
47
+ ```bash
48
+ git clone https://huggingface.co/spaces/tschouis/tox21_xgboost_classifier
49
+ cd tox_21_xgb_classifier
50
+
51
+ conda create -n tox21_xgb_cls python=3.11
52
+ conda activate tox21_xgb_cls
53
+ pip install -r requirements.txt
54
+ ```
55
+
56
+ # Training
57
+
58
+ To train the XGBoost model from scratch:
59
+
60
+ ```bash
61
+ python -m src/train.py
62
+ ```
63
+
64
+ This will:
65
+
66
+ 1. Load and preprocess the Tox21 training dataset.
67
+ 2. Train a XGBoost classifier.
68
+ 3. Save the trained model to the assets/ folder.
69
+ 4. Evaluate the trained XGBoost classifier on the validation split.
70
+
71
+
72
+ # Inference
73
+
74
+ For inference, you only need `predict.py`.
75
+
76
+ Example usage inside Python:
77
+
78
+ ```python
79
+ from predict import predict
80
+
81
+ smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
82
+ results = predict(smiles_list)
83
+
84
+ print(results)
85
+ ```
86
+
87
+ The output will be a nested dictionary in the format:
88
+
89
+ ```python
90
+ {
91
+ "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
92
+ "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
93
+ "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
94
+ }
95
+ ```
96
+
97
+ # Notes
98
+
99
+ - Only adapting `predict.py` for your model inference is required for leaderboard submission.
100
+
101
+ - Training (`src/train.py`) is provided for reproducibility.
102
+
103
+ - Preprocessing (here inside `src/data.py`) must be applied at inference time, not just training.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "AwesomeTox",
48
+ "version": "1.0.0",
49
+ "max_batch_size": 256,
50
+ "tox_endpoints": [
51
+ "NR-AR",
52
+ "NR-AR-LBD",
53
+ "NR-AhR",
54
+ "NR-Aromatase",
55
+ "NR-ER",
56
+ "NR-ER-LBD",
57
+ "NR-PPAR-gamma",
58
+ "SR-ARE",
59
+ "SR-ATAD5",
60
+ "SR-HSE",
61
+ "SR-MMP",
62
+ "SR-p53",
63
+ ],
64
+ }
65
+
66
+
67
+ @app.get("/healthz")
68
+ def healthz():
69
+ return {"ok": True}
70
+
71
+
72
+ @app.post("/predict", response_model=Response)
73
+ def predict(request: Request):
74
+ predictions = predict_func(request.smiles)
75
+ return {
76
+ "predictions": predictions,
77
+ "model_info": {"name": "random_clf", "version": "1.0.0"},
78
+ }
predict.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a predict function for the Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from collections import defaultdict
10
+
11
+ from src.data import preprocess_molecules
12
+ from src.model import Tox21XGBClassifier
13
+
14
+ # ---------------------------------------------------------------------------------------
15
+
16
+
17
+ def predict(smiles_list: list[str]) -> dict:
18
+ """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
19
+ any molecule that could not be cleaned.
20
+
21
+ Args:
22
+ smiles_list (list[str]): list of SMILES strings
23
+
24
+ Returns:
25
+ dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
26
+ """
27
+ print(f"Received {len(smiles_list)} SMILES strings")
28
+ # preprocessing pipeline
29
+ features, removed_idxs = preprocess_molecules(
30
+ smiles_list,
31
+ load_ecdf_path="assets/ecdfs.pkl",
32
+ load_scaler_path="assets/scaler.pkl",
33
+ )
34
+ print(f"{len(removed_idxs)} molecules removed during cleaning")
35
+
36
+ # setup model
37
+ model = Tox21XGBClassifier(seed=42)
38
+ model.load_model("assets/xgb_alltasks.joblib")
39
+
40
+ # make predicitons
41
+ predictions = defaultdict(dict)
42
+ # make smiles list with same num_samples as features
43
+ clean_smiles = [smi for i, smi in enumerate(smiles_list) if i not in removed_idxs]
44
+ no_pred_smiles = [smi for i, smi in enumerate(smiles_list) if i in removed_idxs]
45
+
46
+ for target in model.tasks:
47
+ target_pred = model.predict(target, features)
48
+ for i, smiles in enumerate(clean_smiles):
49
+ predictions[smiles][target] = target_pred[i]
50
+
51
+ for smiles in no_pred_smiles:
52
+ predictions[smiles][target] = 0.0
53
+
54
+ return predictions
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ statsmodels
4
+ rdkit
5
+ numpy
6
+ scikit-learn==1.7.1
7
+ joblib
8
+ tabulate
9
+ datasets
10
+ xgboost=3.0.5
src/__init__.py ADDED
File without changes
src/data.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import os
10
+
11
+ import numpy as np
12
+
13
+ from sklearn.preprocessing import StandardScaler
14
+ from statsmodels.distributions.empirical_distribution import ECDF
15
+
16
+ from rdkit import Chem, DataStructs
17
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator
18
+ from rdkit.Chem.rdchem import Mol
19
+
20
+ from utils import USED_200_DESCR, Standardizer, load_pickle, write_pickle
21
+
22
+
23
+ def preprocess_molecules(
24
+ smiles_list: list[str],
25
+ load_ecdf_path: str = "",
26
+ load_scaler_path: str = "",
27
+ save_ecdf_path: str = "",
28
+ save_scaler_path: str = "",
29
+ ) -> tuple[np.ndarray, list[int]]:
30
+ """Preprocessing pipeline for a list of molecules.
31
+
32
+ Args:
33
+ smiles_list (list[str]): list of SMILES
34
+ load_ecdf_path (str, optional): Path to load ECDFs from. Defaults to "".
35
+ load_scaler_path (str, optional): Path to load fitted StandardScaler from. Defaults to "".
36
+ save_ecdf_path (str, optional): Path to save calculated ECDFs. Defaults to "".
37
+ save_scaler_path (str, optional): Path to save fitted StandardScaler. Defaults to "".
38
+
39
+ Returns:
40
+ np.ndarray: normalized ECFPs fingerprints and RDKit descriptor quantiles
41
+ list[bool]: mask that contains False at index `i`, if molecule in `smiles_list` at
42
+ index `i` could not be cleaned and was removed.
43
+ """
44
+
45
+ assert not (
46
+ load_ecdf_path and save_ecdf_path
47
+ ), "Cannot pass 'load_ecdf_path' and 'save_ecdf_path' simultaneously"
48
+ assert not (
49
+ load_scaler_path and save_scaler_path
50
+ ), "Cannot pass 'load_scaler_path' and 'save_scaler_path' simultaneously"
51
+
52
+ ecdfs = (
53
+ load_pickle(load_ecdf_path)
54
+ if load_ecdf_path and os.path.exists(load_ecdf_path)
55
+ else None
56
+ )
57
+ scaler = (
58
+ load_pickle(load_scaler_path)
59
+ if load_scaler_path and os.path.exists(load_scaler_path)
60
+ else None
61
+ )
62
+
63
+ # Create cleanded rdkit mol objects
64
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles_list)
65
+ print("Cleaned molecules")
66
+
67
+ # Create fingerprints and descriptors
68
+ ecfps = create_ecfp_fps(mols)
69
+ print("Created ECFP fingerprints")
70
+ rdkit_descrs = create_rdkit_descriptors(mols)
71
+ print("Created RDKit descriptors")
72
+
73
+ # Create and save ecdfs
74
+ if ecdfs is None:
75
+ print("Create ECDFs")
76
+ ecdfs = []
77
+ for column in range(rdkit_descrs.shape[1]):
78
+ raw_values = rdkit_descrs[:, column].reshape(-1)
79
+ ecdfs.append(ECDF(raw_values))
80
+ if save_ecdf_path:
81
+ write_pickle(save_ecdf_path, ecdfs)
82
+ print(f"Saved ECDFs under {save_ecdf_path}")
83
+
84
+ # Create quantiles
85
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
86
+ print("Created quantiles of RDKit descriptors")
87
+
88
+ # Concatenate features
89
+ raw_features = np.concatenate((ecfps, rdkit_descr_quantiles), axis=1)
90
+
91
+ if scaler is None:
92
+ scaler = StandardScaler()
93
+ scaler.fit(raw_features)
94
+ print("Fitted the StandardScaler")
95
+ if save_scaler_path:
96
+ write_pickle(save_scaler_path, scaler)
97
+ print(f"Saved the StandardScaler under {save_scaler_path}")
98
+
99
+ # Normalize feature vectors
100
+ normalized_features = scaler.transform(raw_features)
101
+ print("Normalized the molecule features")
102
+
103
+ return normalized_features, clean_mol_mask
104
+
105
+
106
+ def create_cleaned_mol_objects(smiles: list[str]) -> list[Mol]:
107
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
108
+
109
+ Args:
110
+ smiles (list[str]): list of SMILES
111
+
112
+ Returns:
113
+ list[Mol]: list of cleaned molecules
114
+ list[bool]: mask that contains False at index `i`, if molecule in `smiles` at
115
+ index `i` could not be cleaned and was removed.
116
+ """
117
+ sm = Standardizer(canon_taut=True)
118
+
119
+ clean_mol_mask = list()
120
+ mols = list()
121
+ for i, smile in enumerate(smiles):
122
+ mol = Chem.MolFromSmiles(smile)
123
+ standardized_mol, _ = sm.standardize_mol(mol)
124
+ is_cleaned = standardized_mol is not None
125
+ clean_mol_mask.append(is_cleaned)
126
+ if not is_cleaned:
127
+ continue
128
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
129
+ mols.append(can_mol)
130
+
131
+ return mols, clean_mol_mask
132
+
133
+
134
+ def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
135
+ """This function ECFP fingerprints for a list of molecules.
136
+
137
+ Args:
138
+ mols (list[Mol]): list of molecules
139
+
140
+ Returns:
141
+ np.ndarray: ECFP fingerprints of molecules
142
+ """
143
+ ecfps = list()
144
+
145
+ for mol in mols:
146
+ fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
147
+ [mol], fpType=rdFingerprintGenerator.MorganFP
148
+ )[0]
149
+ fp = np.zeros((0,), np.int8)
150
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
151
+
152
+ ecfps.append(fp)
153
+
154
+ return np.array(ecfps)
155
+
156
+
157
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
158
+ """This function creates RDKit descriptors for a list of molecules.
159
+
160
+ Args:
161
+ mols (list[Mol]): list of molecules
162
+
163
+ Returns:
164
+ np.ndarray: RDKit descriptors of molecules
165
+ """
166
+ rdkit_descriptors = list()
167
+
168
+ for mol in mols:
169
+ descrs = []
170
+ for _, descr_calc_fn in Descriptors._descList:
171
+ descrs.append(descr_calc_fn(mol))
172
+
173
+ descrs = np.array(descrs)
174
+ descrs = descrs[USED_200_DESCR]
175
+ rdkit_descriptors.append(descrs)
176
+
177
+ return np.array(rdkit_descriptors)
178
+
179
+
180
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
181
+ """Create quantile values for given features using the columns
182
+
183
+ Args:
184
+ raw_features (np.ndarray): values to put into quantiles
185
+ ecdfs (list): ECDFs to use
186
+
187
+ Returns:
188
+ np.ndarray: computed quantiles
189
+ """
190
+ quantiles = np.zeros_like(raw_features)
191
+
192
+ for column in range(raw_features.shape[1]):
193
+ raw_values = raw_features[:, column].reshape(-1)
194
+ ecdf = ecdfs[column]
195
+ q = ecdf(raw_values)
196
+ quantiles[:, column] = q
197
+
198
+ return quantiles
src/model.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a XGBoost model for Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ import os
10
+ import joblib
11
+
12
+ import numpy as np
13
+ from xgboost import XGBClassifier
14
+
15
+ from utils import TASKS
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Tox21XGBClassifier:
20
+ """A XGBoost classifier that assigns a toxicity score to a given SMILES string."""
21
+
22
+ def __init__(self, seed: int = 42):
23
+ """Initialize an XGBoost classifier for each of the 12 Tox21 tasks.
24
+
25
+ Args:
26
+ seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
27
+ """
28
+ self.tasks = TASKS
29
+ self.model = {
30
+ task: XGBClassifier(n_estimators=1000, random_state=seed, n_jobs=8)
31
+ for task in self.tasks
32
+ }
33
+
34
+ def load_model(self, path: str) -> None:
35
+ """Loads the model from a given path
36
+
37
+ Args:
38
+ path (str): path to model checkpoint
39
+ """
40
+ self.model = joblib.load(path)
41
+
42
+ def save_model(self, path: str) -> None:
43
+ """Saves the model to a given path
44
+
45
+ Args:
46
+ path (str): path to save model to
47
+ """
48
+ if not os.path.exists(os.path.dirname(path)):
49
+ os.makedirs(os.path.dirname(path))
50
+
51
+ joblib.dump(self.model, path)
52
+
53
+ def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None:
54
+ """Train XGBoost for a given task
55
+
56
+ Args:
57
+ task (str): task to train
58
+ input_features (np.ndarray): training features
59
+ labels (np.ndarray): training labels
60
+ """
61
+ assert task in self.tasks, f"Unknown task: {task}"
62
+ self.model[task].fit(input_features, labels)
63
+
64
+ def predict(self, task: str, features: np.ndarray) -> np.ndarray:
65
+ """Predicts labels for a given Tox21 target using molecule features
66
+
67
+ Args:
68
+ task (str): the Tox21 target to predict for
69
+ features (np.ndarray): molecule features used for prediction
70
+
71
+ Returns:
72
+ np.ndarray: predicted probability for positive class
73
+ """
74
+ assert task in self.tasks, f"Unknown task: {task}"
75
+ assert (
76
+ len(features.shape) == 2
77
+ ), f"Function expects 2D np.array. Current shape: {features.shape}"
78
+ preds = self.model[task].predict_proba(features)
79
+ return preds[:, 1]
src/train.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model
3
+ """
4
+
5
+ import argparse
6
+
7
+ import numpy as np
8
+
9
+ from tabulate import tabulate
10
+ from datasets import load_dataset
11
+ from sklearn.metrics import roc_auc_score
12
+
13
+ from data import preprocess_molecules
14
+ from model import Tox21XGBClassifier
15
+ from utils import HF_TOKEN
16
+
17
+ parser = argparse.ArgumentParser(description="XGBoost Trainig script for Tox21 dataset")
18
+
19
+ parser.add_argument(
20
+ "--save_path_model",
21
+ type=str,
22
+ default="assets/xgb_alltasks.joblib",
23
+ )
24
+
25
+ parser.add_argument(
26
+ "--path_ecdfs",
27
+ type=str,
28
+ default="assets/ecdfs.pkl",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--path_scaler",
33
+ type=str,
34
+ default="assets/scaler.pkl",
35
+ )
36
+
37
+
38
+ def main(args):
39
+ ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
40
+
41
+ print("Preprocess train molecules")
42
+ train_smiles = list(ds["train"]["smiles"])
43
+
44
+ train_features, train_mol_mask = preprocess_molecules(
45
+ train_smiles,
46
+ save_ecdf_path=args.path_ecdfs,
47
+ save_scaler_path=args.path_scaler,
48
+ )
49
+
50
+ print("Preprocess validation molecules")
51
+ val_smiles = list(ds["validation"]["smiles"])
52
+ val_features, val_mol_mask = preprocess_molecules(
53
+ val_smiles,
54
+ load_ecdf_path=args.path_ecdfs,
55
+ load_scaler_path=args.path_scaler,
56
+ )
57
+
58
+ model = Tox21XGBClassifier(seed=42)
59
+ print("Start training.")
60
+ for task in model.tasks:
61
+ task_labels = ds["train"].to_pandas()[task].to_numpy()
62
+ task_labels = task_labels[train_mol_mask]
63
+
64
+ label_mask = ~np.isnan(task_labels)
65
+
66
+ print(f"Fit task {task} using {sum(label_mask)} samples")
67
+ model.fit(task, train_features[label_mask], task_labels[label_mask].astype(int))
68
+
69
+ print(f"Save model under {args.save_path_model}")
70
+ model.save_model(args.save_path_model)
71
+
72
+ print("Evaluate model")
73
+ results = {}
74
+ for task in model.tasks:
75
+ task_labels = ds["validation"].to_pandas()[task].to_numpy()
76
+ task_labels = task_labels[val_mol_mask]
77
+
78
+ label_mask = ~np.isnan(task_labels)
79
+
80
+ pred = model.predict(task, val_features[label_mask])
81
+ results[task] = [
82
+ roc_auc_score(y_true=task_labels[label_mask].astype(int), y_score=pred)
83
+ ]
84
+
85
+ print("Results:")
86
+ print(tabulate(results, headers="keys"))
87
+ print("Average: ", sum([val[0] for val in results.values()]) / len(results))
88
+
89
+
90
+ if __name__ == "__main__":
91
+ args = parser.parse_args()
92
+ main(args)
src/utils.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## These MolStandardizer classes are due to Paolo Tosco
2
+ ## It was taken from the FS-Mol github
3
+ ## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/
4
+ ## standardizer.py)
5
+ ## They ensure that a sequence of standardization operations are applied
6
+ ## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e
7
+
8
+ import os
9
+ import pickle
10
+
11
+ from rdkit import Chem
12
+ from rdkit.Chem.MolStandardize import rdMolStandardize
13
+
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+
16
+ TASKS = [
17
+ "NR-AR",
18
+ "NR-AR-LBD",
19
+ "NR-AhR",
20
+ "NR-Aromatase",
21
+ "NR-ER",
22
+ "NR-ER-LBD",
23
+ "NR-PPAR-gamma",
24
+ "SR-ARE",
25
+ "SR-ATAD5",
26
+ "SR-HSE",
27
+ "SR-MMP",
28
+ "SR-p53",
29
+ ]
30
+
31
+ USED_200_DESCR = [
32
+ 0,
33
+ 1,
34
+ 2,
35
+ 3,
36
+ 4,
37
+ 5,
38
+ 6,
39
+ 7,
40
+ 8,
41
+ 9,
42
+ 10,
43
+ 11,
44
+ 12,
45
+ 13,
46
+ 14,
47
+ 15,
48
+ 16,
49
+ 25,
50
+ 26,
51
+ 27,
52
+ 28,
53
+ 29,
54
+ 30,
55
+ 31,
56
+ 32,
57
+ 33,
58
+ 34,
59
+ 35,
60
+ 36,
61
+ 37,
62
+ 38,
63
+ 39,
64
+ 40,
65
+ 41,
66
+ 42,
67
+ 43,
68
+ 44,
69
+ 45,
70
+ 46,
71
+ 47,
72
+ 48,
73
+ 49,
74
+ 50,
75
+ 51,
76
+ 52,
77
+ 53,
78
+ 54,
79
+ 55,
80
+ 56,
81
+ 57,
82
+ 58,
83
+ 59,
84
+ 60,
85
+ 61,
86
+ 62,
87
+ 63,
88
+ 64,
89
+ 65,
90
+ 66,
91
+ 67,
92
+ 68,
93
+ 69,
94
+ 70,
95
+ 71,
96
+ 72,
97
+ 73,
98
+ 74,
99
+ 75,
100
+ 76,
101
+ 77,
102
+ 78,
103
+ 79,
104
+ 80,
105
+ 81,
106
+ 82,
107
+ 83,
108
+ 84,
109
+ 85,
110
+ 86,
111
+ 87,
112
+ 88,
113
+ 89,
114
+ 90,
115
+ 91,
116
+ 92,
117
+ 93,
118
+ 94,
119
+ 95,
120
+ 96,
121
+ 97,
122
+ 98,
123
+ 99,
124
+ 100,
125
+ 101,
126
+ 102,
127
+ 103,
128
+ 104,
129
+ 105,
130
+ 106,
131
+ 107,
132
+ 108,
133
+ 109,
134
+ 110,
135
+ 111,
136
+ 112,
137
+ 113,
138
+ 114,
139
+ 115,
140
+ 116,
141
+ 117,
142
+ 118,
143
+ 119,
144
+ 120,
145
+ 121,
146
+ 122,
147
+ 123,
148
+ 124,
149
+ 125,
150
+ 126,
151
+ 127,
152
+ 128,
153
+ 129,
154
+ 130,
155
+ 131,
156
+ 132,
157
+ 133,
158
+ 134,
159
+ 135,
160
+ 136,
161
+ 137,
162
+ 138,
163
+ 139,
164
+ 140,
165
+ 141,
166
+ 142,
167
+ 143,
168
+ 144,
169
+ 145,
170
+ 146,
171
+ 147,
172
+ 148,
173
+ 149,
174
+ 150,
175
+ 151,
176
+ 152,
177
+ 153,
178
+ 154,
179
+ 155,
180
+ 156,
181
+ 157,
182
+ 158,
183
+ 159,
184
+ 160,
185
+ 161,
186
+ 162,
187
+ 163,
188
+ 164,
189
+ 165,
190
+ 166,
191
+ 167,
192
+ 168,
193
+ 169,
194
+ 170,
195
+ 171,
196
+ 172,
197
+ 173,
198
+ 174,
199
+ 175,
200
+ 176,
201
+ 177,
202
+ 178,
203
+ 179,
204
+ 180,
205
+ 181,
206
+ 182,
207
+ 183,
208
+ 184,
209
+ 185,
210
+ 186,
211
+ 187,
212
+ 188,
213
+ 189,
214
+ 190,
215
+ 191,
216
+ 192,
217
+ 193,
218
+ 194,
219
+ 195,
220
+ 196,
221
+ 197,
222
+ 198,
223
+ 199,
224
+ 200,
225
+ 201,
226
+ 202,
227
+ 203,
228
+ 204,
229
+ 205,
230
+ 206,
231
+ 207,
232
+ ]
233
+
234
+
235
+ class Standardizer:
236
+ """
237
+ Simple wrapper class around rdkit Standardizer.
238
+ """
239
+
240
+ DEFAULT_CANON_TAUT = False
241
+ DEFAULT_METAL_DISCONNECT = False
242
+ MAX_TAUTOMERS = 100
243
+ MAX_TRANSFORMS = 100
244
+ MAX_RESTARTS = 200
245
+ PREFER_ORGANIC = True
246
+
247
+ def __init__(
248
+ self,
249
+ metal_disconnect=None,
250
+ canon_taut=None,
251
+ ):
252
+ """
253
+ Constructor.
254
+ All parameters are optional.
255
+ :param metal_disconnect: if True, metallorganic complexes are
256
+ disconnected
257
+ :param canon_taut: if True, molecules are converted to their
258
+ canonical tautomer
259
+ """
260
+ super().__init__()
261
+ if metal_disconnect is None:
262
+ metal_disconnect = self.DEFAULT_METAL_DISCONNECT
263
+ if canon_taut is None:
264
+ canon_taut = self.DEFAULT_CANON_TAUT
265
+ self._canon_taut = canon_taut
266
+ self._metal_disconnect = metal_disconnect
267
+ self._taut_enumerator = None
268
+ self._uncharger = None
269
+ self._lfrag_chooser = None
270
+ self._metal_disconnector = None
271
+ self._normalizer = None
272
+ self._reionizer = None
273
+ self._params = None
274
+
275
+ @property
276
+ def params(self):
277
+ """Return the MolStandardize CleanupParameters."""
278
+ if self._params is None:
279
+ self._params = rdMolStandardize.CleanupParameters()
280
+ self._params.maxTautomers = self.MAX_TAUTOMERS
281
+ self._params.maxTransforms = self.MAX_TRANSFORMS
282
+ self._params.maxRestarts = self.MAX_RESTARTS
283
+ self._params.preferOrganic = self.PREFER_ORGANIC
284
+ self._params.tautomerRemoveSp3Stereo = False
285
+ return self._params
286
+
287
+ @property
288
+ def canon_taut(self):
289
+ """Return whether tautomer canonicalization will be done."""
290
+ return self._canon_taut
291
+
292
+ @property
293
+ def metal_disconnect(self):
294
+ """Return whether metallorganic complexes will be disconnected."""
295
+ return self._metal_disconnect
296
+
297
+ @property
298
+ def taut_enumerator(self):
299
+ """Return the TautomerEnumerator object."""
300
+ if self._taut_enumerator is None:
301
+ self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
302
+ return self._taut_enumerator
303
+
304
+ @property
305
+ def uncharger(self):
306
+ """Return the Uncharger object."""
307
+ if self._uncharger is None:
308
+ self._uncharger = rdMolStandardize.Uncharger()
309
+ return self._uncharger
310
+
311
+ @property
312
+ def lfrag_chooser(self):
313
+ """Return the LargestFragmentChooser object."""
314
+ if self._lfrag_chooser is None:
315
+ self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(
316
+ self.params.preferOrganic
317
+ )
318
+ return self._lfrag_chooser
319
+
320
+ @property
321
+ def metal_disconnector(self):
322
+ """Return the MetalDisconnector object."""
323
+ if self._metal_disconnector is None:
324
+ self._metal_disconnector = rdMolStandardize.MetalDisconnector()
325
+ return self._metal_disconnector
326
+
327
+ @property
328
+ def normalizer(self):
329
+ """Return the Normalizer object."""
330
+ if self._normalizer is None:
331
+ self._normalizer = rdMolStandardize.Normalizer(
332
+ self.params.normalizationsFile, self.params.maxRestarts
333
+ )
334
+ return self._normalizer
335
+
336
+ @property
337
+ def reionizer(self):
338
+ """Return the Reionizer object."""
339
+ if self._reionizer is None:
340
+ self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile)
341
+ return self._reionizer
342
+
343
+ def charge_parent(self, mol_in):
344
+ """Sequentially apply a series of MolStandardize operations:
345
+ * MetalDisconnector
346
+ * Normalizer
347
+ * Reionizer
348
+ * LargestFragmentChooser
349
+ * Uncharger
350
+ The net result is that a desalted, normalized, neutral
351
+ molecule with implicit Hs is returned.
352
+ """
353
+ params = Chem.RemoveHsParameters()
354
+ params.removeAndTrackIsotopes = True
355
+ mol_in = Chem.RemoveHs(mol_in, params, sanitize=False)
356
+ if self._metal_disconnect:
357
+ mol_in = self.metal_disconnector.Disconnect(mol_in)
358
+ normalized = self.normalizer.normalize(mol_in)
359
+ Chem.SanitizeMol(normalized)
360
+ normalized = self.reionizer.reionize(normalized)
361
+ Chem.AssignStereochemistry(normalized)
362
+ normalized = self.lfrag_chooser.choose(normalized)
363
+ normalized = self.uncharger.uncharge(normalized)
364
+ # need this to reassess aromaticity on things like
365
+ # cyclopentadienyl, tropylium, azolium, etc.
366
+ Chem.SanitizeMol(normalized)
367
+ return Chem.RemoveHs(Chem.AddHs(normalized))
368
+
369
+ def standardize_mol(self, mol_in):
370
+ """
371
+ Standardize a single molecule.
372
+ :param mol_in: a Chem.Mol
373
+ :return: * (standardized Chem.Mol, n_taut) tuple
374
+ if success. n_taut will be negative if
375
+ tautomer enumeration was aborted due
376
+ to reaching a limit
377
+ * (None, error_msg) if failure
378
+ This calls self.charge_parent() and, if self._canon_taut
379
+ is True, runs tautomer canonicalization.
380
+ """
381
+ n_tautomers = 0
382
+ if isinstance(mol_in, Chem.Mol):
383
+ name = None
384
+ try:
385
+ name = mol_in.GetProp("_Name")
386
+ except KeyError:
387
+ pass
388
+ if not name:
389
+ name = "NONAME"
390
+ else:
391
+ error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}"
392
+ return None, error
393
+ try:
394
+ mol_out = self.charge_parent(mol_in)
395
+ except Exception as e:
396
+ error = f"charge_parent FAILED: {str(e).strip()}"
397
+ return None, error
398
+ if self._canon_taut:
399
+ try:
400
+ res = self.taut_enumerator.Enumerate(mol_out, False)
401
+ except TypeError:
402
+ # we are still on the pre-2021 RDKit API
403
+ res = self.taut_enumerator.Enumerate(mol_out)
404
+ except Exception as e:
405
+ # something else went wrong
406
+ error = f"canon_taut FAILED: {str(e).strip()}"
407
+ return None, error
408
+ n_tautomers = len(res)
409
+ if hasattr(res, "status"):
410
+ completed = (
411
+ res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed
412
+ )
413
+ else:
414
+ # we are still on the pre-2021 RDKit API
415
+ completed = len(res) < 1000
416
+ if not completed:
417
+ n_tautomers = -n_tautomers
418
+ try:
419
+ mol_out = self.taut_enumerator.PickCanonical(res)
420
+ except AttributeError:
421
+ # we are still on the pre-2021 RDKit API
422
+ mol_out = max(
423
+ [(self.taut_enumerator.ScoreTautomer(m), m) for m in res]
424
+ )[1]
425
+ except Exception as e:
426
+ # something else went wrong
427
+ error = f"canon_taut FAILED: {str(e).strip()}"
428
+ return None, error
429
+ mol_out.SetProp("_Name", name)
430
+ return mol_out, n_tautomers
431
+
432
+
433
+ def load_pickle(path: str):
434
+ with open(path, "rb") as file:
435
+ content = pickle.load(file)
436
+ return content
437
+
438
+
439
+ def write_pickle(path: str, obj: object):
440
+ with open(path, "wb") as file:
441
+ pickle.dump(obj, file)