Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from functools import partial | |
| from defaults import DEFAULTS | |
| from state import Model, Parallelism, Training | |
| from calculator import MemoryCalculation | |
| from dtypes import DType | |
| # Create a Number component for natural numbers (positive integers) | |
| NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True) | |
| def greet(name, intensity) -> str: | |
| return "Hello, " + name + "!" * int(intensity) | |
| def create_parallelism_block(): | |
| with gr.Column(): | |
| gr.Markdown("# Parallelism") | |
| with gr.Group(): | |
| tp = NaturalNumber(label="Tensor Parallelism", value=1) | |
| pp = NaturalNumber(label="Pipeline Parallelism", value=1) | |
| cp = NaturalNumber(label="Context Parallelism", value=1) | |
| ep = NaturalNumber(label="Expert Parallelism", value=1) | |
| return tp, pp, cp, ep | |
| def create_model_block(): | |
| with gr.Column(): | |
| gr.Markdown("# Model Architecture") | |
| layers = NaturalNumber(label="Number of Layers", value=32) | |
| vocab = NaturalNumber(label="Vocab Size", value=32000) | |
| hidden = NaturalNumber(label="Hidden Dim", value=4096) | |
| intermediate = NaturalNumber(label="Intermediate Dim", value=11008) | |
| is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False) | |
| active_experts = NaturalNumber(label="Active Experts", value=2, visible=False) | |
| total_experts = NaturalNumber(label="Total Experts", value=8, visible=False) | |
| # Toggle expert fields visibility based on MoE checkbox | |
| is_moe.change( | |
| fn=lambda x: [gr.update(visible=x), gr.update(visible=x)], | |
| inputs=is_moe, | |
| outputs=[active_experts, total_experts] | |
| ) | |
| # not ready yet | |
| # presets = gr.Dropdown(list(DEFAULTS.keys()), label="Presets", interactive=True) | |
| return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets | |
| def create_training_block(): | |
| with gr.Column(): | |
| gr.Markdown("# Training Config") | |
| seq_len = NaturalNumber(label="Sequence Length", value=8192) | |
| batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=8) | |
| with gr.Row(): | |
| gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=False) | |
| grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False) | |
| precision = gr.Dropdown(DType.values(), label="Precision", value=DType.FP32.value, interactive=True) | |
| mixed_precision = gr.Checkbox(label="Mixed Precision", value=False) | |
| param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=True, visible=False) | |
| reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=True, visible=False) | |
| # Toggle dtype fields visibility based on mixed precision checkbox | |
| mixed_precision.change( | |
| fn=lambda x: [gr.update(visible=x), gr.update(visible=x)], | |
| inputs=mixed_precision, | |
| outputs=[param_dtype, reduce_dtype] | |
| ) | |
| return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype | |
| def calculate(tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype): | |
| # Create state objects | |
| model_config = Model( | |
| vocab_size=int(vocab), | |
| num_layers=int(layers), | |
| hidden_dim=int(hidden), | |
| intermediate_size=int(intermediate), | |
| weight_tied_embeddings=True, # Default assumption | |
| active_experts=int(active_experts), | |
| total_experts=int(total_experts), | |
| is_moe=is_moe | |
| ) | |
| parallelism_config = Parallelism( | |
| tensor_parallelism=int(tp), | |
| pipeline_parallelism=int(pp), | |
| context_parallelism=int(cp), | |
| expert_parallelism=int(ep) | |
| ) | |
| training_config = Training( | |
| sequence_length=int(seq_len), | |
| batch_size=int(batch_size), | |
| gradient_checkpointing=gradient_checkpointing, | |
| grad_accumulation=grad_accumulation, | |
| precision=DType(precision), | |
| mixed_precision=mixed_precision, | |
| param_dtype=DType(param_dtype), | |
| reduce_dtype=DType(reduce_dtype) | |
| ) | |
| # Calculate different memory components | |
| calc = MemoryCalculation(model_config, parallelism_config, training_config) | |
| # Get all memory calculations | |
| param_memory = calc.calculate_parameter_memory() | |
| activation_memory = calc.calculate_activation_memory() | |
| gradient_memory = calc.calculate_gradient_memory() | |
| optimizer_memory = calc.calculate_optimizer_memory() | |
| # Create DataFrame for bar plot | |
| memory_data = pd.DataFrame({ | |
| 'Component': [ | |
| 'Parameter Memory', | |
| 'Activation Memory', | |
| 'Gradient Memory', | |
| 'Optimizer Memory' | |
| ], | |
| 'Memory (GB)': [ | |
| param_memory / 1e9, | |
| activation_memory / 1e9, | |
| gradient_memory / 1e9, | |
| optimizer_memory / 1e9 | |
| ] | |
| }) | |
| return gr.BarPlot( | |
| value=memory_data, | |
| x="Component", | |
| y="Memory (GB)", | |
| title="LLM Memory Usage Breakdown", | |
| container=False, | |
| y_lim=[0, None] | |
| ) | |
| with gr.Blocks(theme='gstaff/xkcd') as demo: | |
| with gr.Sidebar(): | |
| gr.Textbox("## LLM Memory Visualizer") | |
| with gr.Column(): | |
| with gr.Row(equal_height=True): | |
| tp, pp, cp, ep = create_parallelism_block() | |
| layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets = create_model_block() | |
| seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block() | |
| calculate_button = gr.Button("Calculate") | |
| output = gr.BarPlot(label="Memory Usage Breakdown") | |
| calculate_button.click( | |
| fn=calculate, | |
| inputs=[tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype], | |
| outputs=output | |
| ) | |
| demo.launch() | |