""" Script for fitting and saving any preprocessing assets, as well as the fitted RF model """ import os import json import random import logging import argparse import joblib import numpy as np from datetime import datetime from src.model import Tox21RFClassifier from src.preprocess import FeaturePreprocessor from src.utils import create_dir, normalize_config parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset") parser.add_argument( "--config", type=str, default="config/config.json", ) def main(config): timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # setup logger logger = logging.getLogger(__name__) script_name = os.path.splitext(os.path.basename(__file__))[0] logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler( os.path.join( config["log_folder"], f"{script_name}_{timestamp}.log", ) ), logging.StreamHandler(), ], ) logger.info(f"Config: {config}") model_config_repr = "Model config: \n" + "\n".join( [str(val) for val in config["model_config"].values()] ) logger.info(f"Model config: \n{model_config_repr}") # seeding random.seed(config["seed"]) np.random.seed(config["seed"]) train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz")) val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz")) # filter out unsanitized molecules train_is_clean = train_data["clean_mol_mask"] val_is_clean = val_data["clean_mol_mask"] train_data = {descr: array[train_is_clean] for descr, array in train_data.items()} val_data = {descr: array[val_is_clean] for descr, array in val_data.items()} if config["merge_train_val"]: data = { descr: np.concatenate([train_data[descr], val_data[descr]], axis=0) for descr in config["descriptors"] } labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0) else: data = {descr: train_data[descr] for descr in config["descriptors"]} labels = train_data["labels"] if config["ckpt_path"]: logger.info( f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}" ) else: logger.info("Fitted RandomForestClassifier will NOT be saved.") model = Tox21RFClassifier(seed=config["seed"], config=config["model_config"]) # setup processors preprocessor = FeaturePreprocessor( feature_selection_config=config["feature_selection"], feature_quantilization_config=config["feature_quantilization"], descriptors=config["descriptors"], max_samples=config["max_samples"], scaler=config["scaler"], ) preprocessor.fit(data) logger.info("Start training.") for i, task in enumerate(model.tasks): task_labels = labels[:, i] label_mask = ~np.isnan(task_labels) logger.info(f"Fit task {task} using {sum(label_mask)} samples") task_data = {key: val[label_mask] for key, val in data.items()} task_labels = task_labels[label_mask].astype(int) task_data = preprocessor.transform(task_data) model.fit(task, task_data, task_labels) if config["debug"]: break log_text = f"Finished training." logger.info(log_text) if config["ckpt_path"]: model.save(config["ckpt_path"]) logger.info(f"Save model as: {config['ckpt_path']}") if config["preprocessor_path"]: state = preprocessor.get_state() joblib.dump(state, config["preprocessor_path"]) logger.info(f"Save preprocessor as: {config['preprocessor_path']}") if __name__ == "__main__": args = parser.parse_args() with open(args.config, "r") as f: config = json.load(f) config = normalize_config(config) create_dir(config["log_folder"]) main(config)