| | |
| | |
| | |
| |
|
| | """ |
| | Unified Language Model with GPAS + LNS Integration + xIELU Activation + CoLA (Linear Only) + LaX + Weight Tying + Canon Layers (A+C Only) |
| | MIGRATED TO HUGGINGFACE TRANSFORMERS - FINAL VERSION WITH ALL FIXES + CORRECTED LaX IMPLEMENTATION |
| | UPDATED: Standard Transformer with advanced variance control, parameter efficiency, Canon horizontal information flow, and WORKING LaX Inter-Layer |
| | Combines advanced Transformer architecture with CORRECTED variance control mechanisms, |
| | advanced variance control via GPAS and LNS, xIELU activation function, FIXED LaX integration, and Canon Layers (A+C only) |
| | Based on LLaMA 3 architecture with 30M parameters |
| | |
| | MIGRATION TO HUGGINGFACE - FINAL FIXED VERSION + LaX CORRECTION: |
| | ============================================================== |
| | |
| | 1. **HUGGINGFACE INTEGRATION**: Migrado de PyTorch Lightning a Transformers v4.53.3 |
| | 2. **UPDATED API**: processing_class en lugar de tokenizer (deprecated) |
| | 3. **UPDATED COMPUTE_LOSS**: Método actualizado con num_items_in_batch parameter |
| | 4. **FIXED LOGGING**: Corregido self.log() syntax según documentación oficial HF |
| | 5. **RESTORED PAD HANDLING**: pad_token_id → -100 conversion for CrossEntropyLoss (from original code) |
| | 6. **NATIVE TORCH COMPILE**: Moved to TrainingArguments (torch_compile=True) |
| | 7. **FIXED WEIGHT TYING**: Corrected _tied_weights_keys as class attribute (HF standard) |
| | 8. **VALIDATION DIAGNOSTIC**: Added simple method to diagnose validation loss issues |
| | 9. **CUSTOM CONFIGURATION**: PretrainedConfig personalizada con todos los parámetros |
| | 10. **PRETRAINED MODEL**: Hereda de PreTrainedModel para compatibilidad completa |
| | 11. **MAINTAINED OPTIMIZERS**: Muon + AdamW híbrido preservado |
| | 12. **MAINTAINED PRECISION**: bf16-true preservado |
| | 13. **MAINTAINED TRAINING**: Custom Trainer con todas las métricas y logging |
| | 14. **MAINTAINED ARCHITECTURE**: Toda la arquitectura personalizada preservada |
| | 15. **AUTO TOKENIZER**: Integración completa con AutoTokenizer dinámico |
| | 16. **AUTOCLASS SUPPORT**: Registro completo para AutoConfig y AutoModel |
| | 17. **✅ FIXED LaX**: Implementación correcta Inter-Layer con Linear Gate funcional |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.checkpoint import checkpoint |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoConfig, |
| | AutoModel, |
| | AutoModelForCausalLM, |
| | PreTrainedModel, |
| | ) |
| | import math |
| | import os |
| | from typing import Optional, Tuple, Dict, Any, cast, List |
| | from flash_attn import flash_attn_func |
| | import numpy as np |
| |
|
| | |
| | from configuration_unified import UnifiedModelConfig |
| |
|
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| | torch.set_float32_matmul_precision('high') |
| |
|
| | def init_cola_components(A: nn.Linear, B: nn.Linear): |
| | nn.init.kaiming_normal_(A.weight, mode='fan_in', nonlinearity='relu') |
| | nn.init.xavier_normal_(B.weight, gain=0.8) |
| | if B.bias is not None: |
| | nn.init.zeros_(B.bias) |
| |
|
| | def init_embedding(embedding: nn.Embedding): |
| | nn.init.normal_(embedding.weight, mean=0.0, std=0.02) |
| |
|
| | class CanonLayer(nn.Module): |
| | def __init__(self, hidden_dim: int, kernel_size: int = 4): |
| | """ |
| | Canon layer using a 1D causal convolution with residual connection. |
| | """ |
| | super().__init__() |
| | self.hidden_dim = hidden_dim |
| | self.kernel_size = kernel_size |
| | |
| | |
| | self.causal_conv1d = nn.Conv1d( |
| | in_channels=hidden_dim, |
| | out_channels=hidden_dim, |
| | kernel_size=kernel_size, |
| | groups=hidden_dim, |
| | padding=0, |
| | bias=True |
| | ) |
| | |
| | |
| | nn.init.zeros_(self.causal_conv1d.weight) |
| | nn.init.zeros_(self.causal_conv1d.bias) |
| |
|
| | def forward(self, h: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Applies the Canon layer transformation with causal masking. |
| | """ |
| | |
| | h_permuted = h.permute(0, 2, 1) |
| | |
| | |
| | padding = self.kernel_size - 1 |
| | h_padded = F.pad(h_permuted, (padding, 0)) |
| | |
| | |
| | conv_out = self.causal_conv1d(h_padded) |
| | |
| | |
| | conv_out_permuted = conv_out.permute(0, 2, 1) |
| | |
| | |
| | output = h + conv_out_permuted |
| | |
| | return output |
| |
|
| | class CoLA_Linear(nn.Module): |
| | def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu, bias: bool = True): |
| | super().__init__() |
| | if rank is None: |
| | rank = in_features // 4 |
| | self.rank = rank |
| | self.activation = activation |
| | |
| | self.A = nn.Linear(in_features, rank, bias=False) |
| | self.B = nn.Linear(rank, out_features, bias=bias) |
| | |
| | init_cola_components(self.A, self.B) |
| | |
| | def forward(self, x: torch.Tensor, prev_latent: Optional[torch.Tensor] = None, lax_beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Forward pass with optional LaX Inter-Layer integration. |
| | |
| | Args: |
| | x: Input tensor |
| | prev_latent: Previous latent from same module type in previous layer (for LaX) |
| | lax_beta: Linear gate parameter (scalar) for LaX |
| | |
| | Returns: |
| | Tuple of (output, current_latent) where current_latent can be used for next layer |
| | """ |
| | |
| | latent = self.A(x) |
| | latent_activated = self.activation(latent) |
| | |
| | |
| | if prev_latent is not None and lax_beta is not None and prev_latent.shape == latent_activated.shape: |
| | |
| | latent_activated = latent_activated + lax_beta * prev_latent |
| | |
| | |
| | output = self.B(latent_activated) |
| | |
| | return output, latent_activated |
| |
|
| | class LayerNormScaling(nn.Module): |
| | def __init__(self, layer_depth: int): |
| | super().__init__() |
| | |
| | if layer_depth < 1: |
| | raise ValueError(f"layer_depth debe ser ≥ 1, got {layer_depth}") |
| | |
| | self.layer_depth = layer_depth |
| | self.scaling_factor = 1.0 / math.sqrt(float(layer_depth)) |
| | |
| | def forward(self, normalized_input: torch.Tensor) -> torch.Tensor: |
| | return normalized_input * self.scaling_factor |
| |
|
| | class GPAS(nn.Module): |
| | def __init__(self, d_model: int): |
| | super().__init__() |
| | |
| | self.d_model = d_model |
| | self.alpha = nn.Parameter(torch.zeros(1)) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x_detached = x.detach() |
| | scaled_component = F.silu(self.alpha) * x_detached |
| | x_scaled = x - scaled_component |
| | |
| | return x_scaled |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000): |
| | super().__init__() |
| | self.dim = dim |
| | self.max_position_embeddings = max_position_embeddings |
| | self.base = base |
| | |
| | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| |
|
| | def forward(self, x, seq_len=None): |
| | if seq_len is None: |
| | seq_len = x.shape[-2] |
| | t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
| | freqs = torch.outer(t, self.inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | return emb.cos().to(x.dtype), emb.sin().to(x.dtype) |
| |
|
| | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): |
| | def rotate_half(x): |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| | |
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|
| | class XIELU(nn.Module): |
| | def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta: float = 0.5): |
| | super().__init__() |
| | |
| | self.beta = beta |
| | |
| | self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1)) |
| | self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1)) |
| | |
| | self.register_buffer('eps', torch.tensor(-1e-6)) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | alpha_p = F.softplus(self.alpha_p) |
| | alpha_n = self.beta + F.softplus(self.alpha_n) |
| | |
| | return torch.where( |
| | x > 0, |
| | alpha_p * x * x + self.beta * x, |
| | alpha_n * torch.expm1(torch.clamp(x, min=self.eps)) - alpha_n * x + self.beta * x |
| | ) |
| |
|
| | class StandardMLP(nn.Module): |
| | def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, config=None, layer_idx: int = 0): |
| | super().__init__() |
| | |
| | self.hidden_size = hidden_size |
| | self.intermediate_size = intermediate_size |
| | self.config = config |
| | self.layer_idx = layer_idx |
| | |
| | self.up_proj = CoLA_Linear(hidden_size, intermediate_size, bias=False) |
| | self.down_proj = CoLA_Linear(intermediate_size, hidden_size, bias=False) |
| | |
| | if config is not None: |
| | self.activation = XIELU( |
| | alpha_p_init=config.xielu_alpha_p_init, |
| | alpha_n_init=config.xielu_alpha_n_init, |
| | beta=config.xielu_beta |
| | ) |
| | else: |
| | self.activation = XIELU(alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5) |
| | |
| | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
| | |
| | |
| | if config is not None and config.lax_enabled: |
| | self.lax_beta_up = nn.Parameter(torch.full((1,), 0.2)) |
| | self.lax_beta_down = nn.Parameter(torch.full((1,), 0.2)) |
| | else: |
| | self.lax_beta_up = None |
| | self.lax_beta_down = None |
| |
|
| | def forward(self, x: torch.Tensor, lax_buffer: Optional[Dict] = None) -> torch.Tensor: |
| | |
| | prev_up_latent = None |
| | prev_down_latent = None |
| | if lax_buffer is not None and self.lax_beta_up is not None: |
| | prev_up_latent = lax_buffer.get(('mlp_up', self.layer_idx - 1)) |
| | prev_down_latent = lax_buffer.get(('mlp_down', self.layer_idx - 1)) |
| | |
| | |
| | intermediate, up_latent = self.up_proj(x, prev_up_latent, self.lax_beta_up) |
| | |
| | |
| | if lax_buffer is not None: |
| | lax_buffer[('mlp_up', self.layer_idx)] = up_latent.clone() |
| | |
| | |
| | activated = self.activation(intermediate) |
| | activated = self.dropout(activated) |
| | |
| | |
| | output, down_latent = self.down_proj(activated, prev_down_latent, self.lax_beta_down) |
| | |
| | |
| | if lax_buffer is not None: |
| | lax_buffer[('mlp_down', self.layer_idx)] = down_latent.clone() |
| | |
| | return output |
| |
|
| | class GroupedQueryAttention(nn.Module): |
| | def __init__(self, config, layer_idx: int = 0): |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.num_key_value_heads = config.num_key_value_heads |
| | self.head_dim = self.hidden_size // self.num_heads |
| | self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| | |
| | |
| | self.fanformer_p = getattr(config, 'fanformer_p', 0.15) |
| | |
| | self.d_periodic = int(self.hidden_size * self.fanformer_p) |
| | self.d_standard = self.hidden_size - 2 * self.d_periodic |
| | |
| | assert self.d_standard > 0, \ |
| | f"fanformer_p={self.fanformer_p} is too high. d_standard={self.d_standard} must be > 0" |
| | |
| | self.fan_w_p = CoLA_Linear(self.hidden_size, self.d_periodic, bias=False) |
| | self.fan_w_p_bar = CoLA_Linear(self.hidden_size, self.d_standard, bias=False) |
| | |
| | self.q_proj = CoLA_Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| | self.k_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| | self.v_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| | self.o_proj = CoLA_Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
| | |
| | self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| | self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| | self.v_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| | |
| | self.rotary_emb = RotaryEmbedding( |
| | self.head_dim, |
| | max_position_embeddings=config.max_position_embeddings, |
| | base=config.rope_theta |
| | ) |
| | |
| | |
| | if config.lax_enabled: |
| | self.lax_beta_q = nn.Parameter(torch.full((1,), 0.2)) |
| | self.lax_beta_k = nn.Parameter(torch.full((1,), 0.2)) |
| | self.lax_beta_v = nn.Parameter(torch.full((1,), 0.2)) |
| | else: |
| | self.lax_beta_q = None |
| | self.lax_beta_k = None |
| | self.lax_beta_v = None |
| | |
| | def _fan_layer_prime(self, x: torch.Tensor) -> torch.Tensor: |
| | periodic_proj, _ = self.fan_w_p(x) |
| | standard_proj, _ = self.fan_w_p_bar(x) |
| | |
| | cos_component = torch.cos(periodic_proj) |
| | sin_component = torch.sin(periodic_proj) |
| | |
| | x_f = torch.cat([cos_component, sin_component, standard_proj], dim=-1) |
| | |
| | return x_f |
| |
|
| | def _compute_flash_attention( |
| | self, |
| | query_states: torch.Tensor, |
| | key_states: torch.Tensor, |
| | value_states: torch.Tensor, |
| | seq_len: int, |
| | position_ids: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | batch_size = query_states.shape[0] |
| | |
| | q_rope = query_states.transpose(1, 2) |
| | k_rope = key_states.transpose(1, 2) |
| | |
| | cos, sin = self.rotary_emb(value_states, seq_len=seq_len) |
| | q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin, position_ids) |
| | |
| | query_states = q_rope.transpose(1, 2) |
| | key_states = k_rope.transpose(1, 2) |
| |
|
| | from flash_attn import flash_attn_func |
| | |
| | attn_output = flash_attn_func( |
| | query_states, |
| | key_states, |
| | value_states, |
| | dropout_p=self.config.attention_dropout if self.training else 0.0, |
| | causal=True, |
| | ) |
| | |
| | return attn_output |
| |
|
| | def forward(self, hidden_states, position_ids=None, attention_mask=None, lax_buffer: Optional[Dict] = None): |
| | batch_size, seq_len, _ = hidden_states.shape |
| | |
| | enhanced_input = self._fan_layer_prime(hidden_states) |
| | |
| | |
| | prev_q_latent = None |
| | prev_k_latent = None |
| | prev_v_latent = None |
| | if lax_buffer is not None and self.lax_beta_q is not None: |
| | prev_q_latent = lax_buffer.get(('attn_q', self.layer_idx - 1)) |
| | prev_k_latent = lax_buffer.get(('attn_k', self.layer_idx - 1)) |
| | prev_v_latent = lax_buffer.get(('attn_v', self.layer_idx - 1)) |
| | |
| | |
| | query_states, q_latent = self.q_proj(enhanced_input, prev_q_latent, self.lax_beta_q) |
| | key_states, k_latent = self.k_proj(enhanced_input, prev_k_latent, self.lax_beta_k) |
| | value_states, v_latent = self.v_proj(enhanced_input, prev_v_latent, self.lax_beta_v) |
| | |
| | |
| | if lax_buffer is not None: |
| | lax_buffer[('attn_q', self.layer_idx)] = q_latent.clone() |
| | lax_buffer[('attn_k', self.layer_idx)] = k_latent.clone() |
| | lax_buffer[('attn_v', self.layer_idx)] = v_latent.clone() |
| |
|
| | query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim) |
| | key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
| | value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
| |
|
| | q_flat = query_states.reshape(-1, self.head_dim) |
| | k_flat = key_states.reshape(-1, self.head_dim) |
| | v_flat = value_states.reshape(-1, self.head_dim) |
| | |
| | q_normalized = self.q_norm(q_flat) |
| | k_normalized = self.k_norm(k_flat) |
| | v_normalized = self.v_norm(v_flat) |
| | |
| | query_states = q_normalized.view(batch_size, seq_len, self.num_heads, self.head_dim) |
| | key_states = k_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
| | value_states = v_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
| |
|
| | attn_output = self._compute_flash_attention( |
| | query_states=query_states, |
| | key_states=key_states, |
| | value_states=value_states, |
| | seq_len=seq_len, |
| | position_ids=position_ids |
| | ) |
| |
|
| | attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size) |
| | |
| | |
| | output, _ = self.o_proj(attn_output) |
| | return output |
| |
|
| | class DecoderLayer(nn.Module): |
| | def __init__(self, config, layer_idx: int): |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| | |
| | if layer_idx < 0: |
| | raise ValueError(f"layer_idx debe ser >= 0, got {layer_idx}") |
| | |
| | self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.self_attn = GroupedQueryAttention(config, layer_idx) |
| | self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | |
| | self.mlp = StandardMLP( |
| | config.hidden_size, |
| | config.intermediate_size, |
| | config.mlp_dropout, |
| | config, |
| | layer_idx |
| | ) |
| | |
| | self.dropout_output = nn.Dropout(0.01) |
| | |
| | self.lns_attention = LayerNormScaling(layer_depth=layer_idx + 1) |
| | self.lns_mlp = LayerNormScaling(layer_depth=layer_idx + 1) |
| | |
| | self.gpas_attention = GPAS(config.hidden_size) |
| | self.gpas_mlp = GPAS(config.hidden_size) |
| | |
| | |
| | |
| | if config.canon_enabled and config.canon_a_enabled: |
| | self.canon_a = CanonLayer(config.hidden_size, config.canon_kernel_size) |
| | else: |
| | self.canon_a = None |
| | |
| | |
| | if config.canon_enabled and config.canon_c_enabled: |
| | self.canon_c = CanonLayer(config.hidden_size, config.canon_kernel_size) |
| | else: |
| | self.canon_c = None |
| |
|
| | def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, lax_buffer: Optional[Dict] = None) -> torch.Tensor: |
| | residual = hidden_states |
| | |
| | |
| | if self.canon_a is not None: |
| | hidden_states = self.canon_a(hidden_states) |
| | |
| | attention_input = self.input_layernorm(hidden_states) |
| | attention_input = self.lns_attention(attention_input) |
| | attention_output = self.self_attn(attention_input, position_ids, attention_mask, lax_buffer) |
| | hidden_states = residual + attention_output |
| | hidden_states = self.gpas_attention(hidden_states) |
| | hidden_states = self.dropout_output(hidden_states) |
| | |
| | residual = hidden_states |
| | |
| | |
| | if self.canon_c is not None: |
| | hidden_states = self.canon_c(hidden_states) |
| | |
| | mlp_input = self.post_attention_layernorm(hidden_states) |
| | mlp_input = self.lns_mlp(mlp_input) |
| | mlp_output = self.mlp(mlp_input, lax_buffer) |
| | hidden_states = residual + mlp_output |
| | hidden_states = self.gpas_mlp(hidden_states) |
| | hidden_states = self.dropout_output(hidden_states) |
| | |
| | return hidden_states |
| |
|
| | class UnifiedModel(PreTrainedModel): |
| | """ |
| | UnifiedModel that inherits from PreTrainedModel for full HuggingFace compatibility. |
| | With AutoClass support for seamless Hub integration. |
| | """ |
| | config_class = UnifiedModelConfig |
| | |
| | |
| | _tied_weights_keys = ["lm_head.weight"] |
| | |
| | def __init__(self, config: UnifiedModelConfig): |
| | super().__init__(config) |
| | self.config = config |
| | |
| | if config.vocab_size is None: |
| | raise ValueError("config.vocab_size must be set from tokenizer before model initialization") |
| | |
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.embedding_dropout = nn.Dropout(config.embedding_dropout) |
| | self.output_dropout = nn.Dropout(0.05) |
| |
|
| | |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | |
| | self.layers = nn.ModuleList() |
| | for i in range(config.num_hidden_layers): |
| | self.layers.append(DecoderLayer(config, i)) |
| | |
| | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | |
| | |
| | self.post_init() |
| | |
| | self._print_configuration() |
| |
|
| | def tie_weights(self): |
| | """ |
| | ✅ FIXED: Simplified tie_weights method following HuggingFace standard. |
| | Tie the word embeddings and the output layer. |
| | This is called automatically if config.tie_word_embeddings is True. |
| | """ |
| | if self.config.tie_word_embeddings: |
| | print("🔗 Applying weight tying: lm_head.weight = embed_tokens.weight") |
| | self.lm_head.weight = self.embed_tokens.weight |
| | print("✅ Weight tying successful: Parameters are properly shared") |
| |
|
| | def _init_weights(self, module): |
| | """Initialize weights following the custom initialization scheme.""" |
| | if isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02, a=-0.04, b=0.04) |
| | elif isinstance(module, CoLA_Linear): |
| | pass |
| |
|
| | def _print_configuration(self): |
| | |
| | total_params_naive = sum(p.numel() for p in self.parameters()) |
| | |
| | |
| | total_params_actual = total_params_naive |
| | vocab_params = self.config.vocab_size * self.config.hidden_size |
| | tied_savings = 0 |
| | |
| | |
| | if self.config.tie_word_embeddings: |
| | |
| | embed_weight = self.embed_tokens.weight |
| | lm_head_weight = self.lm_head.weight |
| | |
| | if embed_weight is lm_head_weight: |
| | |
| | tied_savings = vocab_params |
| | total_params_actual = total_params_naive - tied_savings |
| | else: |
| | |
| | tied_savings = 0 |
| | |
| | |
| | total_linear_params = 0 |
| | total_cola_params = 0 |
| | canon_params = 0 |
| | lax_params = 0 |
| | |
| | for name, module in self.named_modules(): |
| | if isinstance(module, CoLA_Linear): |
| | in_features = module.A.in_features |
| | out_features = module.B.out_features |
| | rank = module.rank |
| | |
| | standard_params = in_features * out_features |
| | cola_params = (in_features * rank) + (rank * out_features) |
| | |
| | total_linear_params += standard_params |
| | total_cola_params += cola_params |
| | elif isinstance(module, CanonLayer): |
| | |
| | canon_layer_params = module.hidden_dim * module.kernel_size + module.hidden_dim |
| | canon_params += canon_layer_params |
| | elif hasattr(module, 'lax_beta_q') and module.lax_beta_q is not None: |
| | |
| | lax_params += 3 |
| | elif hasattr(module, 'lax_beta_up') and module.lax_beta_up is not None: |
| | |
| | lax_params += 2 |
| | |
| | cola_reduction = ((total_linear_params - total_cola_params) / total_linear_params) * 100 if total_linear_params > 0 else 0 |
| | canon_overhead = (canon_params / total_params_actual) * 100 if total_params_actual > 0 else 0 |
| | lax_overhead = (lax_params / total_params_actual) * 100 if total_params_actual > 0 else 0 |
| | |
| | print(f"\n📊 UNIFIED Model + GPAS + LNS + xIELU + CoLA (Linear Only) + LaX + Canon (A+C) + Weight Tying:") |
| | |
| | |
| | if self.config.tie_word_embeddings and tied_savings > 0: |
| | print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M (effective)") |
| | print(f"📊 Parameter Breakdown:") |
| | print(f" • Naive count: {total_params_naive/1e6:.2f}M (all registered params)") |
| | print(f" • Actual count: {total_params_actual/1e6:.2f}M (after weight tying)") |
| | print(f" • Weight tying savings: {tied_savings/1e6:.2f}M ({tied_savings/total_params_naive*100:.1f}%)") |
| | else: |
| | print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M") |
| | |
| | print(f"📚 DYNAMIC Vocabulary Size: {self.config.vocab_size} (from tokenizer)") |
| | print(f"🔗 ✅ PROPER Weight Tying: {'ENABLED' if self.config.tie_word_embeddings else 'DISABLED'}") |
| | |
| | |
| | if self.config.tie_word_embeddings: |
| | if tied_savings > 0: |
| | print(f"💾 Weight Tying Status: ✅ ACTIVE (tensors are shared in memory)") |
| | else: |
| | print(f"💾 Weight Tying Status: ⏳ CONFIGURED (will be applied during post_init)") |
| | |
| | print(f"🚀 ACTIVATION: xIELU (αp_init={self.config.xielu_alpha_p_init}, αn_init={self.config.xielu_alpha_n_init}, β={self.config.xielu_beta})") |
| | print(f"🔄 UPGRADE: SwiGLU → StandardMLP + xIELU (better efficiency & adaptability)") |
| | print(f"🗜️ CoLA Integration: {cola_reduction:.1f}% parameter reduction in internal projections") |
| | print(f"🔀 LaX Enabled: {'YES' if self.config.lax_enabled else 'NO'} ✅ WORKING Inter-Layer (Linear Gate)") |
| | if self.config.lax_enabled: |
| | print(f" • LaX Method: Inter-Layer with Linear Gate (β scalars)") |
| | print(f" • LaX Applied to: q_proj, k_proj, v_proj, up_proj, down_proj (NOT o_proj)") |
| | print(f" • LaX Parameters: {lax_params} β scalars ({lax_overhead:.6f}% overhead)") |
| | print(f" • LaX Initialization: β=0.0 (conservative start)") |
| | print(f"🎼 Canon Layers Enabled: {'YES' if self.config.canon_enabled else 'NO'} (A+C ONLY)") |
| | if self.config.canon_enabled: |
| | print(f" • Canon-A (Before Attention): {'✅' if self.config.canon_a_enabled else '❌'}") |
| | print(f" • Canon-B (Inside Attention): ❌ PERMANENTLY DISABLED") |
| | print(f" • Canon-C (Before MLP): {'✅' if self.config.canon_c_enabled else '❌'}") |
| | print(f" • Canon-D (Inside MLP): ❌ PERMANENTLY DISABLED") |
| | print(f" • Canon Kernel Size: {self.config.canon_kernel_size}") |
| | print(f" • Canon Parameters Overhead: {canon_overhead:.3f}% ({canon_params/1e3:.1f}K params)") |
| | print(f"⚡ GPAS Enabled: ALWAYS (Dynamic variance control)") |
| | print(f"📏 LNS Enabled: ALWAYS (Static depth scaling)") |
| | print(f"🔧 Variance Control: Triple-level (LNS + GPAS + Canon A+C) ALWAYS") |
| | print(f"🔗 Residual Connections: STANDARD + HORIZONTAL (Canon A+C only)") |
| | print(f"🧹 CLEAN: Standard transformer architecture - CrossEntropyLoss manages PAD naturally") |
| | print(f"⚡ FlashAttention: Scaled Dot-Product Attention with GQA + automatic causal masking") |
| | print(f"🎯 TOKENIZER AGNOSTIC: Dynamic vocab_size and pad_token_id") |
| | print(f"🎯 SIMPLIFIED: CoLA Linear Only + Canon A+C Only = Better performance & less overhead") |
| | print(f"🔗 ✅ FIXED Weight Tying: _tied_weights_keys as class attribute (HF standard)") |
| | print(f"🎼 Canon A+C BENEFITS: Strategic horizontal information flow with minimal parameters") |
| | print(f"🔀 ✅ FIXED LaX: Functional Inter-Layer with ephemeral buffer (no broken reset)") |
| | print(f"🤗 HUGGINGFACE COMPATIBLE: Full PreTrainedModel integration v4.53.3") |
| | print(f"⚡ ✅ NATIVE TORCH COMPILE: Will be handled by TrainingArguments") |
| | print(f"🚀 ✅ AUTOCLASS SUPPORT: Compatible with AutoConfig.from_pretrained() and AutoModel.from_pretrained()") |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | **kwargs |
| | ): |
| | batch_size, seq_len = input_ids.shape |
| | |
| | |
| | lax_buffer = {} if self.config.lax_enabled else None |
| | |
| | hidden_states = self.embed_tokens(input_ids) |
| | hidden_states = self.embedding_dropout(hidden_states) |
| | |
| | for layer in self.layers: |
| | hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=attention_mask, lax_buffer=lax_buffer) |
| | |
| | hidden_states = self.norm(hidden_states) |
| | hidden_states = self.output_dropout(hidden_states) |
| | |
| | logits = self.lm_head(hidden_states) |
| | |
| | loss = None |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | |
| | loss_fct = nn.CrossEntropyLoss() |
| | shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| | shift_labels = shift_labels.view(-1) |
| | |
| | shift_labels = shift_labels.to(shift_logits.device) |
| | |
| | |
| | if self.config.pad_token_id is not None: |
| | shift_labels[shift_labels == self.config.pad_token_id] = -100 |
| | |
| | loss = loss_fct(shift_logits, shift_labels) |
| | |
| | |
| | |
| | |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=None, |
| | hidden_states=None, |
| | attentions=None, |
| | ) |
| |
|
| | def get_input_embeddings(self): |
| | return self.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | self.embed_tokens = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_ids: torch.Tensor, |
| | max_new_tokens: int = 50, |
| | temperature: float = 1.0, |
| | top_p: float = 0.9, |
| | top_k: Optional[int] = None, |
| | do_sample: bool = True, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[int] = None, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | """ |
| | Generate sequences using the model. |
| | Compatible with AutoModelForCausalLM interface. |
| | """ |
| | |
| | if pad_token_id is None: |
| | pad_token_id = self.config.pad_token_id |
| | if eos_token_id is None: |
| | eos_token_id = self.config.eos_token_id |
| | |
| | batch_size = input_ids.shape[0] |
| | device = input_ids.device |
| | |
| | generated = input_ids.clone() |
| | |
| | for _ in range(max_new_tokens): |
| | |
| | outputs = self.forward(generated) |
| | logits = outputs.logits |
| | |
| | |
| | next_token_logits = logits[:, -1, :] |
| | |
| | if do_sample: |
| | |
| | if temperature != 1.0: |
| | next_token_logits = next_token_logits / temperature |
| | |
| | |
| | if top_k is not None: |
| | values, indices = torch.topk(next_token_logits, top_k) |
| | next_token_logits[next_token_logits < values[:, [-1]]] = -float('inf') |
| | |
| | |
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | |
| | |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | |
| | |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | next_token_logits[indices_to_remove] = -float('inf') |
| | |
| | |
| | probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | else: |
| | |
| | next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| | |
| | |
| | generated = torch.cat([generated, next_token], dim=1) |
| | |
| | |
| | if eos_token_id is not None and (next_token == eos_token_id).all(): |
| | break |
| | |
| | return generated |
| |
|
| |
|
| |
|
| | |
| | |
| | AutoConfig.register("unified_model", UnifiedModelConfig) |
| | AutoModel.register(UnifiedModelConfig, UnifiedModel) |
| | AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel) |
| |
|
| | print("🚀 ✅ AUTOCLASS REGISTRATION COMPLETE:") |
| | print(" • AutoConfig.register('unified_model', UnifiedModelConfig)") |
| | print(" • AutoModel.register(UnifiedModelConfig, UnifiedModel)") |
| | print(" • AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)") |
| | print(" • Users can now load with: AutoModel.from_pretrained('your-repo', trust_remote_code=True)") |