antoniaebner's picture
adapt load/saving, preprocessing, app, readme, modelcard
97697e0
"""
Script for fitting and saving any preprocessing assets, as well as the fitted RF model
"""
import os
import json
import random
import logging
import argparse
import joblib
import numpy as np
from datetime import datetime
from src.model import Tox21RFClassifier
from src.preprocess import FeaturePreprocessor
from src.utils import create_dir, normalize_config
parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
parser.add_argument(
"--config",
type=str,
default="config/config.json",
)
def main(config):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# setup logger
logger = logging.getLogger(__name__)
script_name = os.path.splitext(os.path.basename(__file__))[0]
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(
os.path.join(
config["log_folder"],
f"{script_name}_{timestamp}.log",
)
),
logging.StreamHandler(),
],
)
logger.info(f"Config: {config}")
model_config_repr = "Model config: \n" + "\n".join(
[str(val) for val in config["model_config"].values()]
)
logger.info(f"Model config: \n{model_config_repr}")
# seeding
random.seed(config["seed"])
np.random.seed(config["seed"])
train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))
# filter out unsanitized molecules
train_is_clean = train_data["clean_mol_mask"]
val_is_clean = val_data["clean_mol_mask"]
train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}
if config["merge_train_val"]:
data = {
descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
for descr in config["descriptors"]
}
labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
else:
data = {descr: train_data[descr] for descr in config["descriptors"]}
labels = train_data["labels"]
if config["ckpt_path"]:
logger.info(
f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}"
)
else:
logger.info("Fitted RandomForestClassifier will NOT be saved.")
model = Tox21RFClassifier(seed=config["seed"], config=config["model_config"])
# setup processors
preprocessor = FeaturePreprocessor(
feature_selection_config=config["feature_selection"],
feature_quantilization_config=config["feature_quantilization"],
descriptors=config["descriptors"],
max_samples=config["max_samples"],
scaler=config["scaler"],
)
preprocessor.fit(data)
logger.info("Start training.")
for i, task in enumerate(model.tasks):
task_labels = labels[:, i]
label_mask = ~np.isnan(task_labels)
logger.info(f"Fit task {task} using {sum(label_mask)} samples")
task_data = {key: val[label_mask] for key, val in data.items()}
task_labels = task_labels[label_mask].astype(int)
task_data = preprocessor.transform(task_data)
model.fit(task, task_data, task_labels)
if config["debug"]:
break
log_text = f"Finished training."
logger.info(log_text)
if config["ckpt_path"]:
model.save(config["ckpt_path"])
logger.info(f"Save model as: {config['ckpt_path']}")
if config["preprocessor_path"]:
state = preprocessor.get_state()
joblib.dump(state, config["preprocessor_path"])
logger.info(f"Save preprocessor as: {config['preprocessor_path']}")
if __name__ == "__main__":
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
config = normalize_config(config)
create_dir(config["log_folder"])
main(config)