""" This files includes a XGBoost model for Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies import os import joblib import numpy as np from xgboost import XGBClassifier from .utils import TASKS # --------------------------------------------------------------------------------------- class Tox21XGBClassifier: """A XGBoost classifier that assigns a toxicity score to a given SMILES string.""" def __init__(self, seed: int = 42, task_configs: dict | None = None) -> None: """Initialize an XGBoost classifier for each of the 12 Tox21 tasks. Args: seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42. task_configs (dict | None, optional): dictionary containing task-specific hyperparameters. If None, default hyperparameters are used for all tasks. Defaults to None. """ self.tasks = TASKS self.model = { task: ( XGBClassifier(random_state=seed, n_jobs=8) if task_configs is None else XGBClassifier( **{ k: v for k, v in task_configs[task].items() if k != "var_threshold" }, random_state=seed, n_jobs=8, ) ) for task in self.tasks } self.feature_processors = {} def load_model(self, dir: str) -> None: """Loads the model from a given directory Args: dir (str): directory to load model from """ self.model = joblib.load(os.path.join(dir, "xgb_alltasks.joblib")) self.feature_processors = joblib.load( os.path.join(dir, "feature_processors.pkl") ) def save_model(self, dir: str) -> None: """Saves the model to a given directory Args: dir (str): directory to save model to """ model_path = os.path.join(dir, "xgb_alltasks.joblib") feature_processor_path = os.path.join(dir, "feature_processors.pkl") os.makedirs(dir, exist_ok=True) joblib.dump(self.model, model_path) joblib.dump(self.feature_processors, feature_processor_path) def fit( self, task: str, input_features: np.ndarray, labels: np.ndarray, **kwargs ) -> None: """Train XGBoost for a given task Args: task (str): task to train input_features (np.ndarray): training features labels (np.ndarray): training labels """ assert task in self.tasks, f"Unknown task: {task}" self.model[task].fit(input_features, labels, **kwargs) def predict(self, task: str, features: np.ndarray) -> np.ndarray: """Predicts labels for a given Tox21 target using molecule features Args: task (str): the Tox21 target to predict for features (np.ndarray): molecule features used for prediction Returns: np.ndarray: predicted probability for positive class """ assert task in self.tasks, f"Unknown task: {task}" assert ( len(features.shape) == 2 ), f"Function expects 2D np.array. Current shape: {features.shape}" preds = self.model[task].predict_proba(features) return preds[:, 1]