File size: 5,199 Bytes
5a87d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import sys
sys.path.append("./BranchSBM")
import yaml
import string
import secrets
import os
import torch
import wandb
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from torchdyn.core import NeuralODE
from utils import plot_images_trajectory
from networks.utils import flow_model_torch_wrapper
def load_config(path):
with open(path, "r") as file:
config = yaml.safe_load(file)
return config
def merge_config(args, config_updates):
for key, value in config_updates.items():
if not hasattr(args, key):
raise ValueError(
f"Unknown configuration parameter '{key}' found in the config file."
)
setattr(args, key, value)
return args
def generate_group_string(length=16):
alphabet = string.ascii_letters + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))
def dataset_name2datapath(dataset_name, working_dir):
if dataset_name in ["lidar", "lidarsingle"]:
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "rainier2-thin.las")
elif dataset_name == "mouse":
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "mouse_hematopoiesis.csv")
elif dataset_name in ["clonidine50D", "clonidine100D", "clonidine150D", "clonidine50Dsingle", "clonidine100Dsingle", "clonidine150Dsingle"]:
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "pca_and_leiden_labels.csv")
elif dataset_name in ["trametinib", "trametinibsingle"]:
return os.path.join(working_dir, "/raid/st512/branchsbm/data", "Trametinib_5.0uM_pca_and_leidenumap_labels.csv")
else:
raise ValueError("Dataset not recognized")
def create_callbacks(args, phase, data_type, run_id, datamodule=None):
dirpath = os.path.join(
args.working_dir,
"checkpoints",
data_type,
str(run_id),
f"{phase}_model",
)
if phase == "geopath":
early_stop_callback = EarlyStopping(
monitor="BranchPathNet/val_loss_geopath",
patience=args.patience_geopath,
mode="min",
)
checkpoint_callback = ModelCheckpoint(
dirpath=dirpath,
monitor="BranchPathNet/val_loss_geopath",
mode="min",
save_top_k=1,
)
callbacks = [checkpoint_callback, early_stop_callback]
elif phase == "flow":
early_stop_callback = EarlyStopping(
monitor="FlowNet/val_loss_cfm",
patience=args.patience,
mode="min",
)
checkpoint_callback = ModelCheckpoint(
dirpath=dirpath,
mode="min",
save_top_k=1,
)
callbacks = [checkpoint_callback, early_stop_callback]
elif phase == "growth":
early_stop_callback = EarlyStopping(
monitor="GrowthNet/val_loss",
patience=args.patience,
mode="min",
)
checkpoint_callback = ModelCheckpoint(
dirpath=dirpath,
mode="min",
save_top_k=1,
)
callbacks = [checkpoint_callback, early_stop_callback]
elif phase == "joint":
early_stop_callback = EarlyStopping(
monitor="JointTrain/val_loss",
patience=args.patience,
mode="min",
)
checkpoint_callback = ModelCheckpoint(
dirpath=dirpath,
mode="min",
save_top_k=1,
)
callbacks = [checkpoint_callback, early_stop_callback]
else:
raise ValueError("Unknown phase")
return callbacks
class PlottingCallback(Callback):
def __init__(self, plot_interval, datamodule):
self.plot_interval = plot_interval
self.datamodule = datamodule
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
pl_module.flow_net.train(mode=False)
if epoch % self.plot_interval == 0 and epoch != 0:
node = NeuralODE(
flow_model_torch_wrapper(pl_module.flow_net).to(self.datamodule.device),
solver="tsit5",
sensitivity="adjoint",
atol=1e-5,
rtol=1e-5,
)
for mode in ["train", "val"]:
x0 = getattr(self.datamodule, f"{mode}_x0")
x0 = x0[0:15]
fig = self.trajectory_and_plot(x0, node, self.datamodule)
wandb.log({f"Trajectories {mode.capitalize()}": wandb.Image(fig)})
pl_module.flow_net.train(mode=True)
def trajectory_and_plot(self, x0, node, datamodule):
selected_images = x0[0:15]
with torch.no_grad():
traj = node.trajectory(
selected_images.to(datamodule.device),
t_span=torch.linspace(0, 1, 100).to(datamodule.device),
)
traj = traj.transpose(0, 1)
traj = traj.reshape(*traj.shape[0:2], *datamodule.dim)
fig = plot_images_trajectory(
traj.to(datamodule.device),
datamodule.vae.to(datamodule.device),
datamodule.process,
num_steps=5,
)
return fig |