| # utils/data_loader.py | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| class CustomDataset(Dataset): | |
| def __init__(self, data): | |
| self.data = data | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| def load_data(batch_size=32): | |
| # Dummy data | |
| data = [torch.randn(10) for _ in range(1000)] | |
| dataset = CustomDataset(data) | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| return loader | |