from dataclasses import dataclass from dtypes import DType @dataclass class Model: vocab_size: int num_layers: int hidden_dim: int intermediate_size: int weight_tied_embeddings: bool active_experts: int total_experts: int is_moe: bool @dataclass class Parallelism: tensor_parallelism: int pipeline_parallelism: int context_parallelism: int expert_parallelism: int fsdp_enabled: bool fsdp_parallelism: int fsdp_strategy: str @dataclass class Training: sequence_length: int batch_size: int gradient_checkpointing: bool grad_accumulation: bool precision: DType mixed_precision: bool param_dtype: DType reduce_dtype: DType