File size: 5,199 Bytes
5a87d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import sys
sys.path.append("./BranchSBM")
import yaml
import string
import secrets
import os
import torch
import wandb
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from torchdyn.core import NeuralODE
from utils import plot_images_trajectory
from networks.utils import flow_model_torch_wrapper


def load_config(path):
    with open(path, "r") as file:
        config = yaml.safe_load(file)
    return config


def merge_config(args, config_updates):
    for key, value in config_updates.items():
        if not hasattr(args, key):
            raise ValueError(
                f"Unknown configuration parameter '{key}' found in the config file."
            )
        setattr(args, key, value)
    return args


def generate_group_string(length=16):
    alphabet = string.ascii_letters + string.digits
    return "".join(secrets.choice(alphabet) for _ in range(length))


def dataset_name2datapath(dataset_name, working_dir):
    if dataset_name in ["lidar", "lidarsingle"]:
        return os.path.join(working_dir, "/raid/st512/branchsbm/data", "rainier2-thin.las")
    elif dataset_name == "mouse":
        return os.path.join(working_dir, "/raid/st512/branchsbm/data", "mouse_hematopoiesis.csv")
    elif dataset_name in ["clonidine50D", "clonidine100D", "clonidine150D", "clonidine50Dsingle", "clonidine100Dsingle", "clonidine150Dsingle"]:
        return os.path.join(working_dir, "/raid/st512/branchsbm/data", "pca_and_leiden_labels.csv")
    elif dataset_name in ["trametinib", "trametinibsingle"]:
        return os.path.join(working_dir, "/raid/st512/branchsbm/data", "Trametinib_5.0uM_pca_and_leidenumap_labels.csv")
    else:
        raise ValueError("Dataset not recognized")


def create_callbacks(args, phase, data_type, run_id, datamodule=None):

    dirpath = os.path.join(
        args.working_dir,
        "checkpoints",
        data_type,
        str(run_id),
        f"{phase}_model",
    )

    if phase == "geopath":
        early_stop_callback = EarlyStopping(
            monitor="BranchPathNet/val_loss_geopath",
            patience=args.patience_geopath,
            mode="min",
        )
        checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            monitor="BranchPathNet/val_loss_geopath",
            mode="min",
            save_top_k=1,
        )
        callbacks = [checkpoint_callback, early_stop_callback]
    elif phase == "flow":
        early_stop_callback = EarlyStopping(
            monitor="FlowNet/val_loss_cfm",
            patience=args.patience,
            mode="min",
        )
        checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            mode="min",
            save_top_k=1,
        )
        callbacks = [checkpoint_callback, early_stop_callback]
    elif phase == "growth":
        early_stop_callback = EarlyStopping(
            monitor="GrowthNet/val_loss",
            patience=args.patience,
            mode="min",
        )
        checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            mode="min",
            save_top_k=1,
        )
        callbacks = [checkpoint_callback, early_stop_callback]
    elif phase == "joint":
        early_stop_callback = EarlyStopping(
            monitor="JointTrain/val_loss",
            patience=args.patience,
            mode="min",
        )
        checkpoint_callback = ModelCheckpoint(
            dirpath=dirpath,
            mode="min",
            save_top_k=1,
        )
        callbacks = [checkpoint_callback, early_stop_callback]
    else:
        raise ValueError("Unknown phase")
    return callbacks


class PlottingCallback(Callback):
    def __init__(self, plot_interval, datamodule):
        self.plot_interval = plot_interval
        self.datamodule = datamodule

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        pl_module.flow_net.train(mode=False)
        if epoch % self.plot_interval == 0 and epoch != 0:
            node = NeuralODE(
                flow_model_torch_wrapper(pl_module.flow_net).to(self.datamodule.device),
                solver="tsit5",
                sensitivity="adjoint",
                atol=1e-5,
                rtol=1e-5,
            )

            for mode in ["train", "val"]:
                x0 = getattr(self.datamodule, f"{mode}_x0")
                x0 = x0[0:15]
                fig = self.trajectory_and_plot(x0, node, self.datamodule)
                wandb.log({f"Trajectories {mode.capitalize()}": wandb.Image(fig)})
        pl_module.flow_net.train(mode=True)

    def trajectory_and_plot(self, x0, node, datamodule):
        selected_images = x0[0:15]
        with torch.no_grad():
            traj = node.trajectory(
                selected_images.to(datamodule.device),
                t_span=torch.linspace(0, 1, 100).to(datamodule.device),
            )

        traj = traj.transpose(0, 1)
        traj = traj.reshape(*traj.shape[0:2], *datamodule.dim)

        fig = plot_images_trajectory(
            traj.to(datamodule.device),
            datamodule.vae.to(datamodule.device),
            datamodule.process,
            num_steps=5,
        )
        return fig