rubenaghayan commited on
Commit
b79954f
·
1 Parent(s): ddb0136

finish up calcs

Browse files
Files changed (4) hide show
  1. app.py +14 -8
  2. calculator.py +162 -0
  3. defaults.py +16 -7
  4. state.py +3 -1
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import gradio as gr
2
  from defaults import DEFAULTS
3
 
4
- def greet(name, intensity)->str:
 
5
  return "Hello, " + name + "!" * int(intensity)
6
 
 
7
  def create_parallelism_block():
8
  with gr.Column():
9
  gr.Markdown("# Parallelism Parameters")
@@ -13,29 +15,33 @@ def create_parallelism_block():
13
  ep = gr.Number(label="Expert Parallelism", value=1, interactive=True)
14
  return tp, pp, cp, ep
15
 
 
16
  def create_model_block():
17
  with gr.Column():
18
  gr.Markdown("# Model Parameters")
19
  layers = gr.Number(label="Number of Layers", value=32, interactive=True)
20
  vocab = gr.Number(label="Vocab Size", value=32000, interactive=True)
21
  hidden = gr.Number(label="Hidden Dim", value=4096, interactive=True)
22
- intermediate = gr.Number(label="Intermediate Dim", value=11008, interactive=True)
 
 
23
  presets = gr.Dropdown(list(DEFAULTS.keys()), label="Presets", interactive=True)
24
  return layers, vocab, hidden, intermediate, presets
25
 
 
26
  def create_training_block():
27
  with gr.Column():
28
- gr.Markdown('# Training Parameters')
29
  seq_len = gr.Number(label="Sequence Length", value=8192, interactive=True)
30
  batch_size = gr.Number(label="Batch Size", value=8, interactive=True)
31
  return seq_len, batch_size
32
 
33
- def calculate(*args)->int:
 
34
  out = 1
35
  for arg in args:
36
  out *= arg
37
  return arg
38
-
39
 
40
 
41
  with gr.Blocks() as demo:
@@ -46,8 +52,8 @@ with gr.Blocks() as demo:
46
  seq_len, batch_size = create_training_block()
47
  calculate_button = gr.Button("Calculate")
48
  output = gr.Number(label="Output")
49
-
50
- calculate_button.click(fn=calculate, inputs=[tp,pp,cp,ep],outputs=output)
51
-
52
 
53
  demo.launch()
 
1
  import gradio as gr
2
  from defaults import DEFAULTS
3
 
4
+
5
+ def greet(name, intensity) -> str:
6
  return "Hello, " + name + "!" * int(intensity)
7
 
8
+
9
  def create_parallelism_block():
10
  with gr.Column():
11
  gr.Markdown("# Parallelism Parameters")
 
15
  ep = gr.Number(label="Expert Parallelism", value=1, interactive=True)
16
  return tp, pp, cp, ep
17
 
18
+
19
  def create_model_block():
20
  with gr.Column():
21
  gr.Markdown("# Model Parameters")
22
  layers = gr.Number(label="Number of Layers", value=32, interactive=True)
23
  vocab = gr.Number(label="Vocab Size", value=32000, interactive=True)
24
  hidden = gr.Number(label="Hidden Dim", value=4096, interactive=True)
25
+ intermediate = gr.Number(
26
+ label="Intermediate Dim", value=11008, interactive=True
27
+ )
28
  presets = gr.Dropdown(list(DEFAULTS.keys()), label="Presets", interactive=True)
29
  return layers, vocab, hidden, intermediate, presets
30
 
31
+
32
  def create_training_block():
33
  with gr.Column():
34
+ gr.Markdown("# Training Parameters")
35
  seq_len = gr.Number(label="Sequence Length", value=8192, interactive=True)
36
  batch_size = gr.Number(label="Batch Size", value=8, interactive=True)
37
  return seq_len, batch_size
38
 
39
+
40
+ def calculate(*args) -> int:
41
  out = 1
42
  for arg in args:
43
  out *= arg
44
  return arg
 
45
 
46
 
47
  with gr.Blocks() as demo:
 
52
  seq_len, batch_size = create_training_block()
53
  calculate_button = gr.Button("Calculate")
54
  output = gr.Number(label="Output")
55
+
56
+ calculate_button.click(fn=calculate, inputs=[tp, pp, cp, ep], outputs=output)
57
+
58
 
59
  demo.launch()
calculator.py CHANGED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from state import Model as Model, Parallelism, Training
2
+
3
+
4
+ class MemoryCalculation:
5
+ def __init__(self, model: Model, parallelism: Parallelism, training: Training):
6
+ self.model = model
7
+ self.parallelism = parallelism
8
+ self.training = training
9
+
10
+ def calculate_num_parameters(self) -> float:
11
+ # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
12
+ # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers
13
+
14
+ # Biases are not added/omitted on a per-model basis for simplicity.
15
+ # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
16
+
17
+ b, s = self.training.batch_size, self.training.sequence_length
18
+ h, i, l, v, e = (
19
+ self.model.hidden_dim,
20
+ self.model.intermediate_size,
21
+ self.model.num_layers,
22
+ self.model.vocab_size,
23
+ self.model.experts,
24
+ )
25
+ tp, pp, ep = (
26
+ self.parallelism.tensor_parallelism,
27
+ self.parallelism.pipeline_parallelism,
28
+ self.parallelism.expert_parallelism,
29
+ )
30
+
31
+ # Embedding layers
32
+ input_embedding = v * h / tp
33
+ unembedding = 0
34
+ if not self.model.weight_tied_embeddings:
35
+ unembedding = h * v / tp
36
+
37
+ # Attention
38
+ # weights and biases = *2
39
+ layer_norm_attn_in = 2 * h # not tp sharded
40
+ qkv = 3 * h * h / tp
41
+ attn_output_proj = h * h + h / tp
42
+ attn = layer_norm_attn_in + qkv + attn_output_proj
43
+
44
+ # MLP
45
+ layer_norm_mlp_in = 2 * h # not tp sharded
46
+ router = h * e + e # assuming replicated for simplicity
47
+ mlp_up_proj = h * i + i / tp
48
+ mlp_gate_proj = h * i + i / tp
49
+ mlp_down_proj = i * h + h / tp
50
+ expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj
51
+ experts = expert * e / ep
52
+ mlp = layer_norm_mlp_in + router + experts
53
+
54
+ layer = attn + mlp
55
+ layers = layer * l
56
+
57
+
58
+ final_layer_norm = 2 * h # not tp sharded
59
+
60
+ # pp and weight tying makes knowing where to embed layer challenging
61
+ # going to assume "worst" case and it's at the end with final layer norm
62
+ # even though that's pretty smalle
63
+ if pp == 1:
64
+ total_params = input_embedding + layers + unembedding + final_layer_norm
65
+ if pp > 1:
66
+ total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm
67
+ return total_params
68
+
69
+ def calculate_parameter_memory(self) -> float:
70
+ return (
71
+ self.calculate_num_parameters() * 4
72
+ ) # assuming 4 bytes (32 bits) per parameter
73
+
74
+ def calculate_activation_parameters(self) -> float:
75
+ # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
76
+ # https://arxiv.org/abs/2205.05198
77
+ # pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble
78
+ b, s = self.training.batch_size, self.training.sequence_length
79
+ h, i, l, v, e, ae = (
80
+ self.model.hidden_dim,
81
+ self.model.intermediate_size,
82
+ self.model.num_layers,
83
+ self.model.vocab_size,
84
+ self.model.active_experts,
85
+ )
86
+ tp, cp, pp, ep = (
87
+ self.parallelism.tensor_parallelism,
88
+ self.parallelism.context_parallelism,
89
+ self.parallelism.pipeline_parallelism,
90
+ self.parallelism.expert_parallelism,
91
+ )
92
+ if self.training.gradient_checkpointing:
93
+ # full recomputation
94
+ embed = 0
95
+ layer = s * b * h / cp / tp # only keep initial input to layer
96
+ layers = layer * l
97
+ embed = 0
98
+ final_layer_out = (
99
+ s * b * h / cp / tp
100
+ ) # both sequence and tensor parallelism
101
+ final_norm = s * b * h / cp / tp # both sequence and tensor parallelism
102
+ unembed = s * b * v / cp / tp
103
+ logits = s * b * v / cp / tp # both vocab and tensor parallelism
104
+ num_params = (
105
+ embed + layers + final_layer_out + final_norm + unembed + logits
106
+ )
107
+ return num_params
108
+ else:
109
+ # assume flash attention ie do selective recomputation
110
+ # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
111
+ # the variables calculate the activation outputs
112
+ # Attention Block
113
+ layer_in = s * b * h / cp / tp # both sequence and context parallelism
114
+ attn_norm = s * b * h / cp / tp # both sequence and context parallelism
115
+ flash = s * b * h / cp / tp
116
+ # everything else is recalculated by flash attention
117
+ projection = s * b * h / cp / tp
118
+ attn = layer_in + attn_norm + flash + projection
119
+ # MLP Block
120
+ mlp_norm = s * b * h / cp / tp # both sequence and context parallelism
121
+ router = (
122
+ s * b * e / cp / tp
123
+ ) # makes sense to sp shard if mlp_norm out is sp sharded
124
+ mlp_up = s * b * i / cp / tp
125
+ mlp_gate = s * b * i / cp / tp
126
+ hadamard_swiglu = s * b * i / cp / tp
127
+ mlp_down = s * b * h / cp / tp
128
+ expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
129
+ experts = expert * ae
130
+ mlp = mlp_norm + router + experts
131
+ layer = attn + mlp
132
+ layers = layer * l
133
+ # Other
134
+ embed = 0
135
+ final_layer_out = (
136
+ s * b * h / cp / tp
137
+ ) # both sequence and context parallelism
138
+ final_norm = s * b * h / cp / tp # both sequence and context parallelism
139
+ unembed = s * b * v / cp / tp
140
+ logits = s * b * v / cp / tp
141
+ num_params = (
142
+ embed + layers + final_layer_out + final_norm + unembed + logits
143
+ )
144
+ return num_params
145
+
146
+ def calculate_activation_memory(self) -> float:
147
+ return (
148
+ self.calculate_activation_parameters() * 4
149
+ ) # assuming 4 bytes (32 bits) per activation
150
+
151
+ def calculate_gradient_memory(self) -> float:
152
+ # https://blog.eleuther.ai/transformer-math/#gradients
153
+ return (
154
+ self.calculate_parameter_memory()
155
+ ) # gradients are same size as parameters
156
+
157
+ def calculate_optimizer_memory(self) -> float:
158
+ # https://blog.eleuther.ai/transformer-math/#optimizer-states
159
+ # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
160
+ return (
161
+ 2 * self.calculate_parameter_memory()
162
+ ) # Adam optimizer with 3 states per parameter
defaults.py CHANGED
@@ -1,16 +1,25 @@
1
  from state import ModelState
2
 
3
- GEMMA3_270M = ModelState(vocab_size=256000, num_layers=9, hidden_dim=1152, intermediate_size=4608)
4
- GEMMA3_1B = ModelState(vocab_size=262208, num_layers=26, hidden_dim=2304, intermediate_size=9216)
5
- GEMMA3_4B = ModelState(vocab_size=262208, num_layers=28, hidden_dim=3072, intermediate_size=12288)
6
- GEMMA3_12B = ModelState(vocab_size=262208, num_layers=42, hidden_dim=4608, intermediate_size=18432)
7
- GEMMA3_27B = ModelState(vocab_size=262208, num_layers=46, hidden_dim=6144, intermediate_size=24576)
 
 
 
 
 
 
 
 
 
 
8
 
9
  DEFAULTS = {
10
  "Gemma3 270M": GEMMA3_270M,
11
  "Gemma3 1B": GEMMA3_1B,
12
  "Gemma3 4B": GEMMA3_4B,
13
  "Gemma3 12B": GEMMA3_12B,
14
- "Gemma3 27B": GEMMA3_27B
15
  }
16
-
 
1
  from state import ModelState
2
 
3
+ GEMMA3_270M = ModelState(
4
+ vocab_size=256000, num_layers=9, hidden_dim=1152, intermediate_size=4608
5
+ )
6
+ GEMMA3_1B = ModelState(
7
+ vocab_size=262208, num_layers=26, hidden_dim=2304, intermediate_size=9216
8
+ )
9
+ GEMMA3_4B = ModelState(
10
+ vocab_size=262208, num_layers=28, hidden_dim=3072, intermediate_size=12288
11
+ )
12
+ GEMMA3_12B = ModelState(
13
+ vocab_size=262208, num_layers=42, hidden_dim=4608, intermediate_size=18432
14
+ )
15
+ GEMMA3_27B = ModelState(
16
+ vocab_size=262208, num_layers=46, hidden_dim=6144, intermediate_size=24576
17
+ )
18
 
19
  DEFAULTS = {
20
  "Gemma3 270M": GEMMA3_270M,
21
  "Gemma3 1B": GEMMA3_1B,
22
  "Gemma3 4B": GEMMA3_4B,
23
  "Gemma3 12B": GEMMA3_12B,
24
+ "Gemma3 27B": GEMMA3_27B,
25
  }
 
state.py CHANGED
@@ -1,5 +1,6 @@
1
  from dataclasses import dataclass
2
 
 
3
  @dataclass
4
  class Model:
5
  vocab_size: int
@@ -16,7 +17,8 @@ class Parallelism:
16
  context_parallelism: int
17
  expert_parallelism: int
18
 
 
19
  @dataclass
20
  class Training:
21
  sequence_length: int
22
- batch_size: int
 
1
  from dataclasses import dataclass
2
 
3
+
4
  @dataclass
5
  class Model:
6
  vocab_size: int
 
17
  context_parallelism: int
18
  expert_parallelism: int
19
 
20
+
21
  @dataclass
22
  class Training:
23
  sequence_length: int
24
+ batch_size: int