Spaces:
Sleeping
Sleeping
| """ | |
| Script for fitting and saving any preprocessing assets, as well as the fitted RF model | |
| """ | |
| import os | |
| import json | |
| import random | |
| import logging | |
| import argparse | |
| import joblib | |
| import numpy as np | |
| from datetime import datetime | |
| from src.model import Tox21RFClassifier | |
| from src.preprocess import FeaturePreprocessor | |
| from src.utils import create_dir, normalize_config | |
| parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="config/config.json", | |
| ) | |
| def main(config): | |
| timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| # setup logger | |
| logger = logging.getLogger(__name__) | |
| script_name = os.path.splitext(os.path.basename(__file__))[0] | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[ | |
| logging.FileHandler( | |
| os.path.join( | |
| config["log_folder"], | |
| f"{script_name}_{timestamp}.log", | |
| ) | |
| ), | |
| logging.StreamHandler(), | |
| ], | |
| ) | |
| logger.info(f"Config: {config}") | |
| model_config_repr = "Model config: \n" + "\n".join( | |
| [str(val) for val in config["model_config"].values()] | |
| ) | |
| logger.info(f"Model config: \n{model_config_repr}") | |
| # seeding | |
| random.seed(config["seed"]) | |
| np.random.seed(config["seed"]) | |
| train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz")) | |
| val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz")) | |
| # filter out unsanitized molecules | |
| train_is_clean = train_data["clean_mol_mask"] | |
| val_is_clean = val_data["clean_mol_mask"] | |
| train_data = {descr: array[train_is_clean] for descr, array in train_data.items()} | |
| val_data = {descr: array[val_is_clean] for descr, array in val_data.items()} | |
| if config["merge_train_val"]: | |
| data = { | |
| descr: np.concatenate([train_data[descr], val_data[descr]], axis=0) | |
| for descr in config["descriptors"] | |
| } | |
| labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0) | |
| else: | |
| data = {descr: train_data[descr] for descr in config["descriptors"]} | |
| labels = train_data["labels"] | |
| if config["ckpt_path"]: | |
| logger.info( | |
| f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}" | |
| ) | |
| else: | |
| logger.info("Fitted RandomForestClassifier will NOT be saved.") | |
| model = Tox21RFClassifier(seed=config["seed"], config=config["model_config"]) | |
| # setup processors | |
| preprocessor = FeaturePreprocessor( | |
| feature_selection_config=config["feature_selection"], | |
| feature_quantilization_config=config["feature_quantilization"], | |
| descriptors=config["descriptors"], | |
| max_samples=config["max_samples"], | |
| scaler=config["scaler"], | |
| ) | |
| preprocessor.fit(data) | |
| logger.info("Start training.") | |
| for i, task in enumerate(model.tasks): | |
| task_labels = labels[:, i] | |
| label_mask = ~np.isnan(task_labels) | |
| logger.info(f"Fit task {task} using {sum(label_mask)} samples") | |
| task_data = {key: val[label_mask] for key, val in data.items()} | |
| task_labels = task_labels[label_mask].astype(int) | |
| task_data = preprocessor.transform(task_data) | |
| model.fit(task, task_data, task_labels) | |
| if config["debug"]: | |
| break | |
| log_text = f"Finished training." | |
| logger.info(log_text) | |
| if config["ckpt_path"]: | |
| model.save(config["ckpt_path"]) | |
| logger.info(f"Save model as: {config['ckpt_path']}") | |
| if config["preprocessor_path"]: | |
| state = preprocessor.get_state() | |
| joblib.dump(state, config["preprocessor_path"]) | |
| logger.info(f"Save preprocessor as: {config['preprocessor_path']}") | |
| if __name__ == "__main__": | |
| args = parser.parse_args() | |
| with open(args.config, "r") as f: | |
| config = json.load(f) | |
| config = normalize_config(config) | |
| create_dir(config["log_folder"]) | |
| main(config) | |