KitsuVp commited on
Commit
3fd0d17
·
verified ·
1 Parent(s): 3e804e8

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +37 -34
modeling_neollm.py CHANGED
@@ -1,8 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
- NeoLLM Model with FANformer Integration and Dropout Regularization
4
- Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling
5
- and dropout regularization at strategic locations
6
  """
7
 
8
  import math
@@ -45,8 +45,6 @@ else:
45
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
46
 
47
  logger = logging.get_logger(__name__)
48
-
49
-
50
  class FANLayer(nn.Module):
51
  """
52
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
@@ -63,26 +61,27 @@ class FANLayer(nn.Module):
63
  self.hidden_size = hidden_size
64
  self.fan_ratio = fan_ratio
65
 
66
- # Calculate dimensions for periodic and non-periodic components
67
- self.periodic_dim = int(hidden_size * fan_ratio)
68
- self.non_periodic_dim = hidden_size - self.periodic_dim
 
 
69
 
70
- # Projection matrices
71
- self.Wp = nn.Linear(hidden_size, self.periodic_dim, bias=False)
72
- self.Wp_bar = nn.Linear(hidden_size, self.non_periodic_dim, bias=True)
 
 
 
73
 
74
  # Initialize parameters
75
  self._init_weights()
76
 
77
  def _init_weights(self):
78
  """Initialize weights following the paper's recommendations."""
79
- # Initialize Wp for periodic components
80
- nn.init.normal_(self.Wp.weight, mean=0.0, std=0.02)
81
-
82
- # Initialize Wp_bar for non-periodic components
83
- nn.init.normal_(self.Wp_bar.weight, mean=0.0, std=0.02)
84
- if self.Wp_bar.bias is not None:
85
- nn.init.zeros_(self.Wp_bar.bias)
86
 
87
  def forward(self, x: torch.Tensor) -> torch.Tensor:
88
  """
@@ -93,17 +92,14 @@ class FANLayer(nn.Module):
93
 
94
  Returns:
95
  Transformed tensor with Fourier components concatenated
 
96
  """
97
- # Get periodic components
98
- x_periodic = self.Wp(x) # (batch, seq_len, periodic_dim)
99
- cos_component = torch.cos(x_periodic)
100
- sin_component = torch.sin(x_periodic)
101
-
102
- # Get non-periodic component (linear transformation)
103
- x_non_periodic = self.Wp_bar(x) # (batch, seq_len, non_periodic_dim)
104
 
105
  # Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)]
106
- x_fan = torch.cat([cos_component, sin_component, x_non_periodic], dim=-1)
107
 
108
  return x_fan
109
 
@@ -287,7 +283,7 @@ def eager_attention_forward(
287
 
288
 
289
  class NeoLLMAttention(nn.Module):
290
- """Multi-headed attention with FANformer integration for periodicity modeling"""
291
 
292
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
293
  super().__init__()
@@ -338,8 +334,10 @@ class NeoLLMAttention(nn.Module):
338
 
339
  # Apply FANformer transformation first
340
  hidden_states_fan = self.fan_layer(hidden_states)
 
341
  hidden_shape = (*input_shape, -1, self.head_dim)
342
 
 
343
  query_states, gate = torch.chunk(
344
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
345
  )
@@ -537,7 +535,7 @@ def torch_recurrent_gated_delta_rule(
537
  return core_attn_out, last_recurrent_state
538
 
539
  class NeoLLMGatedDeltaNet(nn.Module):
540
- """Linear attention with FANformer integration for periodicity modeling"""
541
 
542
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
543
  super().__init__()
@@ -659,7 +657,8 @@ class NeoLLMGatedDeltaNet(nn.Module):
659
 
660
  # Apply FANformer transformation first
661
  hidden_states_fan = self.fan_layer(hidden_states)
662
-
 
663
  projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
664
  projected_states_ba = self.in_proj_ba(hidden_states_fan)
665
  query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
@@ -737,6 +736,7 @@ class PolyNorm(torch.nn.Module):
737
 
738
  def forward(self, x):
739
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
 
740
  class NeoLLMMLP(nn.Module):
741
  def __init__(self, config):
742
  super().__init__()
@@ -817,7 +817,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
817
  **kwargs,
818
  )
819
 
820
- # Residual connection
821
  hidden_states = residual + hidden_states
822
 
823
  # Apply GPAS after attention residual connection
@@ -832,7 +832,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
832
 
833
  hidden_states = self.mlp(hidden_states)
834
 
835
- # Residual connection
836
  hidden_states = residual + hidden_states
837
 
838
  # Apply GPAS after MLP residual connection
@@ -867,6 +867,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
867
  def __init__(self, config: NeoLLMConfig):
868
  super().__init__(config)
869
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
 
 
870
  self.layers = nn.ModuleList(
871
  [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
872
  )
@@ -934,6 +936,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
934
  if attention_mask is not None and torch.all(attention_mask == 1):
935
  linear_attn_mask = None
936
  return linear_attn_mask
 
937
  @torch.compiler.disable
938
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
939
  """
@@ -957,7 +960,7 @@ def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, p
957
  processed_labels,
958
  bias=lm_head_bias,
959
  shift=1,
960
- impl="cce",
961
  reduction="mean"
962
  )
963
 
@@ -1015,6 +1018,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1015
  hidden_states=outputs.hidden_states,
1016
  attentions=outputs.attentions,
1017
  )
 
1018
  # ==================== AUTOMODEL REGISTRATION ====================
1019
 
1020
  __all__ = [
@@ -1025,8 +1029,7 @@ __all__ = [
1025
  "FANLayer",
1026
  ]
1027
 
1028
-
1029
  # Register the configuration and model for AutoClass support
1030
  AutoConfig.register("neollm", NeoLLMConfig)
1031
  AutoModel.register(NeoLLMConfig, NeoLLMModel)
1032
- AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)
 
1
  #!/usr/bin/env python3
2
  """
3
+ NeoLLM Model with FANformer Integration, Dropout Regularization, and Selective Self-Attention (SSA)
4
+ Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling,
5
+ dropout regularization at strategic locations
6
  """
7
 
8
  import math
 
45
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
46
 
47
  logger = logging.get_logger(__name__)
 
 
48
  class FANLayer(nn.Module):
49
  """
50
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
 
61
  self.hidden_size = hidden_size
62
  self.fan_ratio = fan_ratio
63
 
64
+ # Calculate dimensions following the paper's approach
65
+ # Output will be: [cos(p) || sin(p) || g] where total = hidden_size + periodic_dim
66
+ output_dim = hidden_size + int(hidden_size * fan_ratio)
67
+ self.p_output_dim = int(output_dim * fan_ratio)
68
+ self.g_output_dim = output_dim - self.p_output_dim * 2
69
 
70
+ # Single fused projection (more efficient than two separate projections)
71
+ self.input_linear = nn.Linear(
72
+ hidden_size,
73
+ self.p_output_dim + self.g_output_dim,
74
+ bias=True
75
+ )
76
 
77
  # Initialize parameters
78
  self._init_weights()
79
 
80
  def _init_weights(self):
81
  """Initialize weights following the paper's recommendations."""
82
+ nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02)
83
+ if self.input_linear.bias is not None:
84
+ nn.init.zeros_(self.input_linear.bias)
 
 
 
 
85
 
86
  def forward(self, x: torch.Tensor) -> torch.Tensor:
87
  """
 
92
 
93
  Returns:
94
  Transformed tensor with Fourier components concatenated
95
+ Shape: (batch, seq_len, hidden_size + periodic_dim)
96
  """
97
+ # Single projection followed by split (more efficient)
98
+ pg = self.input_linear(x)
99
+ p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1)
 
 
 
 
100
 
101
  # Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)]
102
+ x_fan = torch.cat([torch.cos(p), torch.sin(p), g], dim=-1)
103
 
104
  return x_fan
105
 
 
283
 
284
 
285
  class NeoLLMAttention(nn.Module):
286
+ """Multi-headed attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
287
 
288
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
289
  super().__init__()
 
334
 
335
  # Apply FANformer transformation first
336
  hidden_states_fan = self.fan_layer(hidden_states)
337
+
338
  hidden_shape = (*input_shape, -1, self.head_dim)
339
 
340
+ # Use FAN-transformed features directly for projections
341
  query_states, gate = torch.chunk(
342
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
343
  )
 
535
  return core_attn_out, last_recurrent_state
536
 
537
  class NeoLLMGatedDeltaNet(nn.Module):
538
+ """Linear attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
539
 
540
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
541
  super().__init__()
 
657
 
658
  # Apply FANformer transformation first
659
  hidden_states_fan = self.fan_layer(hidden_states)
660
+
661
+ # Use FAN-transformed features directly for projections
662
  projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
663
  projected_states_ba = self.in_proj_ba(hidden_states_fan)
664
  query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
 
736
 
737
  def forward(self, x):
738
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
739
+
740
  class NeoLLMMLP(nn.Module):
741
  def __init__(self, config):
742
  super().__init__()
 
817
  **kwargs,
818
  )
819
 
820
+ # Standard residual connection
821
  hidden_states = residual + hidden_states
822
 
823
  # Apply GPAS after attention residual connection
 
832
 
833
  hidden_states = self.mlp(hidden_states)
834
 
835
+ # Standard residual connection
836
  hidden_states = residual + hidden_states
837
 
838
  # Apply GPAS after MLP residual connection
 
867
  def __init__(self, config: NeoLLMConfig):
868
  super().__init__(config)
869
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
870
+
871
+ # Each layer creates its own components (no shared parameters)
872
  self.layers = nn.ModuleList(
873
  [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
874
  )
 
936
  if attention_mask is not None and torch.all(attention_mask == 1):
937
  linear_attn_mask = None
938
  return linear_attn_mask
939
+
940
  @torch.compiler.disable
941
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
942
  """
 
960
  processed_labels,
961
  bias=lm_head_bias,
962
  shift=1,
963
+ impl="cce_kahan_full_c",
964
  reduction="mean"
965
  )
966
 
 
1018
  hidden_states=outputs.hidden_states,
1019
  attentions=outputs.attentions,
1020
  )
1021
+
1022
  # ==================== AUTOMODEL REGISTRATION ====================
1023
 
1024
  __all__ = [
 
1029
  "FANLayer",
1030
  ]
1031
 
 
1032
  # Register the configuration and model for AutoClass support
1033
  AutoConfig.register("neollm", NeoLLMConfig)
1034
  AutoModel.register(NeoLLMConfig, NeoLLMModel)
1035
+ AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)