import sys sys.path.append("./BranchSBM") import yaml import string import secrets import os import torch import wandb from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint from torchdyn.core import NeuralODE from utils import plot_images_trajectory from networks.utils import flow_model_torch_wrapper def load_config(path): with open(path, "r") as file: config = yaml.safe_load(file) return config def merge_config(args, config_updates): for key, value in config_updates.items(): if not hasattr(args, key): raise ValueError( f"Unknown configuration parameter '{key}' found in the config file." ) setattr(args, key, value) return args def generate_group_string(length=16): alphabet = string.ascii_letters + string.digits return "".join(secrets.choice(alphabet) for _ in range(length)) def dataset_name2datapath(dataset_name, working_dir): if dataset_name in ["lidar", "lidarsingle"]: return os.path.join(working_dir, "/raid/st512/branchsbm/data", "rainier2-thin.las") elif dataset_name == "mouse": return os.path.join(working_dir, "/raid/st512/branchsbm/data", "mouse_hematopoiesis.csv") elif dataset_name in ["clonidine50D", "clonidine100D", "clonidine150D", "clonidine50Dsingle", "clonidine100Dsingle", "clonidine150Dsingle"]: return os.path.join(working_dir, "/raid/st512/branchsbm/data", "pca_and_leiden_labels.csv") elif dataset_name in ["trametinib", "trametinibsingle"]: return os.path.join(working_dir, "/raid/st512/branchsbm/data", "Trametinib_5.0uM_pca_and_leidenumap_labels.csv") else: raise ValueError("Dataset not recognized") def create_callbacks(args, phase, data_type, run_id, datamodule=None): dirpath = os.path.join( args.working_dir, "checkpoints", data_type, str(run_id), f"{phase}_model", ) if phase == "geopath": early_stop_callback = EarlyStopping( monitor="BranchPathNet/val_loss_geopath", patience=args.patience_geopath, mode="min", ) checkpoint_callback = ModelCheckpoint( dirpath=dirpath, monitor="BranchPathNet/val_loss_geopath", mode="min", save_top_k=1, ) callbacks = [checkpoint_callback, early_stop_callback] elif phase == "flow": early_stop_callback = EarlyStopping( monitor="FlowNet/val_loss_cfm", patience=args.patience, mode="min", ) checkpoint_callback = ModelCheckpoint( dirpath=dirpath, mode="min", save_top_k=1, ) callbacks = [checkpoint_callback, early_stop_callback] elif phase == "growth": early_stop_callback = EarlyStopping( monitor="GrowthNet/val_loss", patience=args.patience, mode="min", ) checkpoint_callback = ModelCheckpoint( dirpath=dirpath, mode="min", save_top_k=1, ) callbacks = [checkpoint_callback, early_stop_callback] elif phase == "joint": early_stop_callback = EarlyStopping( monitor="JointTrain/val_loss", patience=args.patience, mode="min", ) checkpoint_callback = ModelCheckpoint( dirpath=dirpath, mode="min", save_top_k=1, ) callbacks = [checkpoint_callback, early_stop_callback] else: raise ValueError("Unknown phase") return callbacks class PlottingCallback(Callback): def __init__(self, plot_interval, datamodule): self.plot_interval = plot_interval self.datamodule = datamodule def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch pl_module.flow_net.train(mode=False) if epoch % self.plot_interval == 0 and epoch != 0: node = NeuralODE( flow_model_torch_wrapper(pl_module.flow_net).to(self.datamodule.device), solver="tsit5", sensitivity="adjoint", atol=1e-5, rtol=1e-5, ) for mode in ["train", "val"]: x0 = getattr(self.datamodule, f"{mode}_x0") x0 = x0[0:15] fig = self.trajectory_and_plot(x0, node, self.datamodule) wandb.log({f"Trajectories {mode.capitalize()}": wandb.Image(fig)}) pl_module.flow_net.train(mode=True) def trajectory_and_plot(self, x0, node, datamodule): selected_images = x0[0:15] with torch.no_grad(): traj = node.trajectory( selected_images.to(datamodule.device), t_span=torch.linspace(0, 1, 100).to(datamodule.device), ) traj = traj.transpose(0, 1) traj = traj.reshape(*traj.shape[0:2], *datamodule.dim) fig = plot_images_trajectory( traj.to(datamodule.device), datamodule.vae.to(datamodule.device), datamodule.process, num_steps=5, ) return fig