rohanjain2312 commited on
Commit
398ce88
Β·
1 Parent(s): 4bf5f6c

Add MNIST GAN generator model with LFS tracking

Browse files
Files changed (8) hide show
  1. .DS_Store +0 -0
  2. README.md +123 -5
  3. app.py +228 -0
  4. discriminator.pth +3 -0
  5. generator.pth +3 -0
  6. loss_curve.png +0 -0
  7. model.py +43 -0
  8. requirements.txt +5 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,13 +1,131 @@
1
  ---
2
- title: Mnist Gan Generator
3
- emoji: 🏒
4
- colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MNIST GAN - Handwritten Digit Generation
3
+ colorFrom: purple
 
4
  colorTo: blue
5
  sdk: gradio
6
+ sdk_version: 4.0.0
7
  app_file: app.py
8
  pinned: false
9
  license: mit
10
  ---
11
 
12
+ # MNIST GAN - Handwritten Digit Generation
13
+
14
+ A Generative Adversarial Network trained on the MNIST dataset to generate realistic handwritten digit images.
15
+
16
+ ---
17
+
18
+ ## Author
19
+
20
+ **Rohan Jain**
21
+ - [LinkedIn](https://www.linkedin.com/in/jaroh23/)
22
+ - [GitHub](https://github.com/rohanjain2312)
23
+
24
+ ---
25
+
26
+ ## Model Summary
27
+
28
+ | Property | Details |
29
+ |-----------|----------|
30
+ | Architecture | Fully Connected GAN |
31
+ | Framework | PyTorch 2.0+ |
32
+ | Dataset | MNIST (60,000 training images) |
33
+ | Training Epochs | 200 |
34
+ | Batch Size | 128 |
35
+ | Image Resolution | 28Γ—28 pixels |
36
+
37
+ ---
38
+
39
+ ## Technical Architecture
40
+
41
+ **Generator:**
42
+ - Input: 100-dimensional latent vector
43
+ - Layers: Linear(100β†’128) β†’ LeakyReLU β†’ Linear(128β†’256) β†’ LeakyReLU β†’ Linear(256β†’512) β†’ LeakyReLU β†’ Linear(512β†’784) β†’ Tanh
44
+ - Output: 28Γ—28 grayscale image
45
+
46
+ **Discriminator:**
47
+ - Input: 784-dimensional flattened image
48
+ - Layers: Linear(784β†’512) β†’ LeakyReLU β†’ Linear(512β†’256) β†’ LeakyReLU β†’ Linear(256β†’1) β†’ Sigmoid
49
+ - Output: Real/Fake probability
50
+
51
+ **Training Configuration:**
52
+ - Loss Function: Binary Cross-Entropy
53
+ - Optimizer: Adam (lr=0.0002, betas=(0.5, 0.999))
54
+ - Normalization: [-1, 1] range
55
+
56
+ ---
57
+
58
+ ## Training Results
59
+
60
+ | Epoch | Image Quality |
61
+ |-------|---------------|
62
+ | 1 | Random noise |
63
+ | 50 | Faint digit structures |
64
+ | 100 | Recognizable digits |
65
+ | 150 | Clear, defined digits |
66
+ | 200 | High-quality handwritten digits |
67
+
68
+ **Loss Metrics:**
69
+ - Discriminator Loss (Stabilized): 0.31-0.49
70
+ - Generator Loss (Stabilized): 1.33-1.99
71
+
72
+ ---
73
+
74
+ ## Usage
75
+
76
+ **Local Installation:**
77
+ ```bash
78
+ git clone https://huggingface.co/spaces/rohanjain2312/MNIST-GAN
79
+ cd MNIST-GAN
80
+ pip install -r requirements.txt
81
+ python app.py
82
+ ```
83
+
84
+ **Programmatic Generation:**
85
+ ```python
86
+ import torch
87
+ from model import Generator
88
+
89
+ generator = Generator(latent_dim=100, img_shape=(1, 28, 28))
90
+ generator.load_state_dict(torch.load('generator.pth', map_location='cpu'))
91
+ generator.eval()
92
+
93
+ with torch.no_grad():
94
+ z = torch.randn(1, 100)
95
+ img = generator(z)
96
+ ```
97
+
98
+ ---
99
+
100
+ ## Project Files
101
+
102
+ | File | Description |
103
+ |------|-------------|
104
+ | `app.py` | Gradio interface |
105
+ | `model.py` | Generator and Discriminator architectures |
106
+ | `generator.pth` | Trained generator weights |
107
+ | `discriminator.pth` | Trained discriminator weights |
108
+ | `requirements.txt` | Python dependencies |
109
+
110
+ ---
111
+
112
+ ## Skills Demonstrated
113
+
114
+ - Generative Adversarial Networks (GANs)
115
+ - PyTorch Implementation
116
+ - Adversarial Training Dynamics
117
+ - Model Deployment (Hugging Face Spaces)
118
+ - Gradio Interface Development
119
+
120
+ ---
121
+
122
+ ## Acknowledgments
123
+
124
+ - Dataset: [MNIST Dataset - Yann LeCun](http://yann.lecun.com/exdb/mnist/)
125
+ - Framework: [Generative Adversarial Nets - Goodfellow et al., 2014](https://arxiv.org/abs/1406.2661)
126
+
127
+ ---
128
+
129
+ ## License
130
+
131
+ MIT License
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import time
6
+ from model import Generator, Discriminator
7
+
8
+ # Configuration
9
+ LATENT_DIM = 100
10
+ IMG_SHAPE = (1, 28, 28)
11
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Load the trained models
14
+ generator = Generator(latent_dim=LATENT_DIM, img_shape=IMG_SHAPE).to(DEVICE)
15
+ generator.load_state_dict(torch.load('generator.pth', map_location=DEVICE))
16
+ generator.eval()
17
+
18
+ discriminator = Discriminator(img_shape=IMG_SHAPE).to(DEVICE)
19
+ discriminator.load_state_dict(torch.load('discriminator.pth', map_location=DEVICE))
20
+ discriminator.eval()
21
+
22
+ def generate_digits(num_images, seed, show_confidence):
23
+ """Generate digits and return with optional confidence scores"""
24
+ start_time = time.time()
25
+
26
+ with torch.no_grad():
27
+ # Set seed for reproducibility if provided
28
+ if seed != 0:
29
+ torch.manual_seed(seed)
30
+ np.random.seed(seed)
31
+
32
+ # Generate random noise
33
+ num_images = int(num_images)
34
+ z = torch.randn(num_images, LATENT_DIM).to(DEVICE)
35
+
36
+ # Generate images
37
+ generated_imgs = generator(z)
38
+
39
+ # Get discriminator confidence
40
+ confidence_scores = discriminator(generated_imgs)
41
+ avg_confidence = confidence_scores.mean().item() * 100
42
+
43
+ # Convert to numpy and denormalize
44
+ generated_imgs = generated_imgs.cpu().numpy()
45
+ generated_imgs = ((generated_imgs + 1) / 2 * 255).astype(np.uint8)
46
+
47
+ # Create grid
48
+ grid_size = int(np.ceil(np.sqrt(num_images)))
49
+ grid_img = Image.new('L', (280 * grid_size, 280 * grid_size), color=255)
50
+
51
+ for idx in range(num_images):
52
+ img_pil = Image.fromarray(generated_imgs[idx][0], mode='L')
53
+ img_pil = img_pil.resize((280, 280), Image.NEAREST)
54
+ row = idx // grid_size
55
+ col = idx % grid_size
56
+ grid_img.paste(img_pil, (col * 280, row * 280))
57
+
58
+ generation_time = time.time() - start_time
59
+
60
+ # Build info text
61
+ info_text = f"Generated {num_images} digit(s) in {generation_time:.3f}s"
62
+ if show_confidence:
63
+ info_text += f"\nDiscriminator Confidence: {avg_confidence:.1f}% (how 'real' the digits appear)"
64
+
65
+ return grid_img, info_text
66
+
67
+ def create_comparison_grid():
68
+ """Create a comparison showing training progress"""
69
+ # Create sample images at different seeds
70
+ seeds = [42, 123, 456, 789]
71
+ images = []
72
+
73
+ with torch.no_grad():
74
+ for seed in seeds:
75
+ torch.manual_seed(seed)
76
+ z = torch.randn(1, LATENT_DIM).to(DEVICE)
77
+ img = generator(z)
78
+ img = ((img[0, 0].cpu().numpy() + 1) / 2 * 255).astype(np.uint8)
79
+ img_pil = Image.fromarray(img, mode='L')
80
+ img_pil = img_pil.resize((140, 140), Image.NEAREST)
81
+ images.append(img_pil)
82
+
83
+ # Create grid
84
+ grid = Image.new('L', (280, 280), color=255)
85
+ for idx, img in enumerate(images):
86
+ row = idx // 2
87
+ col = idx % 2
88
+ grid.paste(img, (col * 140, row * 140))
89
+
90
+ return grid
91
+
92
+ # Create interface
93
+ with gr.Blocks(title="MNIST GAN Generator", theme=gr.themes.Soft()) as demo:
94
+ gr.Markdown(
95
+ """
96
+ # MNIST Handwritten Digit Generator
97
+ ### Using Generative Adversarial Networks (GAN)
98
+
99
+ **Created by Rohan Jain** | [LinkedIn](https://www.linkedin.com/in/jaroh23/) | [GitHub](https://github.com/rohanjain2312)
100
+ """
101
+ )
102
+
103
+ with gr.Tabs():
104
+ # Tab 1: Generate Digits
105
+ with gr.TabItem("Generate Digits"):
106
+ gr.Markdown("Generate new handwritten digits using the trained GAN model.")
107
+
108
+ with gr.Row():
109
+ with gr.Column(scale=1):
110
+ num_images = gr.Slider(
111
+ minimum=1,
112
+ maximum=16,
113
+ value=4,
114
+ step=1,
115
+ label="Number of Digits",
116
+ info="Generate 1-16 digits at once"
117
+ )
118
+ seed = gr.Number(
119
+ value=0,
120
+ label="Random Seed",
121
+ info="Use 0 for random, or set a number for reproducible results"
122
+ )
123
+ show_confidence = gr.Checkbox(
124
+ value=True,
125
+ label="Show Discriminator Confidence",
126
+ info="Display how 'real' the generator fooled the discriminator"
127
+ )
128
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
129
+
130
+ with gr.Column(scale=2):
131
+ output_image = gr.Image(label="Generated Digits", type="pil")
132
+ output_info = gr.Textbox(label="Generation Info", lines=2)
133
+
134
+ generate_btn.click(
135
+ fn=generate_digits,
136
+ inputs=[num_images, seed, show_confidence],
137
+ outputs=[output_image, output_info]
138
+ )
139
+
140
+ gr.Markdown("### Quick Examples")
141
+ gr.Examples(
142
+ examples=[
143
+ [4, 42, True],
144
+ [9, 123, True],
145
+ [16, 456, False],
146
+ ],
147
+ inputs=[num_images, seed, show_confidence],
148
+ outputs=[output_image, output_info],
149
+ fn=generate_digits,
150
+ cache_examples=True,
151
+ )
152
+
153
+ # Tab 2: Model Information
154
+ with gr.TabItem("Model Details"):
155
+ gr.Markdown("### Training Loss Curves")
156
+ gr.Image(value="loss_curve.png", label="Generator and Discriminator Loss During Training")
157
+
158
+ gr.Markdown(
159
+ """
160
+ **Loss Analysis:**
161
+ - The discriminator loss (orange) stabilizes around 0.3-0.4, indicating it effectively distinguishes real from fake
162
+ - The generator loss (blue) shows typical adversarial dynamics, settling around 1.5-2.0
163
+ - The fluctuating losses indicate healthy adversarial balance between the two networks
164
+ """
165
+ )
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ gr.Markdown(
170
+ """
171
+ ### Training Details
172
+
173
+ **Architecture**: Fully Connected GAN
174
+ **Dataset**: MNIST (60,000 images)
175
+ **Epochs**: 200
176
+ **Batch Size**: 128
177
+
178
+ **Generator**:
179
+ - Input: 100-dim random vector
180
+ - Layers: 128 β†’ 256 β†’ 512 β†’ 784
181
+ - Output: 28Γ—28 grayscale image
182
+
183
+ **Discriminator**:
184
+ - Input: 28Γ—28 image (784 dims)
185
+ - Layers: 512 β†’ 256 β†’ 1
186
+ - Output: Real/Fake probability
187
+ """
188
+ )
189
+
190
+ with gr.Column():
191
+ gr.Markdown(
192
+ """
193
+ ### Training Results
194
+
195
+ **Loss Metrics** (Final):
196
+ - Discriminator Loss: ~0.32
197
+ - Generator Loss: ~1.99
198
+
199
+ **Training Evolution**:
200
+ - Epoch 1: Random noise
201
+ - Epoch 50: Faint structures
202
+ - Epoch 100: Recognizable digits
203
+ - Epoch 200: High-quality digits
204
+
205
+ The adversarial training successfully balanced both networks,
206
+ resulting in realistic digit generation.
207
+ """
208
+ )
209
+
210
+ gr.Markdown("### Sample Outputs (Different Seeds)")
211
+ comparison_img = create_comparison_grid()
212
+ gr.Image(value=comparison_img, label="Generated Samples", type="pil")
213
+
214
+ gr.Markdown(
215
+ """
216
+ ---
217
+ ### About This Project
218
+
219
+ This GAN was trained to generate handwritten digits by learning from the MNIST dataset.
220
+ The generator creates images from random noise, while the discriminator learns to distinguish real from fake images.
221
+ Through adversarial training, the generator improves until it produces realistic digits.
222
+
223
+ **Tech Stack**: PyTorch, Gradio, NumPy | **Training**: Google Colab (GPU) | **Deployment**: Hugging Face Spaces
224
+ """
225
+ )
226
+
227
+ if __name__ == "__main__":
228
+ demo.launch()
discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7cd72eed1969132cc1e4ab317d1f7c2217988873bc96350e454ce42ba417d44
3
+ size 2137405
generator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71a3b34bfd986087d5606fb993800425662f9941d8bcf730f18608780d0cd473
3
+ size 2322841
loss_curve.png ADDED
model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ class Generator(nn.Module):
6
+ def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
7
+ super().__init__()
8
+ self.img_shape = img_shape
9
+
10
+ self.model = nn.Sequential(
11
+ nn.Linear(latent_dim, 128),
12
+ nn.LeakyReLU(0.2, inplace=True),
13
+ nn.Linear(128, 256),
14
+ nn.LeakyReLU(0.2, inplace=True),
15
+ nn.Linear(256, 512),
16
+ nn.LeakyReLU(0.2, inplace=True),
17
+ nn.Linear(512, int(np.prod(img_shape))),
18
+ nn.Tanh()
19
+ )
20
+
21
+ def forward(self, z):
22
+ img = self.model(z)
23
+ img = img.view(img.size(0), *self.img_shape)
24
+ return img
25
+
26
+
27
+ class Discriminator(nn.Module):
28
+ def __init__(self, img_shape=(1, 28, 28)):
29
+ super().__init__()
30
+
31
+ self.model = nn.Sequential(
32
+ nn.Linear(int(np.prod(img_shape)), 512),
33
+ nn.LeakyReLU(0.2, inplace=True),
34
+ nn.Linear(512, 256),
35
+ nn.LeakyReLU(0.2, inplace=True),
36
+ nn.Linear(256, 1),
37
+ nn.Sigmoid()
38
+ )
39
+
40
+ def forward(self, img):
41
+ img_flat = img.view(img.size(0), -1)
42
+ validity = self.model(img_flat)
43
+ return validity
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=4.0.0
4
+ numpy>=1.24.0
5
+ Pillow>=9.5.0