|
|
import argparse |
|
|
import os |
|
|
import yaml |
|
|
import glob |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.utils.data as data |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import numpy as np |
|
|
from models.unet import DiffusionUNet |
|
|
from diff2flow import dict2namespace |
|
|
import utils.logging |
|
|
|
|
|
|
|
|
class ReflowDataset(data.Dataset): |
|
|
def __init__(self, data_dir): |
|
|
super().__init__() |
|
|
self.files = sorted(glob.glob(os.path.join(data_dir, "*.pth"))) |
|
|
print(f"Found {len(self.files)} files in {data_dir}") |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
|
|
|
return len(self.files) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
path = self.files[index] |
|
|
data_dict = torch.load(path) |
|
|
return data_dict |
|
|
|
|
|
|
|
|
def train_reflow(args, config): |
|
|
device = config.device |
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir=os.path.join(args.output, "logs")) |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model = DiffusionUNet(config) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
if args.resume: |
|
|
print(f"Loading pretrained weights from {args.resume}") |
|
|
checkpoint = torch.load(args.resume, map_location=device) |
|
|
if "state_dict" in checkpoint: |
|
|
state_dict = checkpoint["state_dict"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith("module."): |
|
|
new_state_dict[k[7:]] = v |
|
|
else: |
|
|
new_state_dict[k] = v |
|
|
model.load_state_dict(new_state_dict, strict=True) |
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=config.optim.lr) |
|
|
|
|
|
|
|
|
dataset = ReflowDataset(args.data_dir_reflow) |
|
|
|
|
|
loader = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4) |
|
|
|
|
|
model.train() |
|
|
|
|
|
print("Starting training...") |
|
|
|
|
|
step = 0 |
|
|
N = config.diffusion.num_diffusion_timesteps |
|
|
|
|
|
for epoch in range(args.epochs): |
|
|
for i, batch_dict in enumerate(loader): |
|
|
|
|
|
x_0 = batch_dict["x_data"].squeeze(0).to(device) |
|
|
x_1 = batch_dict["x_noise"].squeeze(0).to(device) |
|
|
x_cond = batch_dict["x_cond"].squeeze(0).to(device) |
|
|
|
|
|
B = x_0.shape[0] |
|
|
|
|
|
|
|
|
t = torch.rand(B, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_expand = t.view(B, 1, 1, 1) |
|
|
x_t = (1 - t_expand) * x_0 + t_expand * x_1 |
|
|
v_target = x_1 - x_0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_input = t * (N - 1) |
|
|
|
|
|
|
|
|
|
|
|
model_input = torch.cat([x_cond, x_t], dim=1) |
|
|
v_pred = model(model_input, t_input) |
|
|
|
|
|
|
|
|
loss = torch.mean((v_pred - v_target) ** 2) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
if step % 10 == 0: |
|
|
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.6f}") |
|
|
writer.add_scalar("Loss/train", loss.item(), step) |
|
|
|
|
|
step += 1 |
|
|
|
|
|
|
|
|
if (epoch + 1) % 5 == 0 or epoch == 0: |
|
|
save_path = os.path.join(args.output, f"reflow_ckpt_{epoch}.pth") |
|
|
torch.save(model.state_dict(), save_path) |
|
|
print(f"Saved checkpoint to {save_path}") |
|
|
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, required=True) |
|
|
parser.add_argument("--resume", type=str, default="") |
|
|
parser.add_argument("--data_dir_reflow", type=str, required=True) |
|
|
parser.add_argument("--epochs", type=int, default=10) |
|
|
parser.add_argument("--output", type=str, default="results/reflow_train") |
|
|
parser.add_argument("--seed", type=int, default=61) |
|
|
parser.add_argument("--lr", type=float, default=1e-5) |
|
|
args = parser.parse_args() |
|
|
|
|
|
with open(os.path.join("configs", args.config), "r") as f: |
|
|
config_dict = yaml.safe_load(f) |
|
|
config = dict2namespace(config_dict) |
|
|
|
|
|
if args.lr: |
|
|
config.optim.lr = args.lr |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
config.device = device |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
os.makedirs(args.output, exist_ok=True) |
|
|
|
|
|
train_reflow(args, config) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|