llm_memory_visualizer / calculator.py
rubenaghayan's picture
better defaults and validation section
64abcca
from state import Model as Model, Parallelism, Training
from dtypes import DType
from math import ceil
class MemoryCalculation:
def __init__(
self,
modelconfig: Model,
parallelismconfig: Parallelism,
trainingconfig: Training,
):
self.model = modelconfig
self.parallelism = parallelismconfig
self.training = trainingconfig
def calculate_num_parameters_per_layer(self) -> float:
# https://michaelwornow.net/2024/01/18/counting-params-in-transformer
# https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers
# Biases are not added/omitted on a per-model basis for simplicity.
# Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
# self tax
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.total_experts,
)
tp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
# Attention
layer_norm_attn_in = h # not tp sharded
qkv = 3 * h * h / tp
attn_output_proj = (h * h + h) / tp
attn = layer_norm_attn_in + qkv + attn_output_proj
# MLP
layer_norm_mlp_in = h # not tp sharded
mlp_up_proj = (h * i + i) / tp
mlp_gate_proj = (h * i + i) / tp
mlp_down_proj = (i * h + h) / tp
mlp = layer_norm_mlp_in + mlp_up_proj + mlp_gate_proj + mlp_down_proj
if self.model.is_moe:
router = h * e + e # assuming replicated for simplicity
expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj
experts = expert * e / ep
mlp = layer_norm_mlp_in + router + experts
layer = attn + mlp
return layer
def calculate_unshardeable_parameters(self) -> float:
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.total_experts,
)
tp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
# Embedding layers
input_embedding = v * h / tp
unembedding = 0
if not self.model.weight_tied_embeddings:
unembedding = h * v / tp
final_layer_norm = h # not tp sharded
# hush linter
total_params = 0
if pp == 1:
total_params = input_embedding + unembedding + final_layer_norm
elif pp > 1:
total_params = max(input_embedding, unembedding) + final_layer_norm
return total_params
def calculate_fsdp_sharded_parameters(self) -> float:
if not self.parallelism.fsdp_enabled:
return self.calculate_num_parameters()
else:
return (
self.calculate_num_parameters_per_layer()
* ceil(
(self.model.num_layers - 1) / self.parallelism.pipeline_parallelism
)
/ self.parallelism.fsdp_parallelism
+ self.calculate_unshardeable_parameters()
+ self.calculate_num_parameters_per_layer()
)
def calculate_num_parameters(self) -> float:
return (
self.calculate_num_parameters_per_layer()
* ceil(self.model.num_layers / self.parallelism.pipeline_parallelism)
+ self.calculate_unshardeable_parameters()
)
def calculate_activation_parameters(self) -> float:
# https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
# https://arxiv.org/abs/2205.05198
# pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e, ae = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.total_experts,
self.model.active_experts,
)
tp, cp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.context_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
sp = tp
if self.training.gradient_checkpointing:
# full recomputation
embed = 0
layer = s * b * h / cp / tp # only keep initial input to layer
layers = layer * l
embed = 0
final_layer_out = s * b * h / cp / sp
final_norm = s * b * h / cp / sp
unembed = s * b * v / cp / tp
num_params = embed + layers + final_layer_out + final_norm + unembed
return num_params
else:
# assume flash attention ie do selective recomputation
# assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
# the variables calculate the activation outputs
# Attention Block
layer_in = s * b * h / cp / tp
attn_norm = s * b * h / cp / sp
flash = s * b * h / cp / tp
# everything else is recalculated by flash attention
projection = s * b * h / cp / tp
attn = layer_in + attn_norm + flash + projection
# MLP Block
mlp_norm = s * b * h / cp / sp
mlp_up = s * b * i / cp / tp
mlp_gate = s * b * i / cp / tp
hadamard_swiglu = s * b * i / cp / tp
mlp_down = s * b * h / cp / tp
if self.model.is_moe:
router = (
s * b * e / cp / sp
) # makes sense to sp shard if mlp_norm out is sp sharded
expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
experts = expert * ae / ep
mlp = mlp_norm + router + experts
else:
mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down
layer = attn + mlp
layers = (
layer * l
) # no decrease from PP because schedules will increase microbatches
# Other
embed = 0
final_layer_out = (
s * b * h / cp / sp
) # both sequence and context parallelism
final_norm = s * b * h / cp / sp
unembed = s * b * v / cp / tp
num_params = embed + layers + final_layer_out + final_norm + unembed
return num_params
def calculate_parameter_memory(self) -> float:
if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy == "Zero-3":
params = self.calculate_fsdp_sharded_parameters()
else:
params = self.calculate_num_parameters()
if self.training.mixed_precision:
master_copy = params * self.training.precision
working_copy = params * self.training.param_dtype
return master_copy + working_copy
else:
return params * self.training.precision
def calculate_gradient_memory(self) -> float:
# https://blog.eleuther.ai/transformer-math/#gradients
if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy in ("Zero-3", "Zero-2"):
params = self.calculate_fsdp_sharded_parameters()
else:
params = self.calculate_num_parameters()
grad_accumulation = 0
if self.training.grad_accumulation:
if self.training.mixed_precision:
grad_accumulation = (
params * self.training.reduce_dtype
)
else:
grad_accumulation = (
params * self.training.precision
)
if self.training.mixed_precision:
gradients = params * self.training.param_dtype
else:
gradients = params * self.training.precision
return grad_accumulation + gradients
def calculate_optimizer_memory(self) -> float:
# https://blog.eleuther.ai/transformer-math/#optimizer-states
# https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
if self.parallelism.fsdp_enabled:
return (
2 * self.calculate_num_parameters() * DType.FP32
) / self.parallelism.fsdp_parallelism # don't gather a layer unlike params and grads
else:
return (
2 * self.calculate_num_parameters() * DType.FP32
) # Adam optimizer with 2 states per parameter, assume always fp32
def calculate_activation_memory(self) -> float:
if self.training.mixed_precision:
return self.calculate_activation_parameters() * self.training.param_dtype
else:
return (
self.calculate_activation_parameters() * self.training.precision
) # not impacted by fsdp