Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import gradio as gr | |
| # Define the VAE model | |
| class ConvVAE(nn.Module): | |
| def __init__(self, input_channels=3, latent_dim=32): | |
| super(ConvVAE, self).__init__() | |
| self.latent_dim = latent_dim | |
| # Encoder | |
| self.enc_conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1) | |
| self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) | |
| self.enc_conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) | |
| self.fc_mu = nn.Linear(256 * 4 * 10, latent_dim) | |
| self.fc_logvar = nn.Linear(256 * 4 * 10, latent_dim) | |
| # Decoder | |
| self.fc_decode = nn.Linear(latent_dim, 256 * 4 * 10) | |
| self.dec_conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1) | |
| self.dec_conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.dec_conv3 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1)) | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def forward(self, x): | |
| x = F.relu(self.enc_conv1(x)) | |
| x = F.relu(self.enc_conv2(x)) | |
| x = F.relu(self.enc_conv3(x)) | |
| x = x.view(x.size(0), -1) | |
| mu = self.fc_mu(x) | |
| logvar = self.fc_logvar(x) | |
| z = self.reparameterize(mu, logvar) | |
| out = self.decode(z) | |
| return out, mu, logvar | |
| def decode(self, z): | |
| x = F.relu(self.fc_decode(z)) | |
| x = x.view(x.size(0), 256, 4, 10) | |
| x = F.relu(self.dec_conv1(x)) | |
| x = F.relu(self.dec_conv2(x)) | |
| x = self.dec_conv3(x) | |
| return F.softmax(x, dim=1) | |
| # Load trained model | |
| model = ConvVAE() | |
| model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu"))) | |
| model.eval() | |
| # Sampling | |
| def sample_with_temperature(probs, temperature=1.2): | |
| logits = torch.log(probs + 1e-8) / temperature | |
| scaled_probs = torch.softmax(logits, dim=1) | |
| batch, channels, height, width = scaled_probs.shape | |
| scaled_probs = scaled_probs.permute(0, 2, 3, 1).contiguous().view(-1, channels) | |
| sampled = torch.multinomial(scaled_probs, num_samples=1) | |
| sampled = sampled.view(batch, height, width) | |
| return sampled | |
| def generate_map(seed: int = 0): | |
| model.eval() | |
| if seed == 0: | |
| seed = torch.randint(10000, (1,)).item() | |
| torch.manual_seed(seed) | |
| z = torch.randn(1, model.latent_dim).to("cpu") | |
| with torch.no_grad(): | |
| output = model.decode(z) | |
| output = sample_with_temperature(output, temperature=3)[0].cpu().numpy() | |
| grid = np.pad(output, ((5, 0), (0, 0)), mode='constant', constant_values=0) | |
| # Post-processing rule to collapse columns with inner air blocks | |
| for j in range(len(grid[0])): | |
| non_air_blocks = [grid[i, j] for i in range(len(grid)) if grid[i, j] != 0] | |
| k = len(non_air_blocks) | |
| if k > 0: | |
| grid[20 - k:20, j] = non_air_blocks | |
| grid[0:20 - k, j] = 0 | |
| return ["".join(map(str, row)) for row in grid] # Convert each row to a string | |
| gr.Interface( | |
| fn=generate_map, | |
| inputs=gr.Number(label="Seed (set to 0 for random generation)"), | |
| outputs=gr.JSON(label="Generated Map Grid"), | |
| title="VAE Level Generator", | |
| description="Returns a 20x40 grid as a list of strings where 0=air, 1=ground, 2=lava" | |
| ).launch() |