Spaces:
Sleeping
Sleeping
File size: 9,565 Bytes
b79954f 97e312a f45427d b79954f f45427d 97e312a b79954f f45427d b79954f f45427d b79954f 97e312a b79954f 64abcca b79954f f45427d b79954f 64abcca f45427d b79954f f45427d b79954f f45427d 64abcca f45427d 97e312a b79954f f45427d b79954f f45427d b79954f 97e312a b79954f 97e312a b79954f f45427d b79954f f45427d b79954f f45427d b79954f f45427d 97e312a b79954f 97e312a f45427d 97e312a f45427d 97e312a b79954f f45427d b79954f f9d6101 b79954f f45427d b79954f f45427d b79954f f45427d 97e312a f45427d 97e312a f45427d 97e312a f45427d b79954f f45427d b79954f f45427d 97e312a f45427d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
from state import Model as Model, Parallelism, Training
from dtypes import DType
from math import ceil
class MemoryCalculation:
def __init__(
self,
modelconfig: Model,
parallelismconfig: Parallelism,
trainingconfig: Training,
):
self.model = modelconfig
self.parallelism = parallelismconfig
self.training = trainingconfig
def calculate_num_parameters_per_layer(self) -> float:
# https://michaelwornow.net/2024/01/18/counting-params-in-transformer
# https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers
# Biases are not added/omitted on a per-model basis for simplicity.
# Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.
# self tax
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.total_experts,
)
tp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
# Attention
layer_norm_attn_in = h # not tp sharded
qkv = 3 * h * h / tp
attn_output_proj = (h * h + h) / tp
attn = layer_norm_attn_in + qkv + attn_output_proj
# MLP
layer_norm_mlp_in = h # not tp sharded
mlp_up_proj = (h * i + i) / tp
mlp_gate_proj = (h * i + i) / tp
mlp_down_proj = (i * h + h) / tp
mlp = layer_norm_mlp_in + mlp_up_proj + mlp_gate_proj + mlp_down_proj
if self.model.is_moe:
router = h * e + e # assuming replicated for simplicity
expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj
experts = expert * e / ep
mlp = layer_norm_mlp_in + router + experts
layer = attn + mlp
return layer
def calculate_unshardeable_parameters(self) -> float:
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.total_experts,
)
tp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
# Embedding layers
input_embedding = v * h / tp
unembedding = 0
if not self.model.weight_tied_embeddings:
unembedding = h * v / tp
final_layer_norm = h # not tp sharded
# hush linter
total_params = 0
if pp == 1:
total_params = input_embedding + unembedding + final_layer_norm
elif pp > 1:
total_params = max(input_embedding, unembedding) + final_layer_norm
return total_params
def calculate_fsdp_sharded_parameters(self) -> float:
if not self.parallelism.fsdp_enabled:
return self.calculate_num_parameters()
else:
return (
self.calculate_num_parameters_per_layer()
* ceil(
(self.model.num_layers - 1) / self.parallelism.pipeline_parallelism
)
/ self.parallelism.fsdp_parallelism
+ self.calculate_unshardeable_parameters()
+ self.calculate_num_parameters_per_layer()
)
def calculate_num_parameters(self) -> float:
return (
self.calculate_num_parameters_per_layer()
* ceil(self.model.num_layers / self.parallelism.pipeline_parallelism)
+ self.calculate_unshardeable_parameters()
)
def calculate_activation_parameters(self) -> float:
# https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
# https://arxiv.org/abs/2205.05198
# pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble
b, s = self.training.batch_size, self.training.sequence_length
h, i, l, v, e, ae = (
self.model.hidden_dim,
self.model.intermediate_size,
self.model.num_layers,
self.model.vocab_size,
self.model.total_experts,
self.model.active_experts,
)
tp, cp, pp, ep = (
self.parallelism.tensor_parallelism,
self.parallelism.context_parallelism,
self.parallelism.pipeline_parallelism,
self.parallelism.expert_parallelism,
)
sp = tp
if self.training.gradient_checkpointing:
# full recomputation
embed = 0
layer = s * b * h / cp / tp # only keep initial input to layer
layers = layer * l
embed = 0
final_layer_out = s * b * h / cp / sp
final_norm = s * b * h / cp / sp
unembed = s * b * v / cp / tp
num_params = embed + layers + final_layer_out + final_norm + unembed
return num_params
else:
# assume flash attention ie do selective recomputation
# assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
# the variables calculate the activation outputs
# Attention Block
layer_in = s * b * h / cp / tp
attn_norm = s * b * h / cp / sp
flash = s * b * h / cp / tp
# everything else is recalculated by flash attention
projection = s * b * h / cp / tp
attn = layer_in + attn_norm + flash + projection
# MLP Block
mlp_norm = s * b * h / cp / sp
mlp_up = s * b * i / cp / tp
mlp_gate = s * b * i / cp / tp
hadamard_swiglu = s * b * i / cp / tp
mlp_down = s * b * h / cp / tp
if self.model.is_moe:
router = (
s * b * e / cp / sp
) # makes sense to sp shard if mlp_norm out is sp sharded
expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
experts = expert * ae / ep
mlp = mlp_norm + router + experts
else:
mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down
layer = attn + mlp
layers = (
layer * l
) # no decrease from PP because schedules will increase microbatches
# Other
embed = 0
final_layer_out = (
s * b * h / cp / sp
) # both sequence and context parallelism
final_norm = s * b * h / cp / sp
unembed = s * b * v / cp / tp
num_params = embed + layers + final_layer_out + final_norm + unembed
return num_params
def calculate_parameter_memory(self) -> float:
if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy == "Zero-3":
params = self.calculate_fsdp_sharded_parameters()
else:
params = self.calculate_num_parameters()
if self.training.mixed_precision:
master_copy = params * self.training.precision
working_copy = params * self.training.param_dtype
return master_copy + working_copy
else:
return params * self.training.precision
def calculate_gradient_memory(self) -> float:
# https://blog.eleuther.ai/transformer-math/#gradients
if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy in ("Zero-3", "Zero-2"):
params = self.calculate_fsdp_sharded_parameters()
else:
params = self.calculate_num_parameters()
grad_accumulation = 0
if self.training.grad_accumulation:
if self.training.mixed_precision:
grad_accumulation = (
params * self.training.reduce_dtype
)
else:
grad_accumulation = (
params * self.training.precision
)
if self.training.mixed_precision:
gradients = params * self.training.param_dtype
else:
gradients = params * self.training.precision
return grad_accumulation + gradients
def calculate_optimizer_memory(self) -> float:
# https://blog.eleuther.ai/transformer-math/#optimizer-states
# https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
if self.parallelism.fsdp_enabled:
return (
2 * self.calculate_num_parameters() * DType.FP32
) / self.parallelism.fsdp_parallelism # don't gather a layer unlike params and grads
else:
return (
2 * self.calculate_num_parameters() * DType.FP32
) # Adam optimizer with 2 states per parameter, assume always fp32
def calculate_activation_memory(self) -> float:
if self.training.mixed_precision:
return self.calculate_activation_parameters() * self.training.param_dtype
else:
return (
self.calculate_activation_parameters() * self.training.precision
) # not impacted by fsdp
|