Update modeling_neollm.py
Browse files- modeling_neollm.py +37 -34
modeling_neollm.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
NeoLLM Model with FANformer Integration and
|
| 4 |
-
Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling
|
| 5 |
-
|
| 6 |
"""
|
| 7 |
|
| 8 |
import math
|
|
@@ -45,8 +45,6 @@ else:
|
|
| 45 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 46 |
|
| 47 |
logger = logging.get_logger(__name__)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
class FANLayer(nn.Module):
|
| 51 |
"""
|
| 52 |
Fourier Analysis Network (FAN) layer for effective periodicity modeling.
|
|
@@ -63,26 +61,27 @@ class FANLayer(nn.Module):
|
|
| 63 |
self.hidden_size = hidden_size
|
| 64 |
self.fan_ratio = fan_ratio
|
| 65 |
|
| 66 |
-
# Calculate dimensions
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
self.
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Initialize parameters
|
| 75 |
self._init_weights()
|
| 76 |
|
| 77 |
def _init_weights(self):
|
| 78 |
"""Initialize weights following the paper's recommendations."""
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# Initialize Wp_bar for non-periodic components
|
| 83 |
-
nn.init.normal_(self.Wp_bar.weight, mean=0.0, std=0.02)
|
| 84 |
-
if self.Wp_bar.bias is not None:
|
| 85 |
-
nn.init.zeros_(self.Wp_bar.bias)
|
| 86 |
|
| 87 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
"""
|
|
@@ -93,17 +92,14 @@ class FANLayer(nn.Module):
|
|
| 93 |
|
| 94 |
Returns:
|
| 95 |
Transformed tensor with Fourier components concatenated
|
|
|
|
| 96 |
"""
|
| 97 |
-
#
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
sin_component = torch.sin(x_periodic)
|
| 101 |
-
|
| 102 |
-
# Get non-periodic component (linear transformation)
|
| 103 |
-
x_non_periodic = self.Wp_bar(x) # (batch, seq_len, non_periodic_dim)
|
| 104 |
|
| 105 |
# Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)]
|
| 106 |
-
x_fan = torch.cat([
|
| 107 |
|
| 108 |
return x_fan
|
| 109 |
|
|
@@ -287,7 +283,7 @@ def eager_attention_forward(
|
|
| 287 |
|
| 288 |
|
| 289 |
class NeoLLMAttention(nn.Module):
|
| 290 |
-
"""Multi-headed attention with FANformer integration for periodicity modeling"""
|
| 291 |
|
| 292 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 293 |
super().__init__()
|
|
@@ -338,8 +334,10 @@ class NeoLLMAttention(nn.Module):
|
|
| 338 |
|
| 339 |
# Apply FANformer transformation first
|
| 340 |
hidden_states_fan = self.fan_layer(hidden_states)
|
|
|
|
| 341 |
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 342 |
|
|
|
|
| 343 |
query_states, gate = torch.chunk(
|
| 344 |
self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
| 345 |
)
|
|
@@ -537,7 +535,7 @@ def torch_recurrent_gated_delta_rule(
|
|
| 537 |
return core_attn_out, last_recurrent_state
|
| 538 |
|
| 539 |
class NeoLLMGatedDeltaNet(nn.Module):
|
| 540 |
-
"""Linear attention with FANformer integration for periodicity modeling"""
|
| 541 |
|
| 542 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 543 |
super().__init__()
|
|
@@ -659,7 +657,8 @@ class NeoLLMGatedDeltaNet(nn.Module):
|
|
| 659 |
|
| 660 |
# Apply FANformer transformation first
|
| 661 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 662 |
-
|
|
|
|
| 663 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
|
| 664 |
projected_states_ba = self.in_proj_ba(hidden_states_fan)
|
| 665 |
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
|
|
@@ -737,6 +736,7 @@ class PolyNorm(torch.nn.Module):
|
|
| 737 |
|
| 738 |
def forward(self, x):
|
| 739 |
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
|
|
|
|
| 740 |
class NeoLLMMLP(nn.Module):
|
| 741 |
def __init__(self, config):
|
| 742 |
super().__init__()
|
|
@@ -817,7 +817,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 817 |
**kwargs,
|
| 818 |
)
|
| 819 |
|
| 820 |
-
#
|
| 821 |
hidden_states = residual + hidden_states
|
| 822 |
|
| 823 |
# Apply GPAS after attention residual connection
|
|
@@ -832,7 +832,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 832 |
|
| 833 |
hidden_states = self.mlp(hidden_states)
|
| 834 |
|
| 835 |
-
#
|
| 836 |
hidden_states = residual + hidden_states
|
| 837 |
|
| 838 |
# Apply GPAS after MLP residual connection
|
|
@@ -867,6 +867,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 867 |
def __init__(self, config: NeoLLMConfig):
|
| 868 |
super().__init__(config)
|
| 869 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
|
|
|
|
|
|
| 870 |
self.layers = nn.ModuleList(
|
| 871 |
[NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 872 |
)
|
|
@@ -934,6 +936,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 934 |
if attention_mask is not None and torch.all(attention_mask == 1):
|
| 935 |
linear_attn_mask = None
|
| 936 |
return linear_attn_mask
|
|
|
|
| 937 |
@torch.compiler.disable
|
| 938 |
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
|
| 939 |
"""
|
|
@@ -957,7 +960,7 @@ def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, p
|
|
| 957 |
processed_labels,
|
| 958 |
bias=lm_head_bias,
|
| 959 |
shift=1,
|
| 960 |
-
impl="
|
| 961 |
reduction="mean"
|
| 962 |
)
|
| 963 |
|
|
@@ -1015,6 +1018,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1015 |
hidden_states=outputs.hidden_states,
|
| 1016 |
attentions=outputs.attentions,
|
| 1017 |
)
|
|
|
|
| 1018 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1019 |
|
| 1020 |
__all__ = [
|
|
@@ -1025,8 +1029,7 @@ __all__ = [
|
|
| 1025 |
"FANLayer",
|
| 1026 |
]
|
| 1027 |
|
| 1028 |
-
|
| 1029 |
# Register the configuration and model for AutoClass support
|
| 1030 |
AutoConfig.register("neollm", NeoLLMConfig)
|
| 1031 |
AutoModel.register(NeoLLMConfig, NeoLLMModel)
|
| 1032 |
-
AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
NeoLLM Model with FANformer Integration, Dropout Regularization, and Selective Self-Attention (SSA)
|
| 4 |
+
Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling,
|
| 5 |
+
dropout regularization at strategic locations
|
| 6 |
"""
|
| 7 |
|
| 8 |
import math
|
|
|
|
| 45 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 46 |
|
| 47 |
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
| 48 |
class FANLayer(nn.Module):
|
| 49 |
"""
|
| 50 |
Fourier Analysis Network (FAN) layer for effective periodicity modeling.
|
|
|
|
| 61 |
self.hidden_size = hidden_size
|
| 62 |
self.fan_ratio = fan_ratio
|
| 63 |
|
| 64 |
+
# Calculate dimensions following the paper's approach
|
| 65 |
+
# Output will be: [cos(p) || sin(p) || g] where total = hidden_size + periodic_dim
|
| 66 |
+
output_dim = hidden_size + int(hidden_size * fan_ratio)
|
| 67 |
+
self.p_output_dim = int(output_dim * fan_ratio)
|
| 68 |
+
self.g_output_dim = output_dim - self.p_output_dim * 2
|
| 69 |
|
| 70 |
+
# Single fused projection (more efficient than two separate projections)
|
| 71 |
+
self.input_linear = nn.Linear(
|
| 72 |
+
hidden_size,
|
| 73 |
+
self.p_output_dim + self.g_output_dim,
|
| 74 |
+
bias=True
|
| 75 |
+
)
|
| 76 |
|
| 77 |
# Initialize parameters
|
| 78 |
self._init_weights()
|
| 79 |
|
| 80 |
def _init_weights(self):
|
| 81 |
"""Initialize weights following the paper's recommendations."""
|
| 82 |
+
nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02)
|
| 83 |
+
if self.input_linear.bias is not None:
|
| 84 |
+
nn.init.zeros_(self.input_linear.bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
"""
|
|
|
|
| 92 |
|
| 93 |
Returns:
|
| 94 |
Transformed tensor with Fourier components concatenated
|
| 95 |
+
Shape: (batch, seq_len, hidden_size + periodic_dim)
|
| 96 |
"""
|
| 97 |
+
# Single projection followed by split (more efficient)
|
| 98 |
+
pg = self.input_linear(x)
|
| 99 |
+
p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)]
|
| 102 |
+
x_fan = torch.cat([torch.cos(p), torch.sin(p), g], dim=-1)
|
| 103 |
|
| 104 |
return x_fan
|
| 105 |
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
class NeoLLMAttention(nn.Module):
|
| 286 |
+
"""Multi-headed attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
|
| 287 |
|
| 288 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 289 |
super().__init__()
|
|
|
|
| 334 |
|
| 335 |
# Apply FANformer transformation first
|
| 336 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 337 |
+
|
| 338 |
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 339 |
|
| 340 |
+
# Use FAN-transformed features directly for projections
|
| 341 |
query_states, gate = torch.chunk(
|
| 342 |
self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
| 343 |
)
|
|
|
|
| 535 |
return core_attn_out, last_recurrent_state
|
| 536 |
|
| 537 |
class NeoLLMGatedDeltaNet(nn.Module):
|
| 538 |
+
"""Linear attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
|
| 539 |
|
| 540 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 541 |
super().__init__()
|
|
|
|
| 657 |
|
| 658 |
# Apply FANformer transformation first
|
| 659 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 660 |
+
|
| 661 |
+
# Use FAN-transformed features directly for projections
|
| 662 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
|
| 663 |
projected_states_ba = self.in_proj_ba(hidden_states_fan)
|
| 664 |
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
|
|
|
|
| 736 |
|
| 737 |
def forward(self, x):
|
| 738 |
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
|
| 739 |
+
|
| 740 |
class NeoLLMMLP(nn.Module):
|
| 741 |
def __init__(self, config):
|
| 742 |
super().__init__()
|
|
|
|
| 817 |
**kwargs,
|
| 818 |
)
|
| 819 |
|
| 820 |
+
# Standard residual connection
|
| 821 |
hidden_states = residual + hidden_states
|
| 822 |
|
| 823 |
# Apply GPAS after attention residual connection
|
|
|
|
| 832 |
|
| 833 |
hidden_states = self.mlp(hidden_states)
|
| 834 |
|
| 835 |
+
# Standard residual connection
|
| 836 |
hidden_states = residual + hidden_states
|
| 837 |
|
| 838 |
# Apply GPAS after MLP residual connection
|
|
|
|
| 867 |
def __init__(self, config: NeoLLMConfig):
|
| 868 |
super().__init__(config)
|
| 869 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
| 870 |
+
|
| 871 |
+
# Each layer creates its own components (no shared parameters)
|
| 872 |
self.layers = nn.ModuleList(
|
| 873 |
[NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 874 |
)
|
|
|
|
| 936 |
if attention_mask is not None and torch.all(attention_mask == 1):
|
| 937 |
linear_attn_mask = None
|
| 938 |
return linear_attn_mask
|
| 939 |
+
|
| 940 |
@torch.compiler.disable
|
| 941 |
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
|
| 942 |
"""
|
|
|
|
| 960 |
processed_labels,
|
| 961 |
bias=lm_head_bias,
|
| 962 |
shift=1,
|
| 963 |
+
impl="cce_kahan_full_c",
|
| 964 |
reduction="mean"
|
| 965 |
)
|
| 966 |
|
|
|
|
| 1018 |
hidden_states=outputs.hidden_states,
|
| 1019 |
attentions=outputs.attentions,
|
| 1020 |
)
|
| 1021 |
+
|
| 1022 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1023 |
|
| 1024 |
__all__ = [
|
|
|
|
| 1029 |
"FANLayer",
|
| 1030 |
]
|
| 1031 |
|
|
|
|
| 1032 |
# Register the configuration and model for AutoClass support
|
| 1033 |
AutoConfig.register("neollm", NeoLLMConfig)
|
| 1034 |
AutoModel.register(NeoLLMConfig, NeoLLMModel)
|
| 1035 |
+
AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)
|