| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import random |
| import matplotlib.pyplot as plt |
|
|
|
|
| |
| def add_noise(x_0, noise, t): |
| return x_0 + noise * (t / 1000.0) |
|
|
|
|
| def plot_data(mu, sigma, color, title): |
| all_losses = np.array(mu) |
| sigma_losses = np.array(sigma) |
| x = np.arange(len(mu)) |
| plt.plot(x, all_losses, f'{color}-') |
| plt.fill_between(x, all_losses - sigma_losses, all_losses + sigma_losses, color=color, alpha=0.2) |
| plt.legend(['Mean Loss', 'Variance of Loss']) |
| plt.xlabel('Epoch') |
| plt.ylabel('Loss') |
| plt.title(title) |
| plt.show() |
|
|
|
|
| def train(model, conditioner, dataset, epochs=10): |
| model.train() |
| conditioner.train() |
| optimizer = torch.optim.Adam(list(model.parameters()) + list(conditioner.parameters()), lr=1e-4) |
| ce_loss = nn.CrossEntropyLoss() |
| torch.autograd.set_detect_anomaly(True) |
| all_bond_losses: list = [] |
| all_noise_losses: list = [] |
| all_losses: list = [] |
| all_sigma_bond_losses: list = [] |
| all_sigma_noise_losses: list = [] |
| all_sigma_losses: list = [] |
|
|
| for epoch in range(epochs): |
| total_bond_loss = 0 |
| total_noise_loss = 0 |
| total_loss = 0 |
| sigma_bond_losses: list = [] |
| sigma_noise_losses: list = [] |
| sigma_losses: list = [] |
|
|
| for data in dataset: |
| x_0, pos, edge_index, edge_attr, labels = data.x, data.pos, data.edge_index, data.edge_attr.view(-1), data.y |
| if torch.any(edge_attr >= 4) or torch.any(edge_attr < 0) or torch.any(torch.isnan(x_0)): |
| continue |
| t = torch.tensor([random.randint(1, 1000)]) |
| noise = torch.randn_like(x_0) |
| x_t = add_noise(x_0, noise, t) |
| cond_embed = conditioner(labels) |
| pred_noise, bond_logits = model(x_t, pos, edge_index, t, cond_embed) |
| loss_noise = F.mse_loss(pred_noise, noise) |
| loss_bond = ce_loss(bond_logits, edge_attr) |
| loss = loss_noise + loss_bond |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| total_bond_loss += loss_bond.item() |
| total_noise_loss += loss_noise.item() |
| total_loss += loss.item() |
| sigma_bond_losses.append(loss_bond.item()) |
| sigma_noise_losses.append(loss_noise.item()) |
| sigma_losses.append(loss.item()) |
|
|
| all_bond_losses.append(total_bond_loss) |
| all_noise_losses.append(total_noise_loss) |
| all_losses.append(total_loss) |
| all_sigma_bond_losses.append(torch.std(torch.tensor(sigma_bond_losses))) |
| all_sigma_noise_losses.append(torch.std(torch.tensor(sigma_noise_losses))) |
| all_sigma_losses.append(torch.std(torch.tensor(sigma_losses))) |
| print(f"Epoch {epoch}: Loss = {total_loss:.4f}, Noise Loss = {total_noise_loss:.4f}, Bond Loss = {total_bond_loss:.4f}") |
|
|
| plot_data(mu=all_bond_losses, sigma=all_sigma_bond_losses, color='b', title="Bond Loss") |
| plot_data(mu=all_noise_losses, sigma=all_sigma_noise_losses, color='r', title="Noise Loss") |
| plot_data(mu=all_losses, sigma=all_sigma_losses, color='g', title="Total Loss") |
|
|
| plt.plot(all_bond_losses) |
| plt.plot(all_noise_losses) |
| plt.plot(all_losses) |
| plt.legend(['Bond Loss', 'Noise Loss', 'Total Loss']) |
| plt.xlabel('Epoch') |
| plt.ylabel('Loss') |
| plt.title('Training Loss Over Epochs') |
| plt.show() |
| return model, conditioner |
|
|
|
|
| |
| def temperature_scaled_softmax(logits, temperature=1.0): |
| logits = logits / temperature |
| return torch.softmax(logits, dim=0) |
|
|