antoniaebner's picture
refactoring of feature preprocessing
1994acc
raw
history blame
3.91 kB
"""
Script for fitting and saving any preprocessing assets, as well as the fitted RF model
"""
import os
import json
import joblib
import random
import logging
import argparse
import numpy as np
from datetime import datetime
from src.model import Tox21RFClassifier
from src.preprocess import Tox21Preprocessor
from src.utils import create_dir, normalize_config,
parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")
parser.add_argument(
"--config",
type=str,
default="config/config.json",
)
def main(config):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# setup logger
logger = logging.getLogger(__name__)
script_name = os.path.splitext(os.path.basename(__file__))[0]
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(
os.path.join(
config["log_folder"],
f"{script_name}_{timestamp}.log",
)
),
logging.StreamHandler(),
],
)
logger.info(f"Config: {config}")
model_configs_repr = "Model configs: \n" + "\n".join(
[str(val) for val in config["model_configs"].values()]
)
logger.info(f"Model configs: \n{model_configs_repr}")
# seeding
random.seed(config["seed"])
np.random.seed(config["seed"])
train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))
# filter out unsanitized molecules
train_is_clean = train_data["clean_mol_mask"]
val_is_clean = val_data["clean_mol_mask"]
train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}
# combine datasets
data = {
descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
for descr in config["descriptors"]
}
labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
if config["ckpt_path"]:
logger.info(
f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}"
)
else:
logger.info("Fitted RandomForestClassifier will NOT be saved.")
model = Tox21RFClassifier(seed=config["seed"], config=config["model_configs"])
# setup processors
preprocessor = Tox21Preprocessor(
feature_selection_config=config["feature_selection"],
feature_quantilization_config=config["feature_quantilization"],
descriptors=config["descriptors"],
max_samples=config["max_samples"],
scaler=config["scaler"],
)
preprocessor.fit(data)
logger.info("Start training.")
for i, task in enumerate(model.tasks):
task_labels = labels[:, i]
label_mask = ~np.isnan(task_labels)
logger.info(f"Fit task {task} using {sum(label_mask)} samples")
task_data = {key: val[label_mask] for key, val in data.items()}
task_labels = task_labels[label_mask].astype(int)
task_data = preprocessor.transform(task_data)
model.fit(task, task_data, task_labels)
if config["debug"]:
break
log_text = f"Finished training."
logger.info(log_text)
if config["ckpt_path"]:
ckpt = {
"preprocessor": preprocessor.__getstate__(),
"models": model.get_state(),
}
# model.save_model(config["ckpt_path"])
joblib.dump(ckpt, config["ckpt_path"])
logger.info(f"Save model as: {config['ckpt_path']}")
if __name__ == "__main__":
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
config = normalize_config(config)
create_dir(config["log_folder"])
main(config)