Spaces:
Sleeping
Sleeping
File size: 4,097 Bytes
3fd3838 db0fcf9 136540f 3fd3838 89c9197 3fd3838 89c9197 3fd3838 db0fcf9 3fd3838 db0fcf9 3fd3838 db0fcf9 1994acc 3fd3838 1994acc 3fd3838 1994acc 97697e0 60782a2 db0fcf9 97697e0 3fd3838 1994acc 60782a2 1994acc 3fd3838 1994acc 3fd3838 60782a2 1994acc 89c9197 1994acc 9b322e1 1994acc 3fd3838 db0fcf9 3fd3838 1994acc 3fd3838 1994acc 3fd3838 1994acc 9fabbe2 3fd3838 1994acc 97697e0 1994acc 3fd3838 60782a2 3fd3838 db0fcf9 1994acc db0fcf9 1994acc 3fd3838 1994acc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""
Script for fitting and saving any preprocessing assets, as well as the fitted RF model
"""
import os
import json
import random
import logging
import argparse
import joblib
import numpy as np
from datetime import datetime
from src.model import Tox21RFClassifier
from src.preprocess import FeaturePreprocessor
from src.utils import create_dir, normalize_config
parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
parser.add_argument(
"--config",
type=str,
default="config/config.json",
)
def main(config):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# setup logger
logger = logging.getLogger(__name__)
script_name = os.path.splitext(os.path.basename(__file__))[0]
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(
os.path.join(
config["log_folder"],
f"{script_name}_{timestamp}.log",
)
),
logging.StreamHandler(),
],
)
logger.info(f"Config: {config}")
model_config_repr = "Model config: \n" + "\n".join(
[str(val) for val in config["model_config"].values()]
)
logger.info(f"Model config: \n{model_config_repr}")
# seeding
random.seed(config["seed"])
np.random.seed(config["seed"])
train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))
# filter out unsanitized molecules
train_is_clean = train_data["clean_mol_mask"]
val_is_clean = val_data["clean_mol_mask"]
train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}
if config["merge_train_val"]:
data = {
descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
for descr in config["descriptors"]
}
labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
else:
data = {descr: train_data[descr] for descr in config["descriptors"]}
labels = train_data["labels"]
if config["ckpt_path"]:
logger.info(
f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}"
)
else:
logger.info("Fitted RandomForestClassifier will NOT be saved.")
model = Tox21RFClassifier(seed=config["seed"], config=config["model_config"])
# setup processors
preprocessor = FeaturePreprocessor(
feature_selection_config=config["feature_selection"],
feature_quantilization_config=config["feature_quantilization"],
descriptors=config["descriptors"],
max_samples=config["max_samples"],
scaler=config["scaler"],
)
preprocessor.fit(data)
logger.info("Start training.")
for i, task in enumerate(model.tasks):
task_labels = labels[:, i]
label_mask = ~np.isnan(task_labels)
logger.info(f"Fit task {task} using {sum(label_mask)} samples")
task_data = {key: val[label_mask] for key, val in data.items()}
task_labels = task_labels[label_mask].astype(int)
task_data = preprocessor.transform(task_data)
model.fit(task, task_data, task_labels)
if config["debug"]:
break
log_text = f"Finished training."
logger.info(log_text)
if config["ckpt_path"]:
model.save(config["ckpt_path"])
logger.info(f"Save model as: {config['ckpt_path']}")
if config["preprocessor_path"]:
state = preprocessor.get_state()
joblib.dump(state, config["preprocessor_path"])
logger.info(f"Save preprocessor as: {config['preprocessor_path']}")
if __name__ == "__main__":
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
config = normalize_config(config)
create_dir(config["log_folder"])
main(config)
|