Spaces:
Sleeping
Sleeping
| """ | |
| This files includes a XGBoost model for Tox21. | |
| As an input it takes a list of SMILES and it outputs a nested dictionary with | |
| SMILES and target names as keys. | |
| """ | |
| # --------------------------------------------------------------------------------------- | |
| # Dependencies | |
| import os | |
| import joblib | |
| import numpy as np | |
| from xgboost import XGBClassifier | |
| from .utils import TASKS | |
| # --------------------------------------------------------------------------------------- | |
| class Tox21XGBClassifier: | |
| """A XGBoost classifier that assigns a toxicity score to a given SMILES string.""" | |
| def __init__(self, seed: int = 42, task_configs: dict | None = None) -> None: | |
| """Initialize an XGBoost classifier for each of the 12 Tox21 tasks. | |
| Args: | |
| seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42. | |
| task_configs (dict | None, optional): dictionary containing task-specific | |
| hyperparameters. If None, default hyperparameters are used for all tasks. | |
| Defaults to None. | |
| """ | |
| self.tasks = TASKS | |
| self.model = { | |
| task: ( | |
| XGBClassifier(random_state=seed, n_jobs=8) | |
| if task_configs is None | |
| else XGBClassifier( | |
| **{ | |
| k: v | |
| for k, v in task_configs[task].items() | |
| if k != "var_threshold" | |
| }, | |
| random_state=seed, | |
| n_jobs=8, | |
| ) | |
| ) | |
| for task in self.tasks | |
| } | |
| self.feature_processors = {} | |
| def load_model(self, dir: str) -> None: | |
| """Loads the model from a given directory | |
| Args: | |
| dir (str): directory to load model from | |
| """ | |
| self.model = joblib.load(os.path.join(dir, "xgb_alltasks.joblib")) | |
| self.feature_processors = joblib.load( | |
| os.path.join(dir, "feature_processors.pkl") | |
| ) | |
| def save_model(self, dir: str) -> None: | |
| """Saves the model to a given directory | |
| Args: | |
| dir (str): directory to save model to | |
| """ | |
| model_path = os.path.join(dir, "xgb_alltasks.joblib") | |
| feature_processor_path = os.path.join(dir, "feature_processors.pkl") | |
| os.makedirs(dir, exist_ok=True) | |
| joblib.dump(self.model, model_path) | |
| joblib.dump(self.feature_processors, feature_processor_path) | |
| def fit( | |
| self, task: str, input_features: np.ndarray, labels: np.ndarray, **kwargs | |
| ) -> None: | |
| """Train XGBoost for a given task | |
| Args: | |
| task (str): task to train | |
| input_features (np.ndarray): training features | |
| labels (np.ndarray): training labels | |
| """ | |
| assert task in self.tasks, f"Unknown task: {task}" | |
| self.model[task].fit(input_features, labels, **kwargs) | |
| def predict(self, task: str, features: np.ndarray) -> np.ndarray: | |
| """Predicts labels for a given Tox21 target using molecule features | |
| Args: | |
| task (str): the Tox21 target to predict for | |
| features (np.ndarray): molecule features used for prediction | |
| Returns: | |
| np.ndarray: predicted probability for positive class | |
| """ | |
| assert task in self.tasks, f"Unknown task: {task}" | |
| assert ( | |
| len(features.shape) == 2 | |
| ), f"Function expects 2D np.array. Current shape: {features.shape}" | |
| preds = self.model[task].predict_proba(features) | |
| return preds[:, 1] | |