File size: 777 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from dataloader import get_dataloaders
from config import Config
from noise_scheduler import FrequencyAwareNoise
import matplotlib.pyplot as plt

def debug_data():
    config = Config()
    train_loader, _ = get_dataloaders(config)
    x0, _ = next(iter(train_loader))
    
    # Visualize original
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(x0[0].permute(1, 2, 0).numpy() * 0.5 + 0.5)
    plt.title("Original")
    
    # Visualize noisy
    noise_scheduler = FrequencyAwareNoise(config)
    xt = noise_scheduler.apply_noise(x0, torch.tensor([500] * len(x0)))
    plt.subplot(1, 2, 2)
    plt.imshow(xt[0].permute(1, 2, 0).numpy() * 0.5 + 0.5)
    plt.title("Noisy (t=500)")
    plt.show()

if __name__ == "__main__":
    debug_data()