Spaces:
Sleeping
Sleeping
| """ | |
| Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model | |
| """ | |
| import os | |
| import argparse | |
| import numpy as np | |
| from tabulate import tabulate | |
| from sklearn.feature_selection import VarianceThreshold | |
| from sklearn.metrics import roc_auc_score | |
| from sklearn.preprocessing import StandardScaler | |
| from model import Tox21XGBClassifier | |
| SEED = 999 | |
| DATA_FOLDER = "data/" | |
| parser = argparse.ArgumentParser(description="XGBoost Training script for Tox21 dataset") | |
| parser.add_argument( | |
| "--model_dir", | |
| type=str, | |
| default="assets", | |
| ) | |
| def main(args): | |
| print("Preprocess train molecules") | |
| data_path = os.path.join(DATA_FOLDER, "tox21_data.npz") | |
| full_data = np.load(data_path, allow_pickle=True) | |
| features = full_data["features"] | |
| labels = full_data["labels"] | |
| sets = full_data["sets"] | |
| # Handle inf/nan features: instead of dropping columns, zero-out entire affected columns | |
| # so that VarianceThreshold will remove them later, keeping indices aligned. | |
| bad_entries = np.isinf(features) | np.isnan(features) | |
| bad_cols = np.any(bad_entries, axis=0) | |
| if np.any(bad_cols): | |
| features[:, bad_cols] = 0.0 | |
| train_val_mask = sets != "test" # TMP fix should be "validation" ? | |
| train_X = features[train_val_mask] | |
| train_y = labels[train_val_mask] | |
| test_mask = sets == "test" | |
| val_X = features[test_mask] | |
| val_y = labels[test_mask] | |
| task_config = { | |
| "NR-AR": { | |
| "max_depth": 4, | |
| "min_child_weight": 1.1005779061921914, | |
| "gamma": 0.1317988706679324, | |
| "learning_rate": 0.039645108160965156, | |
| "subsample": 0.7296241662412439, | |
| "colsample_bytree": 0.8021365422870282, | |
| "reg_alpha": 3.3237336705963336e-06, | |
| "reg_lambda": 0.5602005185114373, | |
| "colsample_bylevel": 0.6436881915714322, | |
| "max_bin": 320, | |
| "grow_policy": "depthwise", | |
| "var_threshold": 0.007666987709838448 | |
| }, | |
| "NR-AR-LBD": { | |
| "max_depth": 4, | |
| "min_child_weight": 4.1987212703698695, | |
| "gamma": 1.2762015931613548, | |
| "learning_rate": 0.15154599977311695, | |
| "subsample": 0.6695940698634157, | |
| "colsample_bytree": 0.7739932636137854, | |
| "reg_alpha": 0.07898626960219088, | |
| "reg_lambda": 8.571012949754111, | |
| "colsample_bylevel": 0.9853057670318977, | |
| "max_bin": 512, | |
| "grow_policy": "lossguide", | |
| "var_threshold": 0.00037667540735397795 | |
| }, | |
| "NR-AhR": { | |
| "max_depth": 5, | |
| "min_child_weight": 6.689827023187083, | |
| "gamma": 0.05246277760115231, | |
| "learning_rate": 0.04756606141238733, | |
| "subsample": 0.8679211962117436, | |
| "colsample_bytree": 0.6095873089337578, | |
| "reg_alpha": 2.9267916989096844e-05, | |
| "reg_lambda": 0.16597411475484836, | |
| "colsample_bylevel": 0.6109587378961451, | |
| "max_bin": 192, | |
| "grow_policy": "lossguide", | |
| "var_threshold": 0.006450426707708987 | |
| }, | |
| "NR-Aromatase": { | |
| "max_depth": 3, | |
| "min_child_weight": 3.2876314247596152, | |
| "gamma": 0.19699266508924895, | |
| "learning_rate": 0.05088088932843542, | |
| "subsample": 0.7865649204014827, | |
| "colsample_bytree": 0.7251861382401115, | |
| "reg_alpha": 1.5663141562519894e-05, | |
| "reg_lambda": 0.8079227014059855, | |
| "colsample_bylevel": 0.6264563203168154, | |
| "max_bin": 320, | |
| "grow_policy": "lossguide", | |
| "var_threshold": 0.008210794229202779 | |
| }, | |
| "NR-ER": { | |
| "max_depth": 4, | |
| "min_child_weight": 5.780102015649284, | |
| "gamma": 1.4129142474001934, | |
| "learning_rate": 0.030962338755374925, | |
| "subsample": 0.6495287204129598, | |
| "colsample_bytree": 0.6052286799267346, | |
| "reg_alpha": 2.350761568396455e-08, | |
| "reg_lambda": 0.09630529926179951, | |
| "colsample_bylevel": 0.7431813327243276, | |
| "max_bin": 384, | |
| "grow_policy": "lossguide", | |
| "var_threshold": 0.0023810780862365695 | |
| }, | |
| "NR-ER-LBD": { | |
| "max_depth": 5, | |
| "min_child_weight": 9.173052917805649, | |
| "gamma": 1.0722539699322629, | |
| "learning_rate": 0.04237749698413915, | |
| "subsample": 0.7066072339657229, | |
| "colsample_bytree": 0.6813795582720684, | |
| "reg_alpha": 0.00023207537137377197, | |
| "reg_lambda": 15.088634424806914, | |
| "colsample_bylevel": 0.7799437417755278, | |
| "max_bin": 384, | |
| "grow_policy": "depthwise", | |
| "var_threshold": 0.0019169350680113165 | |
| }, | |
| "NR-PPAR-gamma": { | |
| "max_depth": 6, | |
| "min_child_weight": 5.174007598815524, | |
| "gamma": 1.9912192366255241, | |
| "learning_rate": 0.05540828755212913, | |
| "subsample": 0.6903953157523113, | |
| "colsample_bytree": 0.8663027348173384, | |
| "reg_alpha": 2.083339410970234e-08, | |
| "reg_lambda": 0.015396790332761562, | |
| "colsample_bylevel": 0.9751745752733803, | |
| "max_bin": 320, | |
| "grow_policy": "lossguide", | |
| "var_threshold": 0.0029616070252124786 | |
| }, | |
| "SR-ARE": { | |
| "max_depth": 7, | |
| "min_child_weight": 9.1659526731455, | |
| "gamma": 0.697265411436678, | |
| "learning_rate": 0.06570769871964029, | |
| "subsample": 0.9905868520803529, | |
| "colsample_bytree": 0.9320468198902392, | |
| "reg_alpha": 0.0015832053017691588, | |
| "reg_lambda": 0.05920338550334178, | |
| "colsample_bylevel": 0.9881491817036743, | |
| "max_bin": 128, | |
| "grow_policy": "lossguide", | |
| "var_threshold": 0.002817440527458996 | |
| }, | |
| "SR-ATAD5": { | |
| "max_depth": 8, | |
| "min_child_weight": 3.840348891355251, | |
| "gamma": 1.6154505675458388, | |
| "learning_rate": 0.13247082849598005, | |
| "subsample": 0.8051455662822469, | |
| "colsample_bytree": 0.8812075918541051, | |
| "reg_alpha": 1.0831755964182738e-08, | |
| "reg_lambda": 27.095693383578947, | |
| "colsample_bylevel": 0.636617995280427, | |
| "max_bin": 256, | |
| "grow_policy": "depthwise", | |
| "var_threshold": 0.009669430411280284 | |
| }, | |
| "SR-HSE": { | |
| "max_depth": 9, | |
| "min_child_weight": 6.413184249228777, | |
| "gamma": 1.033704331418744, | |
| "learning_rate": 0.05274739499143931, | |
| "subsample": 0.8865620043291726, | |
| "colsample_bytree": 0.6816866072800449, | |
| "reg_alpha": 0.058835365152010946, | |
| "reg_lambda": 0.020754661410877756, | |
| "colsample_bylevel": 0.9110208090854688, | |
| "max_bin": 512, | |
| "grow_policy": "lossguide", | |
| "var_threshold": 0.005674926071804129 | |
| }, | |
| "SR-MMP": { | |
| "max_depth": 5, | |
| "min_child_weight": 9.817728618387365, | |
| "gamma": 1.174192311657815, | |
| "learning_rate": 0.0469463693712702, | |
| "subsample": 0.7551958380501903, | |
| "colsample_bytree": 0.7909988895785574, | |
| "reg_alpha": 0.00015815798249652454, | |
| "reg_lambda": 0.07975430070894152, | |
| "colsample_bylevel": 0.6649592956153568, | |
| "max_bin": 128, | |
| "grow_policy": "depthwise", | |
| "var_threshold": 0.006024127982297082 | |
| }, | |
| "SR-p53": { | |
| "max_depth": 8, | |
| "min_child_weight": 5.038486734836349, | |
| "gamma": 1.807085258740345, | |
| "learning_rate": 0.1096533837056875, | |
| "subsample": 0.71588646279992, | |
| "colsample_bytree": 0.8086559814485024, | |
| "reg_alpha": 3.864250735509029e-08, | |
| "reg_lambda": 0.03548737332001143, | |
| "colsample_bylevel": 0.7740614694930106, | |
| "max_bin": 128, | |
| "grow_policy": "depthwise", | |
| "var_threshold": 0.008637178477182731 | |
| }, | |
| } | |
| results = {} | |
| for i, task in enumerate(task_config.keys()): | |
| npos = np.nansum(train_y[:, i]) | |
| nneg = np.sum(~np.isnan(train_y[:, i])) - npos | |
| task_config[task].update({ | |
| "tree_method": "hist", | |
| "n_estimators": 10_000, | |
| "early_stopping_rounds": 50, | |
| "eval_metric": "auc", | |
| "scale_pos_weight": nneg / max(npos, 1), | |
| "device": "cpu", | |
| }) | |
| model = Tox21XGBClassifier(seed=SEED, task_configs=task_config) | |
| print("Start training.") | |
| for i, task in enumerate(model.tasks): | |
| #print(model.model[task]) | |
| # Training ----------------------- | |
| task_labels = train_y[:, i] | |
| label_mask = ~np.isnan(task_labels) | |
| task_data = train_X[label_mask] | |
| task_labels = task_labels[label_mask].astype(int) | |
| # Remove low variance features and scale | |
| var_thresh = VarianceThreshold(threshold=task_config[task]["var_threshold"]) | |
| task_data = var_thresh.fit_transform(task_data) | |
| scaler = StandardScaler() | |
| task_data = scaler.fit_transform(task_data) | |
| model.feature_processors[task] = { | |
| "selector": var_thresh, | |
| "scaler": scaler, | |
| } | |
| # From X_train split 10% for an early stopping validation set | |
| np.random.seed(SEED) | |
| random_numbers = np.random.rand(task_data.shape[0]) | |
| es_val_mask = random_numbers < 0.1 | |
| es_train_mask = random_numbers >= 0.1 | |
| X_es_val, y_es_val = task_data[es_val_mask], task_labels[es_val_mask] | |
| X_es_train, y_es_train = task_data[es_train_mask], task_labels[es_train_mask] | |
| print(f"Fit task {task} using {sum(label_mask)} samples and {task_data.shape[1]} features") | |
| model.fit(task, X_es_train, y_es_train, eval_set=[(X_es_val, y_es_val)], verbose=False) | |
| # Evaluation ----------------------- | |
| val_task_labels = val_y[:, i] | |
| val_label_mask = ~np.isnan(val_task_labels) | |
| val_task_labels = val_task_labels[val_label_mask].astype(int) | |
| val_task_data = val_X[val_label_mask] | |
| val_task_data = model.feature_processors[task]["selector"].transform(val_task_data) | |
| val_task_data = model.feature_processors[task]["scaler"].transform(val_task_data) | |
| # Evaluate model | |
| pred = model.predict(task, val_task_data) | |
| results[task] = [roc_auc_score(y_true=val_task_labels, y_score=pred)] | |
| print(f"Save model under {args.model_dir}") | |
| model.save_model(args.model_dir) | |
| print("Results:") | |
| print(tabulate(results, headers="keys")) | |
| print("Average: ", sum([val[0] for val in results.values()]) / len(results)) | |
| if __name__ == "__main__": | |
| args = parser.parse_args() | |
| main(args) | |