import json from pathlib import Path import gradio as gr import pandas as pd from functools import partial from defaults import DEFAULTS from details import ACCURACY, DETAILS, INSTRUCTIONS, LIMITATIONS from state import Model, Parallelism, Training from calculator import MemoryCalculation from dtypes import DType # Create a Number component for natural numbers (positive integers) NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True) def create_parallelism_block(): with gr.Column(): gr.Markdown("# Parallelism") with gr.Group(): tp = NaturalNumber(label="Tensor Parallelism", value=1) pp = NaturalNumber(label="Pipeline Parallelism", value=1) cp = NaturalNumber(label="Context Parallelism", value=1) ep = NaturalNumber(label="Expert Parallelism", value=1) fsdp_enabled = gr.Checkbox(label="FSDP (Fully Sharded Data Parallel)", value=True) fsdp_parallelism = NaturalNumber(label="FSDP Parallelism", value=8) fsdp_strategy = gr.Radio( choices=["Zero-1", "Zero-2", "Zero-3"], label="FSDP Strategy", value="Zero-3" ) # Toggle FSDP fields interactivity based on FSDP checkbox fsdp_enabled.change( fn=lambda x: [ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) ], inputs=fsdp_enabled, outputs=[fsdp_parallelism, fsdp_strategy] ) return tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy def create_model_block(): with gr.Column(): gr.Markdown("# Model Architecture") layers = NaturalNumber(label="Number of Layers", value=32) vocab = NaturalNumber(label="Vocab Size", value=128256) hidden = NaturalNumber(label="Hidden Dim", value=4096) intermediate = NaturalNumber(label="Intermediate Dim", value=14336) is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False) active_experts = NaturalNumber(label="Active Experts", value=1, interactive=False, elem_classes="disabled-field") total_experts = NaturalNumber(label="Total Experts", value=1, interactive=False, elem_classes="disabled-field") weight_tied_embeddings = gr.Checkbox(label="Weight Tied Embeddings", value=True) # Toggle expert fields interactivity based on MoE checkbox is_moe.change( fn=lambda x: [ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) ], inputs=is_moe, outputs=[active_experts, total_experts] ) presets = gr.Dropdown(["Custom"] + list(DEFAULTS.keys()), label="Presets", value="Llama3 8B", interactive=True) # Populate model parameters when preset is selected def populate_from_preset(preset_name): if preset_name and preset_name in DEFAULTS: model = DEFAULTS[preset_name] return [ gr.update(value=model.num_layers), gr.update(value=model.vocab_size), gr.update(value=model.hidden_dim), gr.update(value=model.intermediate_size), gr.update(value=model.is_moe), gr.update(value=model.active_experts, interactive=model.is_moe), gr.update(value=model.total_experts, interactive=model.is_moe), gr.update(value=model.weight_tied_embeddings) ] return [gr.update() for _ in range(8)] # Switch to "Custom" when user manually edits values def switch_to_custom(layers_val, vocab_val, hidden_val, intermediate_val, is_moe_val, active_experts_val, total_experts_val, weight_tied_val, current_preset): # Don't switch to custom if a preset is being applied if current_preset and current_preset in DEFAULTS: model = DEFAULTS[current_preset] # Check if current values match the preset exactly if (layers_val == model.num_layers and vocab_val == model.vocab_size and hidden_val == model.hidden_dim and intermediate_val == model.intermediate_size and is_moe_val == model.is_moe and active_experts_val == model.active_experts and total_experts_val == model.total_experts and weight_tied_val == model.weight_tied_embeddings): return gr.update() # Keep current preset return gr.update(value="Custom") presets.change( fn=populate_from_preset, inputs=presets, outputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings] ) # Add change listeners to all model parameter inputs for input_component in [layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings]: input_component.change( fn=switch_to_custom, inputs=[layers, vocab, hidden, intermediate, is_moe, active_experts, total_experts, weight_tied_embeddings, presets], outputs=presets ) return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings def create_training_block(): with gr.Column(): gr.Markdown("# Training Config") seq_len = NaturalNumber(label="Sequence Length", value=4096) batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=1) with gr.Row(): gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=True) grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False) precision = gr.Dropdown(DType.values(), label="Precision", value=DType.BF16.value, interactive=True) mixed_precision = gr.Checkbox(label="Mixed Precision", value=False) param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field") reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=False, elem_classes="disabled-field") # Toggle dtype fields interactivity based on mixed precision checkbox mixed_precision.change( fn=lambda x: [ gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]), gr.update(interactive=x, elem_classes=[] if x else ["disabled-field"]) ], inputs=mixed_precision, outputs=[param_dtype, reduce_dtype] ) return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype def calculate(tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype): # Create state objects model_config = Model( vocab_size=int(vocab), num_layers=int(layers), hidden_dim=int(hidden), intermediate_size=int(intermediate), weight_tied_embeddings=weight_tied_embeddings, active_experts=int(active_experts), total_experts=int(total_experts), is_moe=is_moe ) parallelism_config = Parallelism( tensor_parallelism=int(tp), pipeline_parallelism=int(pp), context_parallelism=int(cp), expert_parallelism=int(ep), fsdp_enabled=fsdp_enabled, fsdp_parallelism=int(fsdp_parallelism), fsdp_strategy=fsdp_strategy ) training_config = Training( sequence_length=int(seq_len), batch_size=int(batch_size), gradient_checkpointing=gradient_checkpointing, grad_accumulation=grad_accumulation, precision=DType(precision), mixed_precision=mixed_precision, param_dtype=DType(param_dtype), reduce_dtype=DType(reduce_dtype) ) # Calculate different memory components calc = MemoryCalculation(model_config, parallelism_config, training_config) # Get all memory calculations param_memory = calc.calculate_parameter_memory() activation_memory = calc.calculate_activation_memory() gradient_memory = calc.calculate_gradient_memory() optimizer_memory = calc.calculate_optimizer_memory() # Calculate total memory total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory # Round to 1 decimal place for display param_gb = round(param_memory / 1e9, 1) activation_gb = round(activation_memory / 1e9, 1) gradient_gb = round(gradient_memory / 1e9, 1) optimizer_gb = round(optimizer_memory / 1e9, 1) total_gb = round(total_memory / 1e9, 1) # Create DataFrame for stacked bar plot # Start with stacked total bar, then add individual bars individual_data = [] # Stacked total bar first - create separate rows for each component within total for mem_type, gb_val in [ ('Activation', activation_gb), ('Optimizer', optimizer_gb), ('Gradient', gradient_gb), ('Parameter', param_gb) ]: individual_data.append({ 'Component': f'Total Memory\n{total_gb} GB', 'Memory (GB)': gb_val, 'Type': mem_type }) # Individual component bars for component, gb_val, mem_type in [ (f'Parameter Memory\n{param_gb} GB', param_gb, 'Parameter'), (f'Gradient Memory\n{gradient_gb} GB', gradient_gb, 'Gradient'), (f'Optimizer Memory\n{optimizer_gb} GB', optimizer_gb, 'Optimizer'), (f'Activation Memory\n{activation_gb} GB', activation_gb, 'Activation') ]: individual_data.append({ 'Component': component, 'Memory (GB)': gb_val, 'Type': mem_type }) memory_data = pd.DataFrame(individual_data) return gr.BarPlot( value=memory_data, x="Component", y="Memory (GB)", color="Type", title="LLM Memory Usage Breakdown", container=False, y_lim=[0, None], sort=[ f'Total Memory\n{total_gb} GB', f'Parameter Memory\n{param_gb} GB', f'Gradient Memory\n{gradient_gb} GB', f'Optimizer Memory\n{optimizer_gb} GB', f'Activation Memory\n{activation_gb} GB' ] ) css = """ /* Style for disabled components to make them visually obvious */ .disabled-field input, .disabled-field select, .disabled-field textarea { opacity: 0.4 !important; background-color: #f5f5f5 !important; color: #999 !important; cursor: not-allowed !important; text-decoration: line-through; } .disabled-field label { opacity: 0.5 !important; color: #999 !important; } """ with gr.Blocks(theme='Default', css=css) as demo: with gr.Column(): gr.Markdown("# LLM Training Memory Visualizer") gr.Markdown("🔧 Built by [Ruben Aghayan](https://www.linkedin.com/in/ruben-aghayan-37885690/)") gr.Markdown("---") gr.Markdown(INSTRUCTIONS) with gr.Row(equal_height=True): tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy = create_parallelism_block() layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets, weight_tied_embeddings = create_model_block() seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block() calculate_button = gr.Button("Calculate") output = gr.BarPlot(label="Memory Usage Breakdown") calculate_button.click( fn=calculate, inputs=[ tp, pp, cp, ep, fsdp_enabled, fsdp_parallelism, fsdp_strategy, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, weight_tied_embeddings, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype, ], outputs=output, ) gr.Markdown("# Details") with gr.Row(): gr.Markdown(LIMITATIONS) gr.Markdown(DETAILS) gr.Markdown("# Validation") gr.Markdown(ACCURACY) demo.launch(share=True)