from state import Model as Model, Parallelism, Training from dtypes import DType from math import ceil class MemoryCalculation: def __init__( self, modelconfig: Model, parallelismconfig: Parallelism, trainingconfig: Training, ): self.model = modelconfig self.parallelism = parallelismconfig self.training = trainingconfig def calculate_num_parameters_per_layer(self) -> float: # https://michaelwornow.net/2024/01/18/counting-params-in-transformer # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers # Biases are not added/omitted on a per-model basis for simplicity. # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound. # self tax b, s = self.training.batch_size, self.training.sequence_length h, i, l, v, e = ( self.model.hidden_dim, self.model.intermediate_size, self.model.num_layers, self.model.vocab_size, self.model.total_experts, ) tp, pp, ep = ( self.parallelism.tensor_parallelism, self.parallelism.pipeline_parallelism, self.parallelism.expert_parallelism, ) # Attention layer_norm_attn_in = h # not tp sharded qkv = 3 * h * h / tp attn_output_proj = (h * h + h) / tp attn = layer_norm_attn_in + qkv + attn_output_proj # MLP layer_norm_mlp_in = h # not tp sharded mlp_up_proj = (h * i + i) / tp mlp_gate_proj = (h * i + i) / tp mlp_down_proj = (i * h + h) / tp mlp = layer_norm_mlp_in + mlp_up_proj + mlp_gate_proj + mlp_down_proj if self.model.is_moe: router = h * e + e # assuming replicated for simplicity expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj experts = expert * e / ep mlp = layer_norm_mlp_in + router + experts layer = attn + mlp return layer def calculate_unshardeable_parameters(self) -> float: b, s = self.training.batch_size, self.training.sequence_length h, i, l, v, e = ( self.model.hidden_dim, self.model.intermediate_size, self.model.num_layers, self.model.vocab_size, self.model.total_experts, ) tp, pp, ep = ( self.parallelism.tensor_parallelism, self.parallelism.pipeline_parallelism, self.parallelism.expert_parallelism, ) # Embedding layers input_embedding = v * h / tp unembedding = 0 if not self.model.weight_tied_embeddings: unembedding = h * v / tp final_layer_norm = h # not tp sharded # hush linter total_params = 0 if pp == 1: total_params = input_embedding + unembedding + final_layer_norm elif pp > 1: total_params = max(input_embedding, unembedding) + final_layer_norm return total_params def calculate_fsdp_sharded_parameters(self) -> float: if not self.parallelism.fsdp_enabled: return self.calculate_num_parameters() else: return ( self.calculate_num_parameters_per_layer() * ceil( (self.model.num_layers - 1) / self.parallelism.pipeline_parallelism ) / self.parallelism.fsdp_parallelism + self.calculate_unshardeable_parameters() + self.calculate_num_parameters_per_layer() ) def calculate_num_parameters(self) -> float: return ( self.calculate_num_parameters_per_layer() * ceil(self.model.num_layers / self.parallelism.pipeline_parallelism) + self.calculate_unshardeable_parameters() ) def calculate_activation_parameters(self) -> float: # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size # https://arxiv.org/abs/2205.05198 # pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble b, s = self.training.batch_size, self.training.sequence_length h, i, l, v, e, ae = ( self.model.hidden_dim, self.model.intermediate_size, self.model.num_layers, self.model.vocab_size, self.model.total_experts, self.model.active_experts, ) tp, cp, pp, ep = ( self.parallelism.tensor_parallelism, self.parallelism.context_parallelism, self.parallelism.pipeline_parallelism, self.parallelism.expert_parallelism, ) sp = tp if self.training.gradient_checkpointing: # full recomputation embed = 0 layer = s * b * h / cp / tp # only keep initial input to layer layers = layer * l embed = 0 final_layer_out = s * b * h / cp / sp final_norm = s * b * h / cp / sp unembed = s * b * v / cp / tp num_params = embed + layers + final_layer_out + final_norm + unembed return num_params else: # assume flash attention ie do selective recomputation # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198 # the variables calculate the activation outputs # Attention Block layer_in = s * b * h / cp / tp attn_norm = s * b * h / cp / sp flash = s * b * h / cp / tp # everything else is recalculated by flash attention projection = s * b * h / cp / tp attn = layer_in + attn_norm + flash + projection # MLP Block mlp_norm = s * b * h / cp / sp mlp_up = s * b * i / cp / tp mlp_gate = s * b * i / cp / tp hadamard_swiglu = s * b * i / cp / tp mlp_down = s * b * h / cp / tp if self.model.is_moe: router = ( s * b * e / cp / sp ) # makes sense to sp shard if mlp_norm out is sp sharded expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down experts = expert * ae / ep mlp = mlp_norm + router + experts else: mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down layer = attn + mlp layers = ( layer * l ) # no decrease from PP because schedules will increase microbatches # Other embed = 0 final_layer_out = ( s * b * h / cp / sp ) # both sequence and context parallelism final_norm = s * b * h / cp / sp unembed = s * b * v / cp / tp num_params = embed + layers + final_layer_out + final_norm + unembed return num_params def calculate_parameter_memory(self) -> float: if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy == "Zero-3": params = self.calculate_fsdp_sharded_parameters() else: params = self.calculate_num_parameters() if self.training.mixed_precision: master_copy = params * self.training.precision working_copy = params * self.training.param_dtype return master_copy + working_copy else: return params * self.training.precision def calculate_gradient_memory(self) -> float: # https://blog.eleuther.ai/transformer-math/#gradients if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy in ("Zero-3", "Zero-2"): params = self.calculate_fsdp_sharded_parameters() else: params = self.calculate_num_parameters() grad_accumulation = 0 if self.training.grad_accumulation: if self.training.mixed_precision: grad_accumulation = ( params * self.training.reduce_dtype ) else: grad_accumulation = ( params * self.training.precision ) if self.training.mixed_precision: gradients = params * self.training.param_dtype else: gradients = params * self.training.precision return grad_accumulation + gradients def calculate_optimizer_memory(self) -> float: # https://blog.eleuther.ai/transformer-math/#optimizer-states # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2 if self.parallelism.fsdp_enabled: return ( 2 * self.calculate_num_parameters() * DType.FP32 ) / self.parallelism.fsdp_parallelism # don't gather a layer unlike params and grads else: return ( 2 * self.calculate_num_parameters() * DType.FP32 ) # Adam optimizer with 2 states per parameter, assume always fp32 def calculate_activation_memory(self) -> float: if self.training.mixed_precision: return self.calculate_activation_parameters() * self.training.param_dtype else: return ( self.calculate_activation_parameters() * self.training.precision ) # not impacted by fsdp