Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import time | |
| from model import Generator, Discriminator | |
| # Configuration | |
| LATENT_DIM = 100 | |
| IMG_SHAPE = (1, 28, 28) | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load the trained models | |
| generator = Generator(latent_dim=LATENT_DIM, img_shape=IMG_SHAPE).to(DEVICE) | |
| generator.load_state_dict(torch.load('generator.pth', map_location=DEVICE)) | |
| generator.eval() | |
| discriminator = Discriminator(img_shape=IMG_SHAPE).to(DEVICE) | |
| discriminator.load_state_dict(torch.load('discriminator.pth', map_location=DEVICE)) | |
| discriminator.eval() | |
| def generate_digits(num_images, seed, show_confidence): | |
| """Generate digits and return with optional confidence scores""" | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| # Set seed for reproducibility if provided | |
| if seed != 0: | |
| torch.manual_seed(seed) | |
| np.random.seed(int(seed)) | |
| # Generate random noise | |
| num_images = int(num_images) | |
| z = torch.randn(num_images, LATENT_DIM).to(DEVICE) | |
| # Generate images | |
| generated_imgs = generator(z) | |
| # Get discriminator confidence | |
| confidence_scores = discriminator(generated_imgs) | |
| avg_confidence = confidence_scores.mean().item() * 100 | |
| # Convert to numpy and denormalize | |
| generated_imgs = generated_imgs.cpu().numpy() | |
| generated_imgs = ((generated_imgs + 1) / 2 * 255).astype(np.uint8) | |
| # Create grid | |
| grid_size = int(np.ceil(np.sqrt(num_images))) | |
| grid_img = Image.new('L', (280 * grid_size, 280 * grid_size), color=255) | |
| for idx in range(num_images): | |
| img_pil = Image.fromarray(generated_imgs[idx][0], mode='L') | |
| img_pil = img_pil.resize((280, 280), Image.NEAREST) | |
| row = idx // grid_size | |
| col = idx % grid_size | |
| grid_img.paste(img_pil, (col * 280, row * 280)) | |
| generation_time = time.time() - start_time | |
| # Build info text | |
| info_text = f"Generated {num_images} digit(s) in {generation_time:.3f}s" | |
| if show_confidence: | |
| info_text += f"\nDiscriminator Confidence: {avg_confidence:.1f}% (how 'real' the digits appear)" | |
| return grid_img, info_text | |
| def create_comparison_grid(): | |
| """Create a comparison showing training progress""" | |
| # Create sample images at different seeds | |
| seeds = [42, 123, 456, 789] | |
| images = [] | |
| with torch.no_grad(): | |
| for seed in seeds: | |
| torch.manual_seed(seed) | |
| z = torch.randn(1, LATENT_DIM).to(DEVICE) | |
| img = generator(z) | |
| img = ((img[0, 0].cpu().numpy() + 1) / 2 * 255).astype(np.uint8) | |
| img_pil = Image.fromarray(img, mode='L') | |
| img_pil = img_pil.resize((140, 140), Image.NEAREST) | |
| images.append(img_pil) | |
| # Create grid | |
| grid = Image.new('L', (280, 280), color=255) | |
| for idx, img in enumerate(images): | |
| row = idx // 2 | |
| col = idx % 2 | |
| grid.paste(img, (col * 140, row * 140)) | |
| return grid | |
| # Create interface | |
| with gr.Blocks(title="MNIST GAN Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Handwritten Digit Generator | |
| ### Using Generative Adversarial Networks (GAN) | |
| ### About This Project | |
| This GAN was trained to generate handwritten digits by learning from the MNIST dataset. | |
| The generator creates images from random noise, while the discriminator learns to distinguish real from fake images. | |
| Through adversarial training, the generator improves until it produces realistic digits. | |
| **Created by Rohan Jain** | [LinkedIn](https://www.linkedin.com/in/jaroh23/) | [GitHub](https://github.com/rohanjain2312) | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # Tab 1: Generate Digits | |
| with gr.TabItem("Generate Digits"): | |
| gr.Markdown("Generate new handwritten digits using the trained GAN model.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| num_images = gr.Slider( | |
| minimum=1, | |
| maximum=16, | |
| value=4, | |
| step=1, | |
| label="Number of Digits", | |
| info="Generate 1-16 digits at once" | |
| ) | |
| seed = gr.Number( | |
| value=0, | |
| label="Random Seed", | |
| info="Use 0 for random, or set a number for reproducible results" | |
| ) | |
| show_confidence = gr.Checkbox( | |
| value=True, | |
| label="Show Discriminator Confidence", | |
| info="Display how 'real' the generator fooled the discriminator" | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(label="Generated Digits", type="pil") | |
| output_info = gr.Textbox(label="Generation Info", lines=2) | |
| generate_btn.click( | |
| fn=generate_digits, | |
| inputs=[num_images, seed, show_confidence], | |
| outputs=[output_image, output_info] | |
| ) | |
| gr.Markdown("### Quick Examples") | |
| gr.Examples( | |
| examples=[ | |
| [4, 42, True], | |
| [9, 123, True], | |
| [16, 456, False], | |
| ], | |
| inputs=[num_images, seed, show_confidence], | |
| outputs=[output_image, output_info], | |
| fn=generate_digits, | |
| cache_examples=True, | |
| ) | |
| # Tab 2: Model Information | |
| with gr.TabItem("Model Details"): | |
| gr.Markdown("### Training Loss Curves") | |
| gr.Image(value="loss_curve.png", label="Generator and Discriminator Loss During Training") | |
| gr.Markdown( | |
| """ | |
| **Loss Analysis:** | |
| - The discriminator loss (orange) stabilizes around 0.3-0.4, indicating it effectively distinguishes real from fake | |
| - The generator loss (blue) shows typical adversarial dynamics, settling around 1.5-2.0 | |
| - The fluctuating losses indicate healthy adversarial balance between the two networks | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ### Training Details | |
| **Architecture**: Fully Connected GAN | |
| **Dataset**: MNIST (60,000 images) | |
| **Epochs**: 200 | |
| **Batch Size**: 128 | |
| **Generator**: | |
| - Input: 100-dim random vector | |
| - Layers: 128 β 256 β 512 β 784 | |
| - Output: 28Γ28 grayscale image | |
| **Discriminator**: | |
| - Input: 28Γ28 image (784 dims) | |
| - Layers: 512 β 256 β 1 | |
| - Output: Real/Fake probability | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ### Training Results | |
| **Loss Metrics** (Final): | |
| - Discriminator Loss: ~0.32 | |
| - Generator Loss: ~1.99 | |
| **Training Evolution**: | |
| - Epoch 1: Random noise | |
| - Epoch 50: Faint structures | |
| - Epoch 100: Recognizable digits | |
| - Epoch 200: High-quality digits | |
| The adversarial training successfully balanced both networks, | |
| resulting in realistic digit generation. | |
| """ | |
| ) | |
| gr.Markdown("### Sample Outputs (Different Seeds)") | |
| comparison_img = create_comparison_grid() | |
| gr.Image(value=comparison_img, label="Generated Samples", type="pil") | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Tech Stack**: PyTorch, Gradio, NumPy | **Training**: Google Colab (GPU) | **Deployment**: Hugging Face Spaces | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |