rubenaghayan commited on
Commit
f45427d
·
1 Parent(s): 97e312a

fsdp + bugfixes

Browse files
Files changed (8) hide show
  1. .gradio/certificate.pem +31 -0
  2. app.py +195 -43
  3. calculator.py +119 -62
  4. defaults.py +107 -10
  5. details.py +21 -0
  6. dtypes.py +1 -1
  7. limitations.py +27 -0
  8. state.py +3 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py CHANGED
@@ -2,16 +2,19 @@ import gradio as gr
2
  import pandas as pd
3
  from functools import partial
4
  from defaults import DEFAULTS
 
5
  from state import Model, Parallelism, Training
6
  from calculator import MemoryCalculation
7
  from dtypes import DType
 
 
8
 
9
  # Create a Number component for natural numbers (positive integers)
10
  NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True)
11
 
 
12
 
13
- def greet(name, intensity) -> str:
14
- return "Hello, " + name + "!" * int(intensity)
15
 
16
 
17
  def create_parallelism_block():
@@ -22,30 +25,103 @@ def create_parallelism_block():
22
  pp = NaturalNumber(label="Pipeline Parallelism", value=1)
23
  cp = NaturalNumber(label="Context Parallelism", value=1)
24
  ep = NaturalNumber(label="Expert Parallelism", value=1)
25
- return tp, pp, cp, ep
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def create_model_block():
29
  with gr.Column():
30
  gr.Markdown("# Model Architecture")
31
- layers = NaturalNumber(label="Number of Layers", value=32)
32
- vocab = NaturalNumber(label="Vocab Size", value=32000)
33
- hidden = NaturalNumber(label="Hidden Dim", value=4096)
34
- intermediate = NaturalNumber(label="Intermediate Dim", value=11008)
35
  is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False)
36
- active_experts = NaturalNumber(label="Active Experts", value=2, visible=False)
37
- total_experts = NaturalNumber(label="Total Experts", value=8, visible=False)
 
38
 
39
- # Toggle expert fields visibility based on MoE checkbox
40
  is_moe.change(
41
- fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
 
 
 
42
  inputs=is_moe,
43
  outputs=[active_experts, total_experts]
44
  )
45
 
46
- # not ready yet
47
- # presets = gr.Dropdown(list(DEFAULTS.keys()), label="Presets", interactive=True)
48
- return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  def create_training_block():
@@ -58,12 +134,15 @@ def create_training_block():
58
  grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False)
59
  precision = gr.Dropdown(DType.values(), label="Precision", value=DType.FP32.value, interactive=True)
60
  mixed_precision = gr.Checkbox(label="Mixed Precision", value=False)
61
- param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=True, visible=False)
62
- reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=True, visible=False)
63
 
64
- # Toggle dtype fields visibility based on mixed precision checkbox
65
  mixed_precision.change(
66
- fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
 
 
 
67
  inputs=mixed_precision,
68
  outputs=[param_dtype, reduce_dtype]
69
  )
@@ -71,14 +150,14 @@ def create_training_block():
71
  return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype
72
 
73
 
74
- def calculate(tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype):
75
  # Create state objects
76
  model_config = Model(
77
  vocab_size=int(vocab),
78
  num_layers=int(layers),
79
  hidden_dim=int(hidden),
80
  intermediate_size=int(intermediate),
81
- weight_tied_embeddings=True, # Default assumption
82
  active_experts=int(active_experts),
83
  total_experts=int(total_experts),
84
  is_moe=is_moe
@@ -88,7 +167,10 @@ def calculate(tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_expert
88
  tensor_parallelism=int(tp),
89
  pipeline_parallelism=int(pp),
90
  context_parallelism=int(cp),
91
- expert_parallelism=int(ep)
 
 
 
92
  )
93
 
94
  training_config = Training(
@@ -111,48 +193,118 @@ def calculate(tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_expert
111
  gradient_memory = calc.calculate_gradient_memory()
112
  optimizer_memory = calc.calculate_optimizer_memory()
113
 
114
- # Create DataFrame for bar plot
115
- memory_data = pd.DataFrame({
116
- 'Component': [
117
- 'Parameter Memory',
118
- 'Activation Memory',
119
- 'Gradient Memory',
120
- 'Optimizer Memory'
121
- ],
122
- 'Memory (GB)': [
123
- param_memory / 1e9,
124
- activation_memory / 1e9,
125
- gradient_memory / 1e9,
126
- optimizer_memory / 1e9
127
- ]
128
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  return gr.BarPlot(
131
  value=memory_data,
132
  x="Component",
133
  y="Memory (GB)",
 
 
134
  title="LLM Memory Usage Breakdown",
135
  container=False,
136
- y_lim=[0, None]
 
 
 
 
 
 
 
137
  )
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- with gr.Blocks(theme='gstaff/xkcd') as demo:
141
- with gr.Sidebar():
142
- gr.Textbox("## LLM Memory Visualizer")
 
143
  with gr.Column():
144
  with gr.Row(equal_height=True):
145
- tp, pp, cp, ep = create_parallelism_block()
146
- layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets = create_model_block()
147
  seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block()
148
  calculate_button = gr.Button("Calculate")
149
  output = gr.BarPlot(label="Memory Usage Breakdown")
150
 
151
  calculate_button.click(
152
  fn=calculate,
153
- inputs=[tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype],
154
  outputs=output
155
  )
156
 
 
 
 
 
 
 
 
 
 
157
 
158
  demo.launch()
 
2
  import pandas as pd
3
  from functools import partial
4
  from defaults import DEFAULTS
5
+ from details import DETAILS
6
  from state import Model, Parallelism, Training
7
  from calculator import MemoryCalculation
8
  from dtypes import DType
9
+ from gradio.themes import ThemeClass as Theme
10
+ from limitations import LIMITATIONS
11
 
12
  # Create a Number component for natural numbers (positive integers)
13
  NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True)
14
 
15
+ colors = {
16
 
17
+ }
 
18
 
19
 
20
  def create_parallelism_block():
 
25
  pp = NaturalNumber(label="Pipeline Parallelism", value=1)
26
  cp = NaturalNumber(label="Context Parallelism", value=1)
27
  ep = NaturalNumber(label="Expert Parallelism", value=1)
28
+
29
+ fsdp_enabled = gr.Checkbox(label="FSDP (Fully Sharded Data Parallel)", value=False)
30
+ fsdp_parallelism = NaturalNumber(label="FSDP Parallelism", value=1, interactive=False, elem_classes="disabled-field")
31
+ fsdp_strategy = gr.Radio(
32
+ choices=["Zero-1", "Zero-2", "Zero-3"],
33
+ label="FSDP Strategy",
34
+ value="Zero-1",
35
+ interactive=False,
36
+ elem_classes="disabled-field"
37
+ )
38
+
39
+ # Toggle FSDP fields interactivity based on FSDP checkbox
40
+ fsdp_enabled.change(
41
+ fn=lambda x: [
42
+ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
43
+ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
44
+ ],
45
+ inputs=fsdp_enabled,
46
+ outputs=[fsdp_parallelism, fsdp_strategy]
47
+ )
48
+
49
+ return tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy
50
 
51
 
52
  def create_model_block():
53
  with gr.Column():
54
  gr.Markdown("# Model Architecture")
55
+ layers = NaturalNumber(label="Number of Layers", value=48)
56
+ vocab = NaturalNumber(label="Vocab Size", value=262144)
57
+ hidden = NaturalNumber(label="Hidden Dim", value=3840)
58
+ intermediate = NaturalNumber(label="Intermediate Dim", value=15360)
59
  is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False)
60
+ active_experts = NaturalNumber(label="Active Experts", value=2, interactive=False, elem_classes="disabled-field")
61
+ total_experts = NaturalNumber(label="Total Experts", value=8, interactive=False, elem_classes="disabled-field")
62
+ weight_tied_embeddings = gr.Checkbox(label="Weight Tied Embeddings", value=True)
63
 
64
+ # Toggle expert fields interactivity based on MoE checkbox
65
  is_moe.change(
66
+ fn=lambda x: [
67
+ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
68
+ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
69
+ ],
70
  inputs=is_moe,
71
  outputs=[active_experts, total_experts]
72
  )
73
 
74
+ presets = gr.Dropdown(["Custom"] + list(DEFAULTS.keys()), label="Presets", value="Gemma3 12B", interactive=True)
75
+
76
+ # Populate model parameters when preset is selected
77
+ def populate_from_preset(preset_name):
78
+ if preset_name and preset_name in DEFAULTS:
79
+ model = DEFAULTS[preset_name]
80
+ return [
81
+ gr.update(value=model.num_layers),
82
+ gr.update(value=model.vocab_size),
83
+ gr.update(value=model.hidden_dim),
84
+ gr.update(value=model.intermediate_size),
85
+ gr.update(value=model.is_moe),
86
+ gr.update(value=model.active_experts, interactive=model.is_moe),
87
+ gr.update(value=model.total_experts, interactive=model.is_moe),
88
+ gr.update(value=model.weight_tied_embeddings)
89
+ ]
90
+ return [gr.update() for _ in range(8)]
91
+
92
+ # Switch to "Custom" when user manually edits values
93
+ def switch_to_custom(layers_val, vocab_val, hidden_val, intermediate_val, is_moe_val, active_experts_val, total_experts_val, weight_tied_val, current_preset):
94
+ # Don't switch to custom if a preset is being applied
95
+ if current_preset and current_preset in DEFAULTS:
96
+ model = DEFAULTS[current_preset]
97
+ # Check if current values match the preset exactly
98
+ if (layers_val == model.num_layers and
99
+ vocab_val == model.vocab_size and
100
+ hidden_val == model.hidden_dim and
101
+ intermediate_val == model.intermediate_size and
102
+ is_moe_val == model.is_moe and
103
+ active_experts_val == model.active_experts and
104
+ total_experts_val == model.total_experts and
105
+ weight_tied_val == model.weight_tied_embeddings):
106
+ return gr.update() # Keep current preset
107
+
108
+ return gr.update(value="Custom")
109
+
110
+ presets.change(
111
+ fn=populate_from_preset,
112
+ inputs=presets,
113
+ outputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]
114
+ )
115
+
116
+ # Add change listeners to all model parameter inputs
117
+ for input_component in [layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]:
118
+ input_component.change(
119
+ fn=switch_to_custom,
120
+ inputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings, presets],
121
+ outputs=presets
122
+ )
123
+
124
+ return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings
125
 
126
 
127
  def create_training_block():
 
134
  grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False)
135
  precision = gr.Dropdown(DType.values(), label="Precision", value=DType.FP32.value, interactive=True)
136
  mixed_precision = gr.Checkbox(label="Mixed Precision", value=False)
137
+ param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field")
138
+ reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field")
139
 
140
+ # Toggle dtype fields interactivity based on mixed precision checkbox
141
  mixed_precision.change(
142
+ fn=lambda x: [
143
+ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]),
144
+ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"])
145
+ ],
146
  inputs=mixed_precision,
147
  outputs=[param_dtype, reduce_dtype]
148
  )
 
150
  return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype
151
 
152
 
153
+ def calculate(tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype):
154
  # Create state objects
155
  model_config = Model(
156
  vocab_size=int(vocab),
157
  num_layers=int(layers),
158
  hidden_dim=int(hidden),
159
  intermediate_size=int(intermediate),
160
+ weight_tied_embeddings=weight_tied_embeddings,
161
  active_experts=int(active_experts),
162
  total_experts=int(total_experts),
163
  is_moe=is_moe
 
167
  tensor_parallelism=int(tp),
168
  pipeline_parallelism=int(pp),
169
  context_parallelism=int(cp),
170
+ expert_parallelism=int(ep),
171
+ fsdp_enabled=fsdp_enabled,
172
+ fsdp_parallelism=int(fsdp_parallelism),
173
+ fsdp_strategy=fsdp_strategy
174
  )
175
 
176
  training_config = Training(
 
193
  gradient_memory = calc.calculate_gradient_memory()
194
  optimizer_memory = calc.calculate_optimizer_memory()
195
 
196
+ # Calculate total memory
197
+ total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory
198
+
199
+ # Round to 1 decimal place for display
200
+ param_gb = round(param_memory / 1e9, 1)
201
+ activation_gb = round(activation_memory / 1e9, 1)
202
+ gradient_gb = round(gradient_memory / 1e9, 1)
203
+ optimizer_gb = round(optimizer_memory / 1e9, 1)
204
+ total_gb = round(total_memory / 1e9, 1)
205
+
206
+ # Create DataFrame for stacked bar plot
207
+ # Start with stacked total bar, then add individual bars
208
+ individual_data = []
209
+
210
+ # Stacked total bar first - create separate rows for each component within total
211
+ for mem_type, gb_val in [
212
+ ('Activation', activation_gb),
213
+ ('Optimizer', optimizer_gb),
214
+ ('Gradient', gradient_gb),
215
+ ('Parameter', param_gb)
216
+ ]:
217
+ individual_data.append({
218
+ 'Component': f'Total Memory\n{total_gb} GB',
219
+ 'Memory (GB)': gb_val,
220
+ 'Type': mem_type
221
+ })
222
+
223
+ # Individual component bars
224
+ for component, gb_val, mem_type in [
225
+ (f'Parameter Memory\n{param_gb} GB', param_gb, 'Parameter'),
226
+ (f'Gradient Memory\n{gradient_gb} GB', gradient_gb, 'Gradient'),
227
+ (f'Optimizer Memory\n{optimizer_gb} GB', optimizer_gb, 'Optimizer'),
228
+ (f'Activation Memory\n{activation_gb} GB', activation_gb, 'Activation')
229
+ ]:
230
+ individual_data.append({
231
+ 'Component': component,
232
+ 'Memory (GB)': gb_val,
233
+ 'Type': mem_type
234
+ })
235
+
236
+ memory_data = pd.DataFrame(individual_data)
237
+
238
+ # Define pastel color map
239
+ color_map = {
240
+ 'Parameter': '#B6E5D8', # Light Mint
241
+ 'Gradient': '#FFB6C1', # Light Pink
242
+ 'Optimizer': '#C7B3FF', # Light Purple
243
+ 'Activation': '#FFD1A9', # Light Peach
244
+ }
245
 
246
  return gr.BarPlot(
247
  value=memory_data,
248
  x="Component",
249
  y="Memory (GB)",
250
+ color="Type",
251
+ color_map=color_map,
252
  title="LLM Memory Usage Breakdown",
253
  container=False,
254
+ y_lim=[0, None],
255
+ sort=[
256
+ f'Total Memory\n{total_gb} GB',
257
+ f'Parameter Memory\n{param_gb} GB',
258
+ f'Gradient Memory\n{gradient_gb} GB',
259
+ f'Optimizer Memory\n{optimizer_gb} GB',
260
+ f'Activation Memory\n{activation_gb} GB'
261
+ ]
262
  )
263
 
264
+ css = """
265
+ /* Style for disabled components to make them visually obvious */
266
+ .disabled-field input,
267
+ .disabled-field select,
268
+ .disabled-field textarea {
269
+ opacity: 0.4 !important;
270
+ background-color: #f5f5f5 !important;
271
+ color: #999 !important;
272
+ cursor: not-allowed !important;
273
+ text-decoration: line-through;
274
+ }
275
+
276
+ .disabled-field label {
277
+ opacity: 0.5 !important;
278
+ color: #999 !important;
279
+ }
280
+ """
281
 
282
+ theme = Theme.from_hub("gstaff/xkcd")
283
+ # otherwise invisible in light mode
284
+ theme.checkbox_label_text_color=theme.block_label_text_color
285
+ with gr.Blocks(theme=theme, css=css) as demo:
286
  with gr.Column():
287
  with gr.Row(equal_height=True):
288
+ tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy = create_parallelism_block()
289
+ layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings = create_model_block()
290
  seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block()
291
  calculate_button = gr.Button("Calculate")
292
  output = gr.BarPlot(label="Memory Usage Breakdown")
293
 
294
  calculate_button.click(
295
  fn=calculate,
296
+ inputs=[tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype],
297
  outputs=output
298
  )
299
 
300
+ # Limitations and Comments section
301
+ with gr.Row():
302
+ with gr.Column():
303
+ gr.Markdown("# Limitations")
304
+ gr.Markdown(LIMITATIONS)
305
+ with gr.Column():
306
+ gr.Markdown("# Comments and Details")
307
+ gr.Markdown(DETAILS)
308
+
309
 
310
  demo.launch()
calculator.py CHANGED
@@ -1,21 +1,27 @@
1
  from state import Model as Model, Parallelism, Training
2
  from dtypes import DType
 
3
 
4
 
5
  class MemoryCalculation:
6
- def __init__(self, modelconfig: Model, parallelismconfig: Parallelism, trainingconfig: Training):
 
 
 
 
 
7
  self.model = modelconfig
8
  self.parallelism = parallelismconfig
9
  self.training = trainingconfig
10
 
11
- def calculate_num_parameters(self) -> float:
12
  # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
13
  # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers
14
 
15
  # Biases are not added/omitted on a per-model basis for simplicity.
16
  # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
17
 
18
- #self tax
19
  b, s = self.training.batch_size, self.training.sequence_length
20
  h, i, l, v, e = (
21
  self.model.hidden_dim,
@@ -30,45 +36,77 @@ class MemoryCalculation:
30
  self.parallelism.expert_parallelism,
31
  )
32
 
33
- # Embedding layers
34
- input_embedding = v * h / tp
35
- unembedding = 0
36
- if not self.model.weight_tied_embeddings:
37
- unembedding = h * v / tp
38
-
39
  # Attention
40
  # weights and biases = *2
41
- layer_norm_attn_in = 2 * h # not tp sharded
42
  qkv = 3 * h * h / tp
43
- attn_output_proj = h * h + h / tp
44
  attn = layer_norm_attn_in + qkv + attn_output_proj
45
 
46
  # MLP
47
- layer_norm_mlp_in = 2 * h # not tp sharded
48
- router = h * e + e # assuming replicated for simplicity
49
- mlp_up_proj = h * i + i / tp
50
- mlp_gate_proj = h * i + i / tp
51
- mlp_down_proj = i * h + h / tp
52
- expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj
53
- experts = expert * e / ep
54
- mlp = layer_norm_mlp_in + router + experts
 
 
55
 
56
  layer = attn + mlp
57
- layers = layer * l
58
 
59
-
60
- final_layer_norm = 2 * h # not tp sharded
61
-
62
- # pp and weight tying makes knowing where to embed layer challenging
63
- # going to assume "worst" case and it's at the end with final layer norm
64
- # even though that's pretty small
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  total_params = 0
66
  if pp == 1:
67
- total_params = input_embedding + layers + unembedding + final_layer_norm
68
- if pp > 1:
69
- total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm
70
  return total_params
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def calculate_activation_parameters(self) -> float:
73
  # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
74
  # https://arxiv.org/abs/2205.05198
@@ -95,29 +133,24 @@ class MemoryCalculation:
95
  layer = s * b * h / cp / tp # only keep initial input to layer
96
  layers = layer * l
97
  embed = 0
98
- final_layer_out = (
99
- s * b * h / cp / sp
100
- )
101
- final_norm = s * b * h / cp / sp
102
  unembed = s * b * v / cp / tp
103
- logits = s * b * v / cp / sp # come back to this
104
- num_params = (
105
- embed + layers + final_layer_out + final_norm + unembed + logits
106
- )
107
  return num_params
108
  else:
109
  # assume flash attention ie do selective recomputation
110
  # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
111
  # the variables calculate the activation outputs
112
  # Attention Block
113
- layer_in = s * b * h / cp / tp
114
- attn_norm = s * b * h / cp / sp
115
  flash = s * b * h / cp / tp
116
  # everything else is recalculated by flash attention
117
  projection = s * b * h / cp / tp
118
  attn = layer_in + attn_norm + flash + projection
119
  # MLP Block
120
- mlp_norm = s * b * h / cp / sp
121
 
122
  mlp_up = s * b * i / cp / tp
123
  mlp_gate = s * b * i / cp / tp
@@ -125,53 +158,77 @@ class MemoryCalculation:
125
  mlp_down = s * b * h / cp / tp
126
  if self.model.is_moe:
127
  router = (
128
- s * b * e / cp / sp) # makes sense to sp shard if mlp_norm out is sp sharded
 
129
  expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
130
- experts = expert * ae
131
  mlp = mlp_norm + router + experts
132
  else:
133
  mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down
134
  layer = attn + mlp
135
- layers = layer * l # no decrease from PP because schedules will increase microbatches
 
 
136
  # Other
137
  embed = 0
138
  final_layer_out = (
139
  s * b * h / cp / tp
140
  ) # both sequence and context parallelism
141
- final_norm = s * b * h / cp / sp
142
  unembed = s * b * v / cp / tp
143
- logits = s * b * v / cp / tp
144
- num_params = (
145
- embed + layers + final_layer_out + final_norm + unembed + logits
146
- )
147
  return num_params
148
-
149
  def calculate_parameter_memory(self) -> float:
 
 
 
 
150
  if self.training.mixed_precision:
151
- master_copy = self.calculate_num_parameters() * self.training.precision
152
- working_copy = self.calculate_num_parameters() * self.training.param_dtype
153
  return master_copy + working_copy
154
  else:
155
- return self.calculate_num_parameters() * self.training.precision
156
-
157
  def calculate_gradient_memory(self) -> float:
158
  # https://blog.eleuther.ai/transformer-math/#gradients
159
- return (
160
- self.calculate_num_parameters() * 4
161
- ) # gradients are same size as parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  def calculate_optimizer_memory(self) -> float:
164
  # https://blog.eleuther.ai/transformer-math/#optimizer-states
165
  # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
166
- return (
167
- 2 * self.calculate_num_parameters() * DType.FP32
168
- ) # Adam optimizer with 2 states per parameter, assume always fp32
169
-
 
 
 
 
 
170
  def calculate_activation_memory(self) -> float:
171
  if self.training.mixed_precision:
172
  return self.calculate_activation_parameters() * self.training.param_dtype
173
  else:
174
  return (
175
  self.calculate_activation_parameters() * self.training.precision
176
- )
177
-
 
1
  from state import Model as Model, Parallelism, Training
2
  from dtypes import DType
3
+ from math import ceil
4
 
5
 
6
  class MemoryCalculation:
7
+ def __init__(
8
+ self,
9
+ modelconfig: Model,
10
+ parallelismconfig: Parallelism,
11
+ trainingconfig: Training,
12
+ ):
13
  self.model = modelconfig
14
  self.parallelism = parallelismconfig
15
  self.training = trainingconfig
16
 
17
+ def calculate_num_parameters_per_layer(self) -> float:
18
  # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
19
  # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers
20
 
21
  # Biases are not added/omitted on a per-model basis for simplicity.
22
  # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
23
 
24
+ # self tax
25
  b, s = self.training.batch_size, self.training.sequence_length
26
  h, i, l, v, e = (
27
  self.model.hidden_dim,
 
36
  self.parallelism.expert_parallelism,
37
  )
38
 
 
 
 
 
 
 
39
  # Attention
40
  # weights and biases = *2
41
+ layer_norm_attn_in = 2 * h # not tp sharded
42
  qkv = 3 * h * h / tp
43
+ attn_output_proj = (h * h + h) / tp
44
  attn = layer_norm_attn_in + qkv + attn_output_proj
45
 
46
  # MLP
47
+ layer_norm_mlp_in = 2 * h # not tp sharded
48
+ mlp_up_proj = (h * i + i) / tp
49
+ mlp_gate_proj = (h * i + i) / tp
50
+ mlp_down_proj = (i * h + h) / tp
51
+ mlp = layer_norm_mlp_in + mlp_up_proj + mlp_gate_proj + mlp_down_proj
52
+ if self.model.is_moe:
53
+ router = h * e + e # assuming replicated for simplicity
54
+ expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj
55
+ experts = expert * e / ep
56
+ mlp = layer_norm_mlp_in + router + experts
57
 
58
  layer = attn + mlp
59
+ return layer
60
 
61
+ def calculate_unshardeable_parameters(self) -> float:
62
+ b, s = self.training.batch_size, self.training.sequence_length
63
+ h, i, l, v, e = (
64
+ self.model.hidden_dim,
65
+ self.model.intermediate_size,
66
+ self.model.num_layers,
67
+ self.model.vocab_size,
68
+ self.model.total_experts,
69
+ )
70
+ tp, pp, ep = (
71
+ self.parallelism.tensor_parallelism,
72
+ self.parallelism.pipeline_parallelism,
73
+ self.parallelism.expert_parallelism,
74
+ )
75
+ # Embedding layers
76
+ input_embedding = v * h / tp
77
+ unembedding = 0
78
+ if not self.model.weight_tied_embeddings:
79
+ unembedding = h * v / tp
80
+ final_layer_norm = 2 * h # not tp sharded
81
+ # hush linter
82
  total_params = 0
83
  if pp == 1:
84
+ total_params = input_embedding + unembedding + final_layer_norm
85
+ elif pp > 1:
86
+ total_params = max(input_embedding, unembedding) + final_layer_norm
87
  return total_params
88
 
89
+ def calculate_fsdp_sharded_parameters(self) -> float:
90
+ if not self.parallelism.fsdp_enabled:
91
+ return self.calculate_num_parameters()
92
+ else:
93
+ return (
94
+ self.calculate_num_parameters_per_layer()
95
+ * ceil(
96
+ (self.model.num_layers - 1) / self.parallelism.pipeline_parallelism
97
+ )
98
+ / self.parallelism.fsdp_parallelism
99
+ + self.calculate_unshardeable_parameters()
100
+ + self.calculate_num_parameters_per_layer()
101
+ )
102
+
103
+ def calculate_num_parameters(self) -> float:
104
+ return (
105
+ self.calculate_num_parameters_per_layer()
106
+ * ceil(self.model.num_layers / self.parallelism.pipeline_parallelism)
107
+ + self.calculate_unshardeable_parameters()
108
+ )
109
+
110
  def calculate_activation_parameters(self) -> float:
111
  # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
112
  # https://arxiv.org/abs/2205.05198
 
133
  layer = s * b * h / cp / tp # only keep initial input to layer
134
  layers = layer * l
135
  embed = 0
136
+ final_layer_out = s * b * h / cp / sp
137
+ final_norm = s * b * h / cp / sp
 
 
138
  unembed = s * b * v / cp / tp
139
+ num_params = embed + layers + final_layer_out + final_norm + unembed
 
 
 
140
  return num_params
141
  else:
142
  # assume flash attention ie do selective recomputation
143
  # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
144
  # the variables calculate the activation outputs
145
  # Attention Block
146
+ layer_in = s * b * h / cp / tp
147
+ attn_norm = s * b * h / cp / sp
148
  flash = s * b * h / cp / tp
149
  # everything else is recalculated by flash attention
150
  projection = s * b * h / cp / tp
151
  attn = layer_in + attn_norm + flash + projection
152
  # MLP Block
153
+ mlp_norm = s * b * h / cp / sp
154
 
155
  mlp_up = s * b * i / cp / tp
156
  mlp_gate = s * b * i / cp / tp
 
158
  mlp_down = s * b * h / cp / tp
159
  if self.model.is_moe:
160
  router = (
161
+ s * b * e / cp / sp
162
+ ) # makes sense to sp shard if mlp_norm out is sp sharded
163
  expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
164
+ experts = expert * ae / ep
165
  mlp = mlp_norm + router + experts
166
  else:
167
  mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down
168
  layer = attn + mlp
169
+ layers = (
170
+ layer * l
171
+ ) # no decrease from PP because schedules will increase microbatches
172
  # Other
173
  embed = 0
174
  final_layer_out = (
175
  s * b * h / cp / tp
176
  ) # both sequence and context parallelism
177
+ final_norm = s * b * h / cp / sp
178
  unembed = s * b * v / cp / tp
179
+ num_params = embed + layers + final_layer_out + final_norm + unembed
 
 
 
180
  return num_params
181
+
182
  def calculate_parameter_memory(self) -> float:
183
+ if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy == "Zero-3":
184
+ params = self.calculate_fsdp_sharded_parameters()
185
+ else:
186
+ params = self.calculate_num_parameters()
187
  if self.training.mixed_precision:
188
+ master_copy = params * self.training.precision
189
+ working_copy = params * self.training.param_dtype
190
  return master_copy + working_copy
191
  else:
192
+ return params * self.training.precision
193
+
194
  def calculate_gradient_memory(self) -> float:
195
  # https://blog.eleuther.ai/transformer-math/#gradients
196
+ if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy in ("Zero-3", "Zero-2"):
197
+ params = self.calculate_fsdp_sharded_parameters()
198
+ else:
199
+ params = self.calculate_num_parameters()
200
+ grad_accumulation = 0
201
+ if self.training.grad_accumulation:
202
+ if self.training.mixed_precision:
203
+ grad_accumulation = (
204
+ params * self.training.reduce_dtype
205
+ )
206
+ else:
207
+ grad_accumulation = (
208
+ params * self.training.precision
209
+ )
210
+ if self.training.mixed_precision:
211
+ gradients = params * self.training.param_dtype
212
+ else:
213
+ gradients = params * self.training.precision
214
+ return grad_accumulation + gradients
215
 
216
  def calculate_optimizer_memory(self) -> float:
217
  # https://blog.eleuther.ai/transformer-math/#optimizer-states
218
  # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
219
+ if self.parallelism.fsdp_enabled:
220
+ return (
221
+ 2 * self.calculate_num_parameters() * DType.FP32
222
+ ) / self.parallelism.fsdp_parallelism # don't gather a layer unlike params and grads
223
+ else:
224
+ return (
225
+ 2 * self.calculate_num_parameters() * DType.FP32
226
+ ) # Adam optimizer with 2 states per parameter, assume always fp32
227
+
228
  def calculate_activation_memory(self) -> float:
229
  if self.training.mixed_precision:
230
  return self.calculate_activation_parameters() * self.training.param_dtype
231
  else:
232
  return (
233
  self.calculate_activation_parameters() * self.training.precision
234
+ ) # not impacted by fsdp
 
defaults.py CHANGED
@@ -1,24 +1,116 @@
1
  from state import Model
2
 
 
3
  GEMMA3_270M = Model(
4
- vocab_size=256000, num_layers=9, hidden_dim=1152, intermediate_size=4608,
5
- weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
 
 
 
 
 
 
6
  )
7
  GEMMA3_1B = Model(
8
- vocab_size=262208, num_layers=26, hidden_dim=2304, intermediate_size=9216,
9
- weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
 
 
 
 
 
 
10
  )
11
  GEMMA3_4B = Model(
12
- vocab_size=262208, num_layers=28, hidden_dim=3072, intermediate_size=12288,
13
- weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
 
 
 
 
 
 
14
  )
15
  GEMMA3_12B = Model(
16
- vocab_size=262208, num_layers=42, hidden_dim=4608, intermediate_size=18432,
17
- weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
 
 
 
 
 
 
18
  )
19
  GEMMA3_27B = Model(
20
- vocab_size=262208, num_layers=46, hidden_dim=6144, intermediate_size=24576,
21
- weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
  DEFAULTS = {
@@ -27,4 +119,9 @@ DEFAULTS = {
27
  "Gemma3 4B": GEMMA3_4B,
28
  "Gemma3 12B": GEMMA3_12B,
29
  "Gemma3 27B": GEMMA3_27B,
 
 
 
 
 
30
  }
 
1
  from state import Model
2
 
3
+ # https://huggingface.co/google/gemma-3-270m/blob/main/config.json
4
  GEMMA3_270M = Model(
5
+ vocab_size=262144,
6
+ num_layers=18,
7
+ hidden_dim=640,
8
+ intermediate_size=2048,
9
+ weight_tied_embeddings=True,
10
+ active_experts=1,
11
+ total_experts=1,
12
+ is_moe=False,
13
  )
14
  GEMMA3_1B = Model(
15
+ vocab_size=262144,
16
+ num_layers=26,
17
+ hidden_dim=1152,
18
+ intermediate_size=6912,
19
+ weight_tied_embeddings=True,
20
+ active_experts=1,
21
+ total_experts=1,
22
+ is_moe=False,
23
  )
24
  GEMMA3_4B = Model(
25
+ vocab_size=262144,
26
+ num_layers=34,
27
+ hidden_dim=2560,
28
+ intermediate_size=10240,
29
+ weight_tied_embeddings=True,
30
+ active_experts=1,
31
+ total_experts=1,
32
+ is_moe=False,
33
  )
34
  GEMMA3_12B = Model(
35
+ vocab_size=262144,
36
+ num_layers=48,
37
+ hidden_dim=3840,
38
+ intermediate_size=15360,
39
+ weight_tied_embeddings=True,
40
+ active_experts=1,
41
+ total_experts=1,
42
+ is_moe=False,
43
  )
44
  GEMMA3_27B = Model(
45
+ vocab_size=262144,
46
+ num_layers=62,
47
+ hidden_dim=5376,
48
+ intermediate_size=21504,
49
+ weight_tied_embeddings=True,
50
+ active_experts=1,
51
+ total_experts=1,
52
+ is_moe=False,
53
+ )
54
+ # No maverick, don't support non-homogenous layers yet
55
+
56
+ # https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
57
+ LLAMA4_SCOUT = Model(
58
+ vocab_size=202048,
59
+ num_layers=48,
60
+ hidden_dim=5120,
61
+ intermediate_size=8192,
62
+ weight_tied_embeddings=True,
63
+ active_experts=2,
64
+ total_experts=17,
65
+ is_moe=True,
66
+ )
67
+
68
+ # https://huggingface.co/unsloth/Llama-3.2-1B-Instruct/blob/main/config.json
69
+ LLAMA3_1B = Model(
70
+ vocab_size=128256,
71
+ num_layers=16,
72
+ hidden_dim=2048,
73
+ intermediate_size=8192,
74
+ weight_tied_embeddings=True,
75
+ active_experts=1,
76
+ total_experts=1,
77
+ is_moe=False,
78
+ )
79
+
80
+ # https://huggingface.co/unsloth/Llama-3.2-3B-Instruct/blob/main/config.json
81
+ LLAMA3_3B = Model(
82
+ vocab_size=128256,
83
+ num_layers=28,
84
+ hidden_dim=3072,
85
+ intermediate_size=8192,
86
+ weight_tied_embeddings=True,
87
+ active_experts=1,
88
+ total_experts=1,
89
+ is_moe=False,
90
+ )
91
+
92
+ # https://huggingface.co/unsloth/llama-3-8b-Instruct/blob/main/config.json
93
+ LLAMA3_8B = Model(
94
+ vocab_size=128256,
95
+ num_layers=32,
96
+ hidden_dim=4096,
97
+ intermediate_size=14336,
98
+ weight_tied_embeddings=True,
99
+ active_experts=1,
100
+ total_experts=1,
101
+ is_moe=False,
102
+ )
103
+
104
+ # https://huggingface.co/unsloth/Llama-3.3-70B-Instruct/blob/main/config.json
105
+ LLAMA3_70B = Model(
106
+ vocab_size=128256,
107
+ num_layers=80,
108
+ hidden_dim=8192,
109
+ intermediate_size=28672,
110
+ weight_tied_embeddings=True,
111
+ active_experts=1,
112
+ total_experts=1,
113
+ is_moe=False,
114
  )
115
 
116
  DEFAULTS = {
 
119
  "Gemma3 4B": GEMMA3_4B,
120
  "Gemma3 12B": GEMMA3_12B,
121
  "Gemma3 27B": GEMMA3_27B,
122
+ "Llama3 1B": LLAMA3_1B,
123
+ "Llama3 3B": LLAMA3_3B,
124
+ "Llama3 8B": LLAMA3_8B,
125
+ "Llama3 70B": LLAMA3_70B,
126
+ "Llama4 Scout": LLAMA4_SCOUT,
127
  }
details.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DETAILS = """
2
+ ### Resources I found helpful while building this tool:
3
+ - [The Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook)
4
+ - [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198)
5
+ - [Transformer Math - Michael Wornow](https://michaelwornow.net/2024/01/18/counting-params-in-transformer)
6
+ - [Transformer Math 101](https://blog.eleuther.ai/transformer-math/)
7
+
8
+
9
+ ### Why this tool?
10
+ While there are some good tools out there already:
11
+ - [Hugging Face Model Memory Estimator](https://huggingface.co/spaces/hf-accelerate/model-memory-usage)
12
+ - [DeepSpeed Model Memory Calculator](https://huggingface.co/spaces/andstor/deepspeed-model-memory-usage)
13
+ - [DeepSpeed Native Utility](https://deepspeed.readthedocs.io/en/latest/memory.html)
14
+
15
+ None of them had all the features I wanted in one place. I wanted a tool that could:
16
+ - Accept arbitrary model configurations
17
+ - Support FSDP
18
+ - Support 5d parallelism
19
+ - Be interactive and break down memory usage by category, to better inform configurations.
20
+
21
+ """
dtypes.py CHANGED
@@ -31,4 +31,4 @@ class DType(Enum):
31
 
32
  def __rmul__(self, other):
33
  """Multiply number by dtype to get total bytes"""
34
- return other * self.bytes_per_element()
 
31
 
32
  def __rmul__(self, other):
33
  """Multiply number by dtype to get total bytes"""
34
+ return other * self.bytes_per_element()
limitations.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LIMITATIONS = """
2
+ This calculator has many limitations and assumptions
3
+ ### Assumptions:
4
+ - Your implementation of tensor parallel also incorporates sequence parallel
5
+ - You are doing selective recomputation with flash attention if not doing gradient checkpointing
6
+ - You keep a master copy of the model weights for mixed precision
7
+ - May not be true for some implementations which cast on the fly
8
+ - You're using Adam optimizer
9
+ - If using PP you're using a schedule that will keep the number of activations roughly the same
10
+ - EP is the number of PPxTP units that share each expert
11
+ - Swiglu activation function
12
+ - Rotary embeddings
13
+
14
+ ### Limitations:
15
+ - Does not support non-homogenous layers
16
+ - e.g. Llama4 Maverick with alternating dense and sparse layers, iRoPE
17
+ - Does not include memory for kernel or framework overhead
18
+ - Does not include memory for intermediates
19
+ - Does not include vision layers for multi-modal models
20
+ - Models shared experts as another routed expert per token
21
+ - Does not support different dtypes for different parts of the model
22
+ - e.g. MXFP4 for GPT-OSS 20 and 120B
23
+ - Have not validated EP/FSDP interaction
24
+ - Doesn't model biases on a per-model basis
25
+
26
+ Note this is not an exhaustive list, just some of the main ones
27
+ """
state.py CHANGED
@@ -20,6 +20,9 @@ class Parallelism:
20
  pipeline_parallelism: int
21
  context_parallelism: int
22
  expert_parallelism: int
 
 
 
23
 
24
 
25
  @dataclass
 
20
  pipeline_parallelism: int
21
  context_parallelism: int
22
  expert_parallelism: int
23
+ fsdp_enabled: bool
24
+ fsdp_parallelism: int
25
+ fsdp_strategy: str
26
 
27
 
28
  @dataclass