Spaces:
Sleeping
Sleeping
| """ | |
| This files includes a predict function for the Tox21. | |
| As an input it takes a list of SMILES and it outputs a nested dictionary with | |
| SMILES and target names as keys. | |
| """ | |
| # --------------------------------------------------------------------------------------- | |
| # Dependencies | |
| from collections import defaultdict | |
| import numpy as np | |
| from src.model import Tox21XGBClassifier | |
| from src.preprocess import create_descriptors | |
| from src.utils import TASKS | |
| # --------------------------------------------------------------------------------------- | |
| def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]: | |
| """Applies the classifier to a list of SMILES strings. Returns prediction=0.5 for | |
| any molecule that could not be cleaned. | |
| Args: | |
| smiles_list (list[str]): list of SMILES strings | |
| Returns: | |
| dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}} | |
| """ | |
| print(f"Received {len(smiles_list)} SMILES strings") | |
| # preprocessing pipeline | |
| features, is_clean = create_descriptors(smiles_list) | |
| print(f"Created {features.shape[1]} descriptors for the molecules.") | |
| # print( | |
| # f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning. All predictions for these will be set to 0.0." | |
| # ) | |
| # setup model | |
| model = Tox21XGBClassifier(seed=42) | |
| model_dir = "assets/" | |
| model.load_model(model_dir) | |
| print(f"Loaded model and feature processors from {model_dir}") | |
| # make predictions | |
| predictions = defaultdict(dict) | |
| preds = [] | |
| for target in TASKS: | |
| X = features.copy() | |
| preds = np.empty_like(is_clean, dtype=np.float64) | |
| preds[~is_clean] = 0.5 | |
| feature_processors = model.feature_processors[target] | |
| task_features = feature_processors["selector"].transform(X) | |
| task_features = feature_processors["scaler"].transform(task_features) | |
| preds[is_clean] = model.predict(target, task_features) | |
| for smiles, pred in zip(smiles_list, preds): | |
| predictions[smiles][target] = float(pred) | |
| return predictions | |