| import sys | |
| sys.path.append("./BranchSBM") | |
| import yaml | |
| import string | |
| import secrets | |
| import os | |
| import torch | |
| import wandb | |
| from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint | |
| from torchdyn.core import NeuralODE | |
| from utils import plot_images_trajectory | |
| from networks.utils import flow_model_torch_wrapper | |
| def load_config(path): | |
| with open(path, "r") as file: | |
| config = yaml.safe_load(file) | |
| return config | |
| def merge_config(args, config_updates): | |
| for key, value in config_updates.items(): | |
| if not hasattr(args, key): | |
| raise ValueError( | |
| f"Unknown configuration parameter '{key}' found in the config file." | |
| ) | |
| setattr(args, key, value) | |
| return args | |
| def generate_group_string(length=16): | |
| alphabet = string.ascii_letters + string.digits | |
| return "".join(secrets.choice(alphabet) for _ in range(length)) | |
| def dataset_name2datapath(dataset_name, working_dir): | |
| if dataset_name in ["lidar", "lidarsingle"]: | |
| return os.path.join(working_dir, "/raid/st512/branchsbm/data", "rainier2-thin.las") | |
| elif dataset_name == "mouse": | |
| return os.path.join(working_dir, "/raid/st512/branchsbm/data", "mouse_hematopoiesis.csv") | |
| elif dataset_name in ["clonidine50D", "clonidine100D", "clonidine150D", "clonidine50Dsingle", "clonidine100Dsingle", "clonidine150Dsingle"]: | |
| return os.path.join(working_dir, "/raid/st512/branchsbm/data", "pca_and_leiden_labels.csv") | |
| elif dataset_name in ["trametinib", "trametinibsingle"]: | |
| return os.path.join(working_dir, "/raid/st512/branchsbm/data", "Trametinib_5.0uM_pca_and_leidenumap_labels.csv") | |
| else: | |
| raise ValueError("Dataset not recognized") | |
| def create_callbacks(args, phase, data_type, run_id, datamodule=None): | |
| dirpath = os.path.join( | |
| args.working_dir, | |
| "checkpoints", | |
| data_type, | |
| str(run_id), | |
| f"{phase}_model", | |
| ) | |
| if phase == "geopath": | |
| early_stop_callback = EarlyStopping( | |
| monitor="BranchPathNet/val_loss_geopath", | |
| patience=args.patience_geopath, | |
| mode="min", | |
| ) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=dirpath, | |
| monitor="BranchPathNet/val_loss_geopath", | |
| mode="min", | |
| save_top_k=1, | |
| ) | |
| callbacks = [checkpoint_callback, early_stop_callback] | |
| elif phase == "flow": | |
| early_stop_callback = EarlyStopping( | |
| monitor="FlowNet/val_loss_cfm", | |
| patience=args.patience, | |
| mode="min", | |
| ) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=dirpath, | |
| mode="min", | |
| save_top_k=1, | |
| ) | |
| callbacks = [checkpoint_callback, early_stop_callback] | |
| elif phase == "growth": | |
| early_stop_callback = EarlyStopping( | |
| monitor="GrowthNet/val_loss", | |
| patience=args.patience, | |
| mode="min", | |
| ) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=dirpath, | |
| mode="min", | |
| save_top_k=1, | |
| ) | |
| callbacks = [checkpoint_callback, early_stop_callback] | |
| elif phase == "joint": | |
| early_stop_callback = EarlyStopping( | |
| monitor="JointTrain/val_loss", | |
| patience=args.patience, | |
| mode="min", | |
| ) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=dirpath, | |
| mode="min", | |
| save_top_k=1, | |
| ) | |
| callbacks = [checkpoint_callback, early_stop_callback] | |
| else: | |
| raise ValueError("Unknown phase") | |
| return callbacks | |
| class PlottingCallback(Callback): | |
| def __init__(self, plot_interval, datamodule): | |
| self.plot_interval = plot_interval | |
| self.datamodule = datamodule | |
| def on_train_epoch_end(self, trainer, pl_module): | |
| epoch = trainer.current_epoch | |
| pl_module.flow_net.train(mode=False) | |
| if epoch % self.plot_interval == 0 and epoch != 0: | |
| node = NeuralODE( | |
| flow_model_torch_wrapper(pl_module.flow_net).to(self.datamodule.device), | |
| solver="tsit5", | |
| sensitivity="adjoint", | |
| atol=1e-5, | |
| rtol=1e-5, | |
| ) | |
| for mode in ["train", "val"]: | |
| x0 = getattr(self.datamodule, f"{mode}_x0") | |
| x0 = x0[0:15] | |
| fig = self.trajectory_and_plot(x0, node, self.datamodule) | |
| wandb.log({f"Trajectories {mode.capitalize()}": wandb.Image(fig)}) | |
| pl_module.flow_net.train(mode=True) | |
| def trajectory_and_plot(self, x0, node, datamodule): | |
| selected_images = x0[0:15] | |
| with torch.no_grad(): | |
| traj = node.trajectory( | |
| selected_images.to(datamodule.device), | |
| t_span=torch.linspace(0, 1, 100).to(datamodule.device), | |
| ) | |
| traj = traj.transpose(0, 1) | |
| traj = traj.reshape(*traj.shape[0:2], *datamodule.dim) | |
| fig = plot_images_trajectory( | |
| traj.to(datamodule.device), | |
| datamodule.vae.to(datamodule.device), | |
| datamodule.process, | |
| num_steps=5, | |
| ) | |
| return fig |