rubenaghayan commited on
Commit
97e312a
·
1 Parent(s): b79954f

added support for precision

Browse files
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+
6
+ # Virtual environments
7
+ .env
8
+ venv/
9
+ .venv/
10
+
11
+ # Linting/Formatting
12
+ .ruff_cache/
13
+
14
+ # macOS
15
+ .DS_Store
16
+
17
+ # IDE
18
+ .vscode/
19
+ .idea/
20
+
21
+ # Claude Code
22
+ claude/
23
+
24
+ # Gradio
25
+ gradio_cached_examples/
26
+ flagged/
__pycache__/defaults.cpython-311.pyc DELETED
Binary file (805 Bytes)
 
__pycache__/state.cpython-311.pyc DELETED
Binary file (639 Bytes)
 
app.py CHANGED
@@ -1,5 +1,13 @@
1
  import gradio as gr
 
 
2
  from defaults import DEFAULTS
 
 
 
 
 
 
3
 
4
 
5
  def greet(name, intensity) -> str:
@@ -8,52 +16,143 @@ def greet(name, intensity) -> str:
8
 
9
  def create_parallelism_block():
10
  with gr.Column():
11
- gr.Markdown("# Parallelism Parameters")
12
- tp = gr.Number(label="Tensor Parallelism", value=1, interactive=True)
13
- pp = gr.Number(label="Pipeline Parallelism", value=1, interactive=True)
14
- cp = gr.Number(label="Context Parallelism", value=1, interactive=True)
15
- ep = gr.Number(label="Expert Parallelism", value=1, interactive=True)
16
- return tp, pp, cp, ep
 
17
 
18
 
19
  def create_model_block():
20
  with gr.Column():
21
- gr.Markdown("# Model Parameters")
22
- layers = gr.Number(label="Number of Layers", value=32, interactive=True)
23
- vocab = gr.Number(label="Vocab Size", value=32000, interactive=True)
24
- hidden = gr.Number(label="Hidden Dim", value=4096, interactive=True)
25
- intermediate = gr.Number(
26
- label="Intermediate Dim", value=11008, interactive=True
 
 
 
 
 
 
 
 
27
  )
28
- presets = gr.Dropdown(list(DEFAULTS.keys()), label="Presets", interactive=True)
29
- return layers, vocab, hidden, intermediate, presets
 
 
30
 
31
 
32
  def create_training_block():
33
  with gr.Column():
34
- gr.Markdown("# Training Parameters")
35
- seq_len = gr.Number(label="Sequence Length", value=8192, interactive=True)
36
- batch_size = gr.Number(label="Batch Size", value=8, interactive=True)
37
- return seq_len, batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def calculate(*args) -> int:
41
- out = 1
42
- for arg in args:
43
- out *= arg
44
- return arg
 
45
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  with gr.Column():
49
- with gr.Row():
50
  tp, pp, cp, ep = create_parallelism_block()
51
- layers, vocab, hidden, intermediate, presets = create_model_block()
52
- seq_len, batch_size = create_training_block()
53
  calculate_button = gr.Button("Calculate")
54
- output = gr.Number(label="Output")
55
 
56
- calculate_button.click(fn=calculate, inputs=[tp, pp, cp, ep], outputs=output)
 
 
 
 
57
 
58
 
59
  demo.launch()
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ from functools import partial
4
  from defaults import DEFAULTS
5
+ from state import Model, Parallelism, Training
6
+ from calculator import MemoryCalculation
7
+ from dtypes import DType
8
+
9
+ # Create a Number component for natural numbers (positive integers)
10
+ NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True)
11
 
12
 
13
  def greet(name, intensity) -> str:
 
16
 
17
  def create_parallelism_block():
18
  with gr.Column():
19
+ gr.Markdown("# Parallelism")
20
+ with gr.Group():
21
+ tp = NaturalNumber(label="Tensor Parallelism", value=1)
22
+ pp = NaturalNumber(label="Pipeline Parallelism", value=1)
23
+ cp = NaturalNumber(label="Context Parallelism", value=1)
24
+ ep = NaturalNumber(label="Expert Parallelism", value=1)
25
+ return tp, pp, cp, ep
26
 
27
 
28
  def create_model_block():
29
  with gr.Column():
30
+ gr.Markdown("# Model Architecture")
31
+ layers = NaturalNumber(label="Number of Layers", value=32)
32
+ vocab = NaturalNumber(label="Vocab Size", value=32000)
33
+ hidden = NaturalNumber(label="Hidden Dim", value=4096)
34
+ intermediate = NaturalNumber(label="Intermediate Dim", value=11008)
35
+ is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False)
36
+ active_experts = NaturalNumber(label="Active Experts", value=2, visible=False)
37
+ total_experts = NaturalNumber(label="Total Experts", value=8, visible=False)
38
+
39
+ # Toggle expert fields visibility based on MoE checkbox
40
+ is_moe.change(
41
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
42
+ inputs=is_moe,
43
+ outputs=[active_experts, total_experts]
44
  )
45
+
46
+ # not ready yet
47
+ # presets = gr.Dropdown(list(DEFAULTS.keys()), label="Presets", interactive=True)
48
+ return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets
49
 
50
 
51
  def create_training_block():
52
  with gr.Column():
53
+ gr.Markdown("# Training Config")
54
+ seq_len = NaturalNumber(label="Sequence Length", value=8192)
55
+ batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=8)
56
+ with gr.Row():
57
+ gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=False)
58
+ grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False)
59
+ precision = gr.Dropdown(DType.values(), label="Precision", value=DType.FP32.value, interactive=True)
60
+ mixed_precision = gr.Checkbox(label="Mixed Precision", value=False)
61
+ param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=True, visible=False)
62
+ reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=True, visible=False)
63
+
64
+ # Toggle dtype fields visibility based on mixed precision checkbox
65
+ mixed_precision.change(
66
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
67
+ inputs=mixed_precision,
68
+ outputs=[param_dtype, reduce_dtype]
69
+ )
70
+
71
+ return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype
72
+
73
 
74
+ def calculate(tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype):
75
+ # Create state objects
76
+ model_config = Model(
77
+ vocab_size=int(vocab),
78
+ num_layers=int(layers),
79
+ hidden_dim=int(hidden),
80
+ intermediate_size=int(intermediate),
81
+ weight_tied_embeddings=True, # Default assumption
82
+ active_experts=int(active_experts),
83
+ total_experts=int(total_experts),
84
+ is_moe=is_moe
85
+ )
86
 
87
+ parallelism_config = Parallelism(
88
+ tensor_parallelism=int(tp),
89
+ pipeline_parallelism=int(pp),
90
+ context_parallelism=int(cp),
91
+ expert_parallelism=int(ep)
92
+ )
93
 
94
+ training_config = Training(
95
+ sequence_length=int(seq_len),
96
+ batch_size=int(batch_size),
97
+ gradient_checkpointing=gradient_checkpointing,
98
+ grad_accumulation=grad_accumulation,
99
+ precision=DType(precision),
100
+ mixed_precision=mixed_precision,
101
+ param_dtype=DType(param_dtype),
102
+ reduce_dtype=DType(reduce_dtype)
103
+ )
104
 
105
+ # Calculate different memory components
106
+ calc = MemoryCalculation(model_config, parallelism_config, training_config)
107
+
108
+ # Get all memory calculations
109
+ param_memory = calc.calculate_parameter_memory()
110
+ activation_memory = calc.calculate_activation_memory()
111
+ gradient_memory = calc.calculate_gradient_memory()
112
+ optimizer_memory = calc.calculate_optimizer_memory()
113
+
114
+ # Create DataFrame for bar plot
115
+ memory_data = pd.DataFrame({
116
+ 'Component': [
117
+ 'Parameter Memory',
118
+ 'Activation Memory',
119
+ 'Gradient Memory',
120
+ 'Optimizer Memory'
121
+ ],
122
+ 'Memory (GB)': [
123
+ param_memory / 1e9,
124
+ activation_memory / 1e9,
125
+ gradient_memory / 1e9,
126
+ optimizer_memory / 1e9
127
+ ]
128
+ })
129
+
130
+ return gr.BarPlot(
131
+ value=memory_data,
132
+ x="Component",
133
+ y="Memory (GB)",
134
+ title="LLM Memory Usage Breakdown",
135
+ container=False,
136
+ y_lim=[0, None]
137
+ )
138
+
139
+
140
+ with gr.Blocks(theme='gstaff/xkcd') as demo:
141
+ with gr.Sidebar():
142
+ gr.Textbox("## LLM Memory Visualizer")
143
  with gr.Column():
144
+ with gr.Row(equal_height=True):
145
  tp, pp, cp, ep = create_parallelism_block()
146
+ layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets = create_model_block()
147
+ seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block()
148
  calculate_button = gr.Button("Calculate")
149
+ output = gr.BarPlot(label="Memory Usage Breakdown")
150
 
151
+ calculate_button.click(
152
+ fn=calculate,
153
+ inputs=[tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype],
154
+ outputs=output
155
+ )
156
 
157
 
158
  demo.launch()
calculator.py CHANGED
@@ -1,11 +1,12 @@
1
  from state import Model as Model, Parallelism, Training
 
2
 
3
 
4
  class MemoryCalculation:
5
- def __init__(self, model: Model, parallelism: Parallelism, training: Training):
6
- self.model = model
7
- self.parallelism = parallelism
8
- self.training = training
9
 
10
  def calculate_num_parameters(self) -> float:
11
  # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
@@ -14,13 +15,14 @@ class MemoryCalculation:
14
  # Biases are not added/omitted on a per-model basis for simplicity.
15
  # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
16
 
 
17
  b, s = self.training.batch_size, self.training.sequence_length
18
  h, i, l, v, e = (
19
  self.model.hidden_dim,
20
  self.model.intermediate_size,
21
  self.model.num_layers,
22
  self.model.vocab_size,
23
- self.model.experts,
24
  )
25
  tp, pp, ep = (
26
  self.parallelism.tensor_parallelism,
@@ -59,18 +61,14 @@ class MemoryCalculation:
59
 
60
  # pp and weight tying makes knowing where to embed layer challenging
61
  # going to assume "worst" case and it's at the end with final layer norm
62
- # even though that's pretty smalle
 
63
  if pp == 1:
64
  total_params = input_embedding + layers + unembedding + final_layer_norm
65
  if pp > 1:
66
  total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm
67
  return total_params
68
 
69
- def calculate_parameter_memory(self) -> float:
70
- return (
71
- self.calculate_num_parameters() * 4
72
- ) # assuming 4 bytes (32 bits) per parameter
73
-
74
  def calculate_activation_parameters(self) -> float:
75
  # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
76
  # https://arxiv.org/abs/2205.05198
@@ -81,6 +79,7 @@ class MemoryCalculation:
81
  self.model.intermediate_size,
82
  self.model.num_layers,
83
  self.model.vocab_size,
 
84
  self.model.active_experts,
85
  )
86
  tp, cp, pp, ep = (
@@ -89,6 +88,7 @@ class MemoryCalculation:
89
  self.parallelism.pipeline_parallelism,
90
  self.parallelism.expert_parallelism,
91
  )
 
92
  if self.training.gradient_checkpointing:
93
  # full recomputation
94
  embed = 0
@@ -96,11 +96,11 @@ class MemoryCalculation:
96
  layers = layer * l
97
  embed = 0
98
  final_layer_out = (
99
- s * b * h / cp / tp
100
- ) # both sequence and tensor parallelism
101
- final_norm = s * b * h / cp / tp # both sequence and tensor parallelism
102
  unembed = s * b * v / cp / tp
103
- logits = s * b * v / cp / tp # both vocab and tensor parallelism
104
  num_params = (
105
  embed + layers + final_layer_out + final_norm + unembed + logits
106
  )
@@ -110,53 +110,68 @@ class MemoryCalculation:
110
  # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
111
  # the variables calculate the activation outputs
112
  # Attention Block
113
- layer_in = s * b * h / cp / tp # both sequence and context parallelism
114
- attn_norm = s * b * h / cp / tp # both sequence and context parallelism
115
  flash = s * b * h / cp / tp
116
  # everything else is recalculated by flash attention
117
  projection = s * b * h / cp / tp
118
  attn = layer_in + attn_norm + flash + projection
119
  # MLP Block
120
- mlp_norm = s * b * h / cp / tp # both sequence and context parallelism
121
- router = (
122
- s * b * e / cp / tp
123
- ) # makes sense to sp shard if mlp_norm out is sp sharded
124
  mlp_up = s * b * i / cp / tp
125
  mlp_gate = s * b * i / cp / tp
126
  hadamard_swiglu = s * b * i / cp / tp
127
  mlp_down = s * b * h / cp / tp
128
- expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
129
- experts = expert * ae
130
- mlp = mlp_norm + router + experts
 
 
 
 
 
131
  layer = attn + mlp
132
- layers = layer * l
133
  # Other
134
  embed = 0
135
  final_layer_out = (
136
  s * b * h / cp / tp
137
  ) # both sequence and context parallelism
138
- final_norm = s * b * h / cp / tp # both sequence and context parallelism
139
  unembed = s * b * v / cp / tp
140
  logits = s * b * v / cp / tp
141
  num_params = (
142
  embed + layers + final_layer_out + final_norm + unembed + logits
143
  )
144
  return num_params
145
-
146
- def calculate_activation_memory(self) -> float:
147
- return (
148
- self.calculate_activation_parameters() * 4
149
- ) # assuming 4 bytes (32 bits) per activation
150
-
 
 
 
151
  def calculate_gradient_memory(self) -> float:
152
  # https://blog.eleuther.ai/transformer-math/#gradients
153
  return (
154
- self.calculate_parameter_memory()
155
  ) # gradients are same size as parameters
156
 
157
  def calculate_optimizer_memory(self) -> float:
158
  # https://blog.eleuther.ai/transformer-math/#optimizer-states
159
  # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
160
  return (
161
- 2 * self.calculate_parameter_memory()
162
- ) # Adam optimizer with 3 states per parameter
 
 
 
 
 
 
 
 
 
 
1
  from state import Model as Model, Parallelism, Training
2
+ from dtypes import DType
3
 
4
 
5
  class MemoryCalculation:
6
+ def __init__(self, modelconfig: Model, parallelismconfig: Parallelism, trainingconfig: Training):
7
+ self.model = modelconfig
8
+ self.parallelism = parallelismconfig
9
+ self.training = trainingconfig
10
 
11
  def calculate_num_parameters(self) -> float:
12
  # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
 
15
  # Biases are not added/omitted on a per-model basis for simplicity.
16
  # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
17
 
18
+ #self tax
19
  b, s = self.training.batch_size, self.training.sequence_length
20
  h, i, l, v, e = (
21
  self.model.hidden_dim,
22
  self.model.intermediate_size,
23
  self.model.num_layers,
24
  self.model.vocab_size,
25
+ self.model.total_experts,
26
  )
27
  tp, pp, ep = (
28
  self.parallelism.tensor_parallelism,
 
61
 
62
  # pp and weight tying makes knowing where to embed layer challenging
63
  # going to assume "worst" case and it's at the end with final layer norm
64
+ # even though that's pretty small
65
+ total_params = 0
66
  if pp == 1:
67
  total_params = input_embedding + layers + unembedding + final_layer_norm
68
  if pp > 1:
69
  total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm
70
  return total_params
71
 
 
 
 
 
 
72
  def calculate_activation_parameters(self) -> float:
73
  # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
74
  # https://arxiv.org/abs/2205.05198
 
79
  self.model.intermediate_size,
80
  self.model.num_layers,
81
  self.model.vocab_size,
82
+ self.model.total_experts,
83
  self.model.active_experts,
84
  )
85
  tp, cp, pp, ep = (
 
88
  self.parallelism.pipeline_parallelism,
89
  self.parallelism.expert_parallelism,
90
  )
91
+ sp = tp
92
  if self.training.gradient_checkpointing:
93
  # full recomputation
94
  embed = 0
 
96
  layers = layer * l
97
  embed = 0
98
  final_layer_out = (
99
+ s * b * h / cp / sp
100
+ )
101
+ final_norm = s * b * h / cp / sp
102
  unembed = s * b * v / cp / tp
103
+ logits = s * b * v / cp / sp # come back to this
104
  num_params = (
105
  embed + layers + final_layer_out + final_norm + unembed + logits
106
  )
 
110
  # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
111
  # the variables calculate the activation outputs
112
  # Attention Block
113
+ layer_in = s * b * h / cp / tp
114
+ attn_norm = s * b * h / cp / sp
115
  flash = s * b * h / cp / tp
116
  # everything else is recalculated by flash attention
117
  projection = s * b * h / cp / tp
118
  attn = layer_in + attn_norm + flash + projection
119
  # MLP Block
120
+ mlp_norm = s * b * h / cp / sp
121
+
 
 
122
  mlp_up = s * b * i / cp / tp
123
  mlp_gate = s * b * i / cp / tp
124
  hadamard_swiglu = s * b * i / cp / tp
125
  mlp_down = s * b * h / cp / tp
126
+ if self.model.is_moe:
127
+ router = (
128
+ s * b * e / cp / sp) # makes sense to sp shard if mlp_norm out is sp sharded
129
+ expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
130
+ experts = expert * ae
131
+ mlp = mlp_norm + router + experts
132
+ else:
133
+ mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down
134
  layer = attn + mlp
135
+ layers = layer * l # no decrease from PP because schedules will increase microbatches
136
  # Other
137
  embed = 0
138
  final_layer_out = (
139
  s * b * h / cp / tp
140
  ) # both sequence and context parallelism
141
+ final_norm = s * b * h / cp / sp
142
  unembed = s * b * v / cp / tp
143
  logits = s * b * v / cp / tp
144
  num_params = (
145
  embed + layers + final_layer_out + final_norm + unembed + logits
146
  )
147
  return num_params
148
+
149
+ def calculate_parameter_memory(self) -> float:
150
+ if self.training.mixed_precision:
151
+ master_copy = self.calculate_num_parameters() * self.training.precision
152
+ working_copy = self.calculate_num_parameters() * self.training.param_dtype
153
+ return master_copy + working_copy
154
+ else:
155
+ return self.calculate_num_parameters() * self.training.precision
156
+
157
  def calculate_gradient_memory(self) -> float:
158
  # https://blog.eleuther.ai/transformer-math/#gradients
159
  return (
160
+ self.calculate_num_parameters() * 4
161
  ) # gradients are same size as parameters
162
 
163
  def calculate_optimizer_memory(self) -> float:
164
  # https://blog.eleuther.ai/transformer-math/#optimizer-states
165
  # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
166
  return (
167
+ 2 * self.calculate_num_parameters() * DType.FP32
168
+ ) # Adam optimizer with 2 states per parameter, assume always fp32
169
+
170
+ def calculate_activation_memory(self) -> float:
171
+ if self.training.mixed_precision:
172
+ return self.calculate_activation_parameters() * self.training.param_dtype
173
+ else:
174
+ return (
175
+ self.calculate_activation_parameters() * self.training.precision
176
+ )
177
+
defaults.py CHANGED
@@ -1,19 +1,24 @@
1
- from state import ModelState
2
 
3
- GEMMA3_270M = ModelState(
4
- vocab_size=256000, num_layers=9, hidden_dim=1152, intermediate_size=4608
 
5
  )
6
- GEMMA3_1B = ModelState(
7
- vocab_size=262208, num_layers=26, hidden_dim=2304, intermediate_size=9216
 
8
  )
9
- GEMMA3_4B = ModelState(
10
- vocab_size=262208, num_layers=28, hidden_dim=3072, intermediate_size=12288
 
11
  )
12
- GEMMA3_12B = ModelState(
13
- vocab_size=262208, num_layers=42, hidden_dim=4608, intermediate_size=18432
 
14
  )
15
- GEMMA3_27B = ModelState(
16
- vocab_size=262208, num_layers=46, hidden_dim=6144, intermediate_size=24576
 
17
  )
18
 
19
  DEFAULTS = {
 
1
+ from state import Model
2
 
3
+ GEMMA3_270M = Model(
4
+ vocab_size=256000, num_layers=9, hidden_dim=1152, intermediate_size=4608,
5
+ weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
6
  )
7
+ GEMMA3_1B = Model(
8
+ vocab_size=262208, num_layers=26, hidden_dim=2304, intermediate_size=9216,
9
+ weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
10
  )
11
+ GEMMA3_4B = Model(
12
+ vocab_size=262208, num_layers=28, hidden_dim=3072, intermediate_size=12288,
13
+ weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
14
  )
15
+ GEMMA3_12B = Model(
16
+ vocab_size=262208, num_layers=42, hidden_dim=4608, intermediate_size=18432,
17
+ weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
18
  )
19
+ GEMMA3_27B = Model(
20
+ vocab_size=262208, num_layers=46, hidden_dim=6144, intermediate_size=24576,
21
+ weight_tied_embeddings=True, active_experts=2, total_experts=8, is_moe=False
22
  )
23
 
24
  DEFAULTS = {
dtypes.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class DType(Enum):
5
+ FP32 = "fp32"
6
+ FP16 = "fp16"
7
+ BF16 = "bf16"
8
+ FP8 = "fp8"
9
+
10
+ @classmethod
11
+ def values(cls):
12
+ """Return a list of all enum values"""
13
+ return [dtype.value for dtype in cls]
14
+
15
+ def bytes_per_element(self):
16
+ """Return the number of bytes per element for this dtype"""
17
+ if self == DType.FP32:
18
+ return 4
19
+ elif self == DType.FP16:
20
+ return 2
21
+ elif self == DType.BF16:
22
+ return 2
23
+ elif self == DType.FP8:
24
+ return 1
25
+ else:
26
+ raise ValueError(f"Unknown dtype: {self}")
27
+
28
+ def __mul__(self, other):
29
+ """Multiply dtype by a number to get total bytes"""
30
+ return self.bytes_per_element() * other
31
+
32
+ def __rmul__(self, other):
33
+ """Multiply number by dtype to get total bytes"""
34
+ return other * self.bytes_per_element()
state.py CHANGED
@@ -1,4 +1,5 @@
1
  from dataclasses import dataclass
 
2
 
3
 
4
  @dataclass
@@ -8,6 +9,9 @@ class Model:
8
  hidden_dim: int
9
  intermediate_size: int
10
  weight_tied_embeddings: bool
 
 
 
11
 
12
 
13
  @dataclass
@@ -22,3 +26,9 @@ class Parallelism:
22
  class Training:
23
  sequence_length: int
24
  batch_size: int
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
+ from dtypes import DType
3
 
4
 
5
  @dataclass
 
9
  hidden_dim: int
10
  intermediate_size: int
11
  weight_tied_embeddings: bool
12
+ active_experts: int
13
+ total_experts: int
14
+ is_moe: bool
15
 
16
 
17
  @dataclass
 
26
  class Training:
27
  sequence_length: int
28
  batch_size: int
29
+ gradient_checkpointing: bool
30
+ grad_accumulation: bool
31
+ precision: DType
32
+ mixed_precision: bool
33
+ param_dtype: DType
34
+ reduce_dtype: DType