antoniaebner's picture
Update src/train.py
c954357 verified
"""
Script for fitting and saving any preprocessing assets, as well as the fitted XGBoost model
"""
import os
import argparse
import numpy as np
from tabulate import tabulate
from sklearn.feature_selection import VarianceThreshold
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from model import Tox21XGBClassifier
SEED = 999
DATA_FOLDER = "data/"
parser = argparse.ArgumentParser(description="XGBoost Training script for Tox21 dataset")
parser.add_argument(
"--model_dir",
type=str,
default="assets",
)
def main(args):
print("Preprocess train molecules")
data_path = os.path.join(DATA_FOLDER, "tox21_data.npz")
full_data = np.load(data_path, allow_pickle=True)
features = full_data["features"]
labels = full_data["labels"]
sets = full_data["sets"]
# Handle inf/nan features: instead of dropping columns, zero-out entire affected columns
# so that VarianceThreshold will remove them later, keeping indices aligned.
bad_entries = np.isinf(features) | np.isnan(features)
bad_cols = np.any(bad_entries, axis=0)
if np.any(bad_cols):
features[:, bad_cols] = 0.0
train_val_mask = sets != "test" # TMP fix should be "validation" ?
train_X = features[train_val_mask]
train_y = labels[train_val_mask]
test_mask = sets == "test"
val_X = features[test_mask]
val_y = labels[test_mask]
task_config = {
"NR-AR": {
"max_depth": 4,
"min_child_weight": 1.1005779061921914,
"gamma": 0.1317988706679324,
"learning_rate": 0.039645108160965156,
"subsample": 0.7296241662412439,
"colsample_bytree": 0.8021365422870282,
"reg_alpha": 3.3237336705963336e-06,
"reg_lambda": 0.5602005185114373,
"colsample_bylevel": 0.6436881915714322,
"max_bin": 320,
"grow_policy": "depthwise",
"var_threshold": 0.007666987709838448
},
"NR-AR-LBD": {
"max_depth": 4,
"min_child_weight": 4.1987212703698695,
"gamma": 1.2762015931613548,
"learning_rate": 0.15154599977311695,
"subsample": 0.6695940698634157,
"colsample_bytree": 0.7739932636137854,
"reg_alpha": 0.07898626960219088,
"reg_lambda": 8.571012949754111,
"colsample_bylevel": 0.9853057670318977,
"max_bin": 512,
"grow_policy": "lossguide",
"var_threshold": 0.00037667540735397795
},
"NR-AhR": {
"max_depth": 5,
"min_child_weight": 6.689827023187083,
"gamma": 0.05246277760115231,
"learning_rate": 0.04756606141238733,
"subsample": 0.8679211962117436,
"colsample_bytree": 0.6095873089337578,
"reg_alpha": 2.9267916989096844e-05,
"reg_lambda": 0.16597411475484836,
"colsample_bylevel": 0.6109587378961451,
"max_bin": 192,
"grow_policy": "lossguide",
"var_threshold": 0.006450426707708987
},
"NR-Aromatase": {
"max_depth": 3,
"min_child_weight": 3.2876314247596152,
"gamma": 0.19699266508924895,
"learning_rate": 0.05088088932843542,
"subsample": 0.7865649204014827,
"colsample_bytree": 0.7251861382401115,
"reg_alpha": 1.5663141562519894e-05,
"reg_lambda": 0.8079227014059855,
"colsample_bylevel": 0.6264563203168154,
"max_bin": 320,
"grow_policy": "lossguide",
"var_threshold": 0.008210794229202779
},
"NR-ER": {
"max_depth": 4,
"min_child_weight": 5.780102015649284,
"gamma": 1.4129142474001934,
"learning_rate": 0.030962338755374925,
"subsample": 0.6495287204129598,
"colsample_bytree": 0.6052286799267346,
"reg_alpha": 2.350761568396455e-08,
"reg_lambda": 0.09630529926179951,
"colsample_bylevel": 0.7431813327243276,
"max_bin": 384,
"grow_policy": "lossguide",
"var_threshold": 0.0023810780862365695
},
"NR-ER-LBD": {
"max_depth": 5,
"min_child_weight": 9.173052917805649,
"gamma": 1.0722539699322629,
"learning_rate": 0.04237749698413915,
"subsample": 0.7066072339657229,
"colsample_bytree": 0.6813795582720684,
"reg_alpha": 0.00023207537137377197,
"reg_lambda": 15.088634424806914,
"colsample_bylevel": 0.7799437417755278,
"max_bin": 384,
"grow_policy": "depthwise",
"var_threshold": 0.0019169350680113165
},
"NR-PPAR-gamma": {
"max_depth": 6,
"min_child_weight": 5.174007598815524,
"gamma": 1.9912192366255241,
"learning_rate": 0.05540828755212913,
"subsample": 0.6903953157523113,
"colsample_bytree": 0.8663027348173384,
"reg_alpha": 2.083339410970234e-08,
"reg_lambda": 0.015396790332761562,
"colsample_bylevel": 0.9751745752733803,
"max_bin": 320,
"grow_policy": "lossguide",
"var_threshold": 0.0029616070252124786
},
"SR-ARE": {
"max_depth": 7,
"min_child_weight": 9.1659526731455,
"gamma": 0.697265411436678,
"learning_rate": 0.06570769871964029,
"subsample": 0.9905868520803529,
"colsample_bytree": 0.9320468198902392,
"reg_alpha": 0.0015832053017691588,
"reg_lambda": 0.05920338550334178,
"colsample_bylevel": 0.9881491817036743,
"max_bin": 128,
"grow_policy": "lossguide",
"var_threshold": 0.002817440527458996
},
"SR-ATAD5": {
"max_depth": 8,
"min_child_weight": 3.840348891355251,
"gamma": 1.6154505675458388,
"learning_rate": 0.13247082849598005,
"subsample": 0.8051455662822469,
"colsample_bytree": 0.8812075918541051,
"reg_alpha": 1.0831755964182738e-08,
"reg_lambda": 27.095693383578947,
"colsample_bylevel": 0.636617995280427,
"max_bin": 256,
"grow_policy": "depthwise",
"var_threshold": 0.009669430411280284
},
"SR-HSE": {
"max_depth": 9,
"min_child_weight": 6.413184249228777,
"gamma": 1.033704331418744,
"learning_rate": 0.05274739499143931,
"subsample": 0.8865620043291726,
"colsample_bytree": 0.6816866072800449,
"reg_alpha": 0.058835365152010946,
"reg_lambda": 0.020754661410877756,
"colsample_bylevel": 0.9110208090854688,
"max_bin": 512,
"grow_policy": "lossguide",
"var_threshold": 0.005674926071804129
},
"SR-MMP": {
"max_depth": 5,
"min_child_weight": 9.817728618387365,
"gamma": 1.174192311657815,
"learning_rate": 0.0469463693712702,
"subsample": 0.7551958380501903,
"colsample_bytree": 0.7909988895785574,
"reg_alpha": 0.00015815798249652454,
"reg_lambda": 0.07975430070894152,
"colsample_bylevel": 0.6649592956153568,
"max_bin": 128,
"grow_policy": "depthwise",
"var_threshold": 0.006024127982297082
},
"SR-p53": {
"max_depth": 8,
"min_child_weight": 5.038486734836349,
"gamma": 1.807085258740345,
"learning_rate": 0.1096533837056875,
"subsample": 0.71588646279992,
"colsample_bytree": 0.8086559814485024,
"reg_alpha": 3.864250735509029e-08,
"reg_lambda": 0.03548737332001143,
"colsample_bylevel": 0.7740614694930106,
"max_bin": 128,
"grow_policy": "depthwise",
"var_threshold": 0.008637178477182731
},
}
results = {}
for i, task in enumerate(task_config.keys()):
npos = np.nansum(train_y[:, i])
nneg = np.sum(~np.isnan(train_y[:, i])) - npos
task_config[task].update({
"tree_method": "hist",
"n_estimators": 10_000,
"early_stopping_rounds": 50,
"eval_metric": "auc",
"scale_pos_weight": nneg / max(npos, 1),
"device": "cpu",
})
model = Tox21XGBClassifier(seed=SEED, task_configs=task_config)
print("Start training.")
for i, task in enumerate(model.tasks):
#print(model.model[task])
# Training -----------------------
task_labels = train_y[:, i]
label_mask = ~np.isnan(task_labels)
task_data = train_X[label_mask]
task_labels = task_labels[label_mask].astype(int)
# Remove low variance features and scale
var_thresh = VarianceThreshold(threshold=task_config[task]["var_threshold"])
task_data = var_thresh.fit_transform(task_data)
scaler = StandardScaler()
task_data = scaler.fit_transform(task_data)
model.feature_processors[task] = {
"selector": var_thresh,
"scaler": scaler,
}
# From X_train split 10% for an early stopping validation set
np.random.seed(SEED)
random_numbers = np.random.rand(task_data.shape[0])
es_val_mask = random_numbers < 0.1
es_train_mask = random_numbers >= 0.1
X_es_val, y_es_val = task_data[es_val_mask], task_labels[es_val_mask]
X_es_train, y_es_train = task_data[es_train_mask], task_labels[es_train_mask]
print(f"Fit task {task} using {sum(label_mask)} samples and {task_data.shape[1]} features")
model.fit(task, X_es_train, y_es_train, eval_set=[(X_es_val, y_es_val)], verbose=False)
# Evaluation -----------------------
val_task_labels = val_y[:, i]
val_label_mask = ~np.isnan(val_task_labels)
val_task_labels = val_task_labels[val_label_mask].astype(int)
val_task_data = val_X[val_label_mask]
val_task_data = model.feature_processors[task]["selector"].transform(val_task_data)
val_task_data = model.feature_processors[task]["scaler"].transform(val_task_data)
# Evaluate model
pred = model.predict(task, val_task_data)
results[task] = [roc_auc_score(y_true=val_task_labels, y_score=pred)]
print(f"Save model under {args.model_dir}")
model.save_model(args.model_dir)
print("Results:")
print(tabulate(results, headers="keys"))
print("Average: ", sum([val[0] for val in results.values()]) / len(results))
if __name__ == "__main__":
args = parser.parse_args()
main(args)