Spaces:
Sleeping
Sleeping
| from state import Model as Model, Parallelism, Training | |
| from dtypes import DType | |
| from math import ceil | |
| class MemoryCalculation: | |
| def __init__( | |
| self, | |
| modelconfig: Model, | |
| parallelismconfig: Parallelism, | |
| trainingconfig: Training, | |
| ): | |
| self.model = modelconfig | |
| self.parallelism = parallelismconfig | |
| self.training = trainingconfig | |
| def calculate_num_parameters_per_layer(self) -> float: | |
| # https://michaelwornow.net/2024/01/18/counting-params-in-transformer | |
| # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers | |
| # Biases are not added/omitted on a per-model basis for simplicity. | |
| # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound. | |
| # self tax | |
| b, s = self.training.batch_size, self.training.sequence_length | |
| h, i, l, v, e = ( | |
| self.model.hidden_dim, | |
| self.model.intermediate_size, | |
| self.model.num_layers, | |
| self.model.vocab_size, | |
| self.model.total_experts, | |
| ) | |
| tp, pp, ep = ( | |
| self.parallelism.tensor_parallelism, | |
| self.parallelism.pipeline_parallelism, | |
| self.parallelism.expert_parallelism, | |
| ) | |
| # Attention | |
| layer_norm_attn_in = h # not tp sharded | |
| qkv = 3 * h * h / tp | |
| attn_output_proj = (h * h + h) / tp | |
| attn = layer_norm_attn_in + qkv + attn_output_proj | |
| # MLP | |
| layer_norm_mlp_in = h # not tp sharded | |
| mlp_up_proj = (h * i + i) / tp | |
| mlp_gate_proj = (h * i + i) / tp | |
| mlp_down_proj = (i * h + h) / tp | |
| mlp = layer_norm_mlp_in + mlp_up_proj + mlp_gate_proj + mlp_down_proj | |
| if self.model.is_moe: | |
| router = h * e + e # assuming replicated for simplicity | |
| expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj | |
| experts = expert * e / ep | |
| mlp = layer_norm_mlp_in + router + experts | |
| layer = attn + mlp | |
| return layer | |
| def calculate_unshardeable_parameters(self) -> float: | |
| b, s = self.training.batch_size, self.training.sequence_length | |
| h, i, l, v, e = ( | |
| self.model.hidden_dim, | |
| self.model.intermediate_size, | |
| self.model.num_layers, | |
| self.model.vocab_size, | |
| self.model.total_experts, | |
| ) | |
| tp, pp, ep = ( | |
| self.parallelism.tensor_parallelism, | |
| self.parallelism.pipeline_parallelism, | |
| self.parallelism.expert_parallelism, | |
| ) | |
| # Embedding layers | |
| input_embedding = v * h / tp | |
| unembedding = 0 | |
| if not self.model.weight_tied_embeddings: | |
| unembedding = h * v / tp | |
| final_layer_norm = h # not tp sharded | |
| # hush linter | |
| total_params = 0 | |
| if pp == 1: | |
| total_params = input_embedding + unembedding + final_layer_norm | |
| elif pp > 1: | |
| total_params = max(input_embedding, unembedding) + final_layer_norm | |
| return total_params | |
| def calculate_fsdp_sharded_parameters(self) -> float: | |
| if not self.parallelism.fsdp_enabled: | |
| return self.calculate_num_parameters() | |
| else: | |
| return ( | |
| self.calculate_num_parameters_per_layer() | |
| * ceil( | |
| (self.model.num_layers - 1) / self.parallelism.pipeline_parallelism | |
| ) | |
| / self.parallelism.fsdp_parallelism | |
| + self.calculate_unshardeable_parameters() | |
| + self.calculate_num_parameters_per_layer() | |
| ) | |
| def calculate_num_parameters(self) -> float: | |
| return ( | |
| self.calculate_num_parameters_per_layer() | |
| * ceil(self.model.num_layers / self.parallelism.pipeline_parallelism) | |
| + self.calculate_unshardeable_parameters() | |
| ) | |
| def calculate_activation_parameters(self) -> float: | |
| # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size | |
| # https://arxiv.org/abs/2205.05198 | |
| # pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble | |
| b, s = self.training.batch_size, self.training.sequence_length | |
| h, i, l, v, e, ae = ( | |
| self.model.hidden_dim, | |
| self.model.intermediate_size, | |
| self.model.num_layers, | |
| self.model.vocab_size, | |
| self.model.total_experts, | |
| self.model.active_experts, | |
| ) | |
| tp, cp, pp, ep = ( | |
| self.parallelism.tensor_parallelism, | |
| self.parallelism.context_parallelism, | |
| self.parallelism.pipeline_parallelism, | |
| self.parallelism.expert_parallelism, | |
| ) | |
| sp = tp | |
| if self.training.gradient_checkpointing: | |
| # full recomputation | |
| embed = 0 | |
| layer = s * b * h / cp / tp # only keep initial input to layer | |
| layers = layer * l | |
| embed = 0 | |
| final_layer_out = s * b * h / cp / sp | |
| final_norm = s * b * h / cp / sp | |
| unembed = s * b * v / cp / tp | |
| num_params = embed + layers + final_layer_out + final_norm + unembed | |
| return num_params | |
| else: | |
| # assume flash attention ie do selective recomputation | |
| # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198 | |
| # the variables calculate the activation outputs | |
| # Attention Block | |
| layer_in = s * b * h / cp / tp | |
| attn_norm = s * b * h / cp / sp | |
| flash = s * b * h / cp / tp | |
| # everything else is recalculated by flash attention | |
| projection = s * b * h / cp / tp | |
| attn = layer_in + attn_norm + flash + projection | |
| # MLP Block | |
| mlp_norm = s * b * h / cp / sp | |
| mlp_up = s * b * i / cp / tp | |
| mlp_gate = s * b * i / cp / tp | |
| hadamard_swiglu = s * b * i / cp / tp | |
| mlp_down = s * b * h / cp / tp | |
| if self.model.is_moe: | |
| router = ( | |
| s * b * e / cp / sp | |
| ) # makes sense to sp shard if mlp_norm out is sp sharded | |
| expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down | |
| experts = expert * ae / ep | |
| mlp = mlp_norm + router + experts | |
| else: | |
| mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down | |
| layer = attn + mlp | |
| layers = ( | |
| layer * l | |
| ) # no decrease from PP because schedules will increase microbatches | |
| # Other | |
| embed = 0 | |
| final_layer_out = ( | |
| s * b * h / cp / sp | |
| ) # both sequence and context parallelism | |
| final_norm = s * b * h / cp / sp | |
| unembed = s * b * v / cp / tp | |
| num_params = embed + layers + final_layer_out + final_norm + unembed | |
| return num_params | |
| def calculate_parameter_memory(self) -> float: | |
| if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy == "Zero-3": | |
| params = self.calculate_fsdp_sharded_parameters() | |
| else: | |
| params = self.calculate_num_parameters() | |
| if self.training.mixed_precision: | |
| master_copy = params * self.training.precision | |
| working_copy = params * self.training.param_dtype | |
| return master_copy + working_copy | |
| else: | |
| return params * self.training.precision | |
| def calculate_gradient_memory(self) -> float: | |
| # https://blog.eleuther.ai/transformer-math/#gradients | |
| if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy in ("Zero-3", "Zero-2"): | |
| params = self.calculate_fsdp_sharded_parameters() | |
| else: | |
| params = self.calculate_num_parameters() | |
| grad_accumulation = 0 | |
| if self.training.grad_accumulation: | |
| if self.training.mixed_precision: | |
| grad_accumulation = ( | |
| params * self.training.reduce_dtype | |
| ) | |
| else: | |
| grad_accumulation = ( | |
| params * self.training.precision | |
| ) | |
| if self.training.mixed_precision: | |
| gradients = params * self.training.param_dtype | |
| else: | |
| gradients = params * self.training.precision | |
| return grad_accumulation + gradients | |
| def calculate_optimizer_memory(self) -> float: | |
| # https://blog.eleuther.ai/transformer-math/#optimizer-states | |
| # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2 | |
| if self.parallelism.fsdp_enabled: | |
| return ( | |
| 2 * self.calculate_num_parameters() * DType.FP32 | |
| ) / self.parallelism.fsdp_parallelism # don't gather a layer unlike params and grads | |
| else: | |
| return ( | |
| 2 * self.calculate_num_parameters() * DType.FP32 | |
| ) # Adam optimizer with 2 states per parameter, assume always fp32 | |
| def calculate_activation_memory(self) -> float: | |
| if self.training.mixed_precision: | |
| return self.calculate_activation_parameters() * self.training.param_dtype | |
| else: | |
| return ( | |
| self.calculate_activation_parameters() * self.training.precision | |
| ) # not impacted by fsdp | |