File size: 4,097 Bytes
3fd3838
 
 
 
 
db0fcf9
136540f
3fd3838
 
 
89c9197
3fd3838
 
 
 
89c9197
 
3fd3838
 
 
 
db0fcf9
3fd3838
db0fcf9
3fd3838
 
db0fcf9
1994acc
3fd3838
 
 
 
 
 
 
 
 
 
 
1994acc
3fd3838
 
 
 
 
 
 
1994acc
97697e0
60782a2
db0fcf9
97697e0
3fd3838
 
1994acc
 
 
 
 
 
 
 
 
 
 
 
60782a2
 
 
 
 
 
 
 
 
1994acc
 
3fd3838
1994acc
3fd3838
 
 
 
60782a2
1994acc
 
89c9197
1994acc
 
 
 
 
9b322e1
1994acc
3fd3838
 
 
 
 
db0fcf9
3fd3838
1994acc
3fd3838
 
1994acc
3fd3838
1994acc
9fabbe2
3fd3838
 
 
 
1994acc
97697e0
1994acc
3fd3838
60782a2
 
 
 
 
3fd3838
 
 
 
db0fcf9
1994acc
 
db0fcf9
1994acc
3fd3838
1994acc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Script for fitting and saving any preprocessing assets, as well as the fitted RF model
"""

import os
import json
import random
import logging
import argparse

import joblib
import numpy as np
from datetime import datetime

from src.model import Tox21RFClassifier
from src.preprocess import FeaturePreprocessor
from src.utils import create_dir, normalize_config

parser = argparse.ArgumentParser(description="RF Training script for Tox21 dataset")

parser.add_argument(
    "--config",
    type=str,
    default="config/config.json",
)


def main(config):
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # setup logger
    logger = logging.getLogger(__name__)
    script_name = os.path.splitext(os.path.basename(__file__))[0]
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.FileHandler(
                os.path.join(
                    config["log_folder"],
                    f"{script_name}_{timestamp}.log",
                )
            ),
            logging.StreamHandler(),
        ],
    )

    logger.info(f"Config: {config}")
    model_config_repr = "Model config: \n" + "\n".join(
        [str(val) for val in config["model_config"].values()]
    )
    logger.info(f"Model config: \n{model_config_repr}")

    # seeding
    random.seed(config["seed"])
    np.random.seed(config["seed"])

    train_data = np.load(os.path.join(config["data_folder"], "tox21_train_cv4.npz"))
    val_data = np.load(os.path.join(config["data_folder"], "tox21_validation_cv4.npz"))

    # filter out unsanitized molecules
    train_is_clean = train_data["clean_mol_mask"]
    val_is_clean = val_data["clean_mol_mask"]
    train_data = {descr: array[train_is_clean] for descr, array in train_data.items()}
    val_data = {descr: array[val_is_clean] for descr, array in val_data.items()}

    if config["merge_train_val"]:
        data = {
            descr: np.concatenate([train_data[descr], val_data[descr]], axis=0)
            for descr in config["descriptors"]
        }
        labels = np.concatenate([train_data["labels"], val_data["labels"]], axis=0)
    else:
        data = {descr: train_data[descr] for descr in config["descriptors"]}
        labels = train_data["labels"]

    if config["ckpt_path"]:
        logger.info(
            f"Fitted RandomForestClassifier will be saved as: {config['ckpt_path']}"
        )
    else:
        logger.info("Fitted RandomForestClassifier will NOT be saved.")

    model = Tox21RFClassifier(seed=config["seed"], config=config["model_config"])

    # setup processors
    preprocessor = FeaturePreprocessor(
        feature_selection_config=config["feature_selection"],
        feature_quantilization_config=config["feature_quantilization"],
        descriptors=config["descriptors"],
        max_samples=config["max_samples"],
        scaler=config["scaler"],
    )
    preprocessor.fit(data)

    logger.info("Start training.")
    for i, task in enumerate(model.tasks):
        task_labels = labels[:, i]
        label_mask = ~np.isnan(task_labels)
        logger.info(f"Fit task {task} using {sum(label_mask)} samples")

        task_data = {key: val[label_mask] for key, val in data.items()}
        task_labels = task_labels[label_mask].astype(int)

        task_data = preprocessor.transform(task_data)
        model.fit(task, task_data, task_labels)
        if config["debug"]:
            break

    log_text = f"Finished training."
    logger.info(log_text)

    if config["ckpt_path"]:
        model.save(config["ckpt_path"])
        logger.info(f"Save model as: {config['ckpt_path']}")

    if config["preprocessor_path"]:
        state = preprocessor.get_state()
        joblib.dump(state, config["preprocessor_path"])
        logger.info(f"Save preprocessor as: {config['preprocessor_path']}")


if __name__ == "__main__":
    args = parser.parse_args()

    with open(args.config, "r") as f:
        config = json.load(f)
    config = normalize_config(config)

    create_dir(config["log_folder"])

    main(config)