NeoLLM / modeling_neollm.py
KitsuVp's picture
Update modeling_neollm.py
4f8ffab verified
#!/usr/bin/env python3
"""
NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
SeeDNorm (Self-Rescaled Dynamic Normalization), and ResFormer Value Residual Learning
for enhanced information flow through deep layers.
Updated to include:
- Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
- FAN layer in FFN for featural periodicity modeling (complementary coverage)
- SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
- Dropout regularization at strategic locations
- ResFormer: Feature residual connections from first layer (applied before projections)
"""
import math
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from cut_cross_entropy import linear_cross_entropy
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, logging
from transformers.utils.generic import check_model_inputs
from transformers.utils.import_utils import (
is_causal_conv1d_available,
is_flash_linear_attention_available,
)
from .configuration_neollm import NeoLLMConfig
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
if is_flash_linear_attention_available():
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
else:
chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
FusedRMSNormGated = None
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
logger = logging.get_logger(__name__)
class FANLayer(nn.Module):
"""
Fourier Analysis Network (FAN) layer for effective periodicity modeling.
From "FANformer: Improving Large Language Models Through Effective Periodicity Modeling":
FANLayer'(X) = [cos(WpX)||sin(WpX)||(Wp¯X + Bp¯)]
This is the modified version (FANLayer') without activation function that gave
the best results in the paper.
"""
def __init__(self, hidden_size: int, fan_ratio: float = 0.25):
super().__init__()
self.hidden_size = hidden_size
self.fan_ratio = fan_ratio
# Calculate dimensions following the paper's approach
# Output will be: [cos(p) || sin(p) || g] where total = hidden_size + periodic_dim
output_dim = hidden_size + int(hidden_size * fan_ratio)
self.p_output_dim = int(output_dim * fan_ratio)
self.g_output_dim = output_dim - self.p_output_dim * 2
# Single fused projection (more efficient than two separate projections)
self.input_linear = nn.Linear(
hidden_size,
self.p_output_dim + self.g_output_dim,
bias=True
)
# Initialize parameters
self._init_weights()
def _init_weights(self):
"""Initialize weights following the paper's recommendations."""
nn.init.normal_(self.input_linear.weight, mean=0.0, std=0.02)
if self.input_linear.bias is not None:
nn.init.zeros_(self.input_linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply Fourier transformation to input.
Args:
x: Input tensor of shape (batch, seq_len, hidden_size)
Returns:
Transformed tensor with Fourier components concatenated
Shape: (batch, seq_len, hidden_size + periodic_dim)
"""
# Single projection followed by split (more efficient)
pg = self.input_linear(x)
p, g = torch.split(pg, [self.p_output_dim, self.g_output_dim], dim=-1)
# Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)]
x_fan = torch.cat([torch.cos(p), torch.sin(p), g], dim=-1)
return x_fan
class LNS(nn.Module):
"""
LayerNorm Scaling (LNS) - applies scaling factor 1/√ℓ as described in the paper.
From "The Curse of Depth in Large Language Models":
h^(ℓ) = LayerNorm(h^(ℓ)) × (1/√ℓ)
This prevents exponential variance growth in deeper layers.
"""
def __init__(self, layer_idx: int):
super().__init__()
# Layer 1 gets index 1, layer 2 gets index 2, etc.
# Avoid division by zero for layer 0
self.layer_idx = max(layer_idx + 1, 1) # +1 because layer_idx starts from 0
self.scale = 1.0 / math.sqrt(self.layer_idx)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale
class GPAS(nn.Module):
"""
Gradient-Preserving Activation Scaling (GPAS)
Scales activations without penalizing gradients using stop-gradient.
Applied in Pre-Norm style: after sub-layer output but before residual sum.
"""
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
self.alpha = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_detached = x.detach()
scaled_component = F.silu(self.alpha) * x_detached
x_scaled = x - scaled_component
return x_scaled
class SeeDNorm(nn.Module):
"""
Self-Rescaled Dynamic Normalization (SeeDNorm)
From "SeeDNorm: Self-Rescaled Dynamic Normalization":
SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
Dynamically adjusts the scaling coefficient based on the current input,
preserving input norm information and enabling data-dependent normalization.
Key features:
- γ: Static scaling factor (like RMSNorm), initialized to 1
- β: Self-rescaling parameter, initialized to 0
- α: Dynamic modulation parameter, initialized to 1
- σ: tanh activation to constrain dynamic scaling range [-1, 1]
Args:
dim: Hidden dimension size
eps: Small constant for numerical stability
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
# Learnable parameters
self.gamma = nn.Parameter(torch.ones(dim)) # γ: static scaling (RMSNorm-like)
self.beta = nn.Parameter(torch.zeros(dim)) # β: self-rescaling parameter
self.alpha = nn.Parameter(torch.ones(dim)) # α: dynamic modulation parameter
def _rms_norm(self, x: torch.Tensor) -> torch.Tensor:
"""Compute RMS normalization: x / RMS(x)"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply Self-Rescaled Dynamic Normalization.
Args:
x: Input tensor of shape (..., dim)
Returns:
Normalized and dynamically scaled tensor of same shape
"""
# Compute input-dependent rescaling: σ(x·β^T)
# x·β^T produces scalar per token via dot product
rescale_factor = torch.tanh(torch.sum(x * self.beta, dim=-1, keepdim=True))
# Dynamic scaling coefficient: σ(x·β^T)·α + γ
dynamic_scale = rescale_factor * self.alpha + self.gamma
# Apply RMS normalization
x_normalized = self._rms_norm(x.float())
# Apply dynamic scaling
output = x_normalized * dynamic_scale.float()
return output.type_as(x)
def extra_repr(self) -> str:
return f"dim={self.dim}, eps={self.eps}"
class NeoLLMRMSNormGated(nn.Module):
"""
Gated RMSNorm variant used in specific contexts.
"""
def __init__(self, hidden_size, eps=1e-6, **kwargs):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
class NeoLLMRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: NeoLLMConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class NeoLLMAttention(nn.Module):
"""
Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
and ResFormer feature residual connections for enhanced information flow.
ResFormer enhancement: Applies learnable feature residual connections from the first layer
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
"""
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
# FANformer integration: FAN layer before QKV projections
self.fan_layer = FANLayer(
hidden_size=config.hidden_size,
fan_ratio=getattr(config, 'fan_ratio', 0.125)
)
# Calculate the output dimension after FAN transformation
fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125))
# QKV projections operate on FAN-transformed features
self.q_proj = nn.Linear(
fan_output_dim, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
)
self.k_proj = nn.Linear(
fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
# SeeDNorm for Q/K normalization (replaces RMSNorm)
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
# Dropout for attention output
self.dropout = nn.Dropout(config.dropout_rate)
# ResFormer: learnable feature residual parameters (initialized to 0.5)
self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
first_layer_fan: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
input_shape = hidden_states.shape[:-1]
# Apply FANformer transformation first
hidden_states_fan = self.fan_layer(hidden_states)
# ResFormer: Apply feature residual connection BEFORE projections
# This ensures dimensional compatibility across all layer types
if first_layer_fan is not None:
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
# Store current FAN features for potential use as first_layer_fan in subsequent layers
current_layer_fan = hidden_states_fan.clone()
hidden_shape = (*input_shape, -1, self.head_dim)
# Use FAN-transformed features (with residual applied) for projections
query_states, gate = torch.chunk(
self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
# Apply SeeDNorm to Q and K
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output * torch.sigmoid(gate)
attn_output = self.o_proj(attn_output)
attn_output = self.dropout(attn_output)
return attn_output, attn_weights, current_layer_fan
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens
"""
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
is_fast_path_available = all(
(causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
)
def torch_causal_conv1d_update(
hidden_states,
conv_state,
weight,
bias=None,
activation=None,
):
_, hidden_size, seq_len = hidden_states.shape
state_len = conv_state.shape[-1]
hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
conv_state.copy_(hidden_states_new[:, :, -state_len:])
out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
out = F.silu(out[:, :, -seq_len:])
out = out.to(hidden_states.dtype)
return out
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = 1 / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, sequence_length, num_heads, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
tot_heads = num_heads + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
# for each chunk
for i in range(0, tot_heads // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :num_heads]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def torch_recurrent_gated_delta_rule(
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, sequence_length, num_heads, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value)
last_recurrent_state = (
torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
for i in range(num_heads):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
class NeoLLMGatedDeltaNet(nn.Module):
"""
Linear attention with FANformer integration, SeeDNorm for normalization,
and ResFormer feature residual connections for enhanced information flow.
ResFormer enhancement: Applies learnable feature residual connections from the first layer
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
"""
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = layer_idx
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
# FANformer integration: FAN layer before projections
self.fan_layer = FANLayer(
hidden_size=config.hidden_size,
fan_ratio=getattr(config, 'fan_ratio', 0.125)
)
# Calculate the output dimension after FAN transformation
fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125))
# QKV - operates on FAN-transformed features
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=False,
kernel_size=self.conv_kernel_size,
groups=self.conv_dim,
padding=self.conv_kernel_size - 1,
)
# projection of the FAN-transformed hidden states
projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
projection_size_ba = self.num_v_heads * 2
self.in_proj_qkvz = nn.Linear(fan_output_dim, projection_size_qkvz, bias=False)
self.in_proj_ba = nn.Linear(fan_output_dim, projection_size_ba, bias=False)
# time step projection (discretization)
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
A = torch.empty(self.num_v_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
# FLA compatibility: use "silu" for FusedRMSNormGated, original activation elsewhere
fla_compatible_activation = "silu" if self.activation not in ['swish', 'silu', 'sigmoid'] else self.activation
self.norm = (
NeoLLMRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
if FusedRMSNormGated is None
else FusedRMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=fla_compatible_activation,
device=torch.cuda.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
# Dropout for attention output
self.dropout = nn.Dropout(config.dropout_rate)
self.causal_conv1d_fn = causal_conv1d_fn
self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
# ResFormer: learnable feature residual parameters (initialized to 0.5)
self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of the required library is not installed. Falling back to "
"torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
"""
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
self.num_k_heads,
2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,
)
new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
split_arg_list_qkvz = [
self.head_k_dim,
self.head_k_dim,
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
]
split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)
b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)
z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)
b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
return query, key, value, z, b, a
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
first_layer_fan: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
# Apply FANformer transformation first
hidden_states_fan = self.fan_layer(hidden_states)
# ResFormer: Apply feature residual connection BEFORE projections
# This ensures dimensional compatibility across all layer types
if first_layer_fan is not None:
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
# Store current FAN features for potential use as first_layer_fan in subsequent layers
current_layer_fan = hidden_states_fan.clone()
# Use FAN-transformed features (with residual applied) for projections
projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
projected_states_ba = self.in_proj_ba(hidden_states_fan)
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
# Simple convolution without cache
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation="silu", # Keep original activation for conv1d
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
# Use chunk-based implementation without cache
core_attn_out, _ = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
)
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
output = self.out_proj(core_attn_out)
output = self.dropout(output) # Apply dropout after output projection
return output, current_layer_fan
class PolyNorm(torch.nn.Module):
def __init__(self, eps=1e-6):
super(PolyNorm, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
self.bias = torch.nn.Parameter(torch.zeros(1))
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
class NeoLLMMLP(nn.Module):
"""
MLP with FANformer integration for featural periodicity modeling.
This captures periodicities in the feature space (semantic/embedding dimensions)
complementary to the relational periodicities captured by attention mechanisms.
Works in conjunction with ResFormer for comprehensive information flow.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# NEW: FANformer integration for featural space periodicity
self.fan_layer = FANLayer(
hidden_size=config.hidden_size,
fan_ratio=getattr(config, 'fan_ratio_ffn', 0.0625) # Half of attention's fan_ratio
)
# Calculate the output dimension after FAN transformation
fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio_ffn', 0.0625))
# SwiGLU/Gated architecture - now operates on FAN-transformed features
self.gate_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = PolyNorm()
# Dropout for MLP hidden layer
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x):
# NEW: Apply FAN transformation before projections
x_fan = self.fan_layer(x)
# Use FAN-transformed features for gate and up projections
gate_output = self.act_fn(self.gate_proj(x_fan))
up_output = self.up_proj(x_fan)
hidden = gate_output * up_output
hidden = self.dropout(hidden)
return self.down_proj(hidden)
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: NeoLLMConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
# token mixer
self.layer_type = config.layer_types[layer_idx]
if self.layer_type == "linear_attention":
self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx)
elif self.layer_type == "full_attention":
self.self_attn = NeoLLMAttention(config, layer_idx)
# MLP with FANformer integration
self.mlp = NeoLLMMLP(config)
# SeeDNorm for input and post-attention normalization (replaces RMSNorm)
self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
# LNS (LayerNorm Scaling) - applies 1/√ℓ scaling
self.lns_attn = LNS(layer_idx)
self.lns_mlp = LNS(layer_idx)
# GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
self.gpas_attn = GPAS(config.hidden_size)
self.gpas_mlp = GPAS(config.hidden_size)
# ResFormer: storage for current layer's FAN features
self.current_layer_fan = None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
first_layer_fan: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> torch.FloatTensor:
residual = hidden_states
# Apply SeeDNorm normalization
hidden_states = self.input_layernorm(hidden_states)
# Apply LNS scaling after normalization
hidden_states = self.lns_attn(hidden_states)
# Token Mixer with ResFormer feature residual connections
if self.layer_type == "linear_attention":
hidden_states, self.current_layer_fan = self.linear_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
first_layer_fan=first_layer_fan,
)
elif self.layer_type == "full_attention":
# Self Attention
hidden_states, _, self.current_layer_fan = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
first_layer_fan=first_layer_fan,
**kwargs,
)
# Standard residual connection
hidden_states = residual + hidden_states
# Apply GPAS after attention residual connection
hidden_states = self.gpas_attn(hidden_states)
# Fully Connected with FANformer
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# Apply LNS scaling after normalization
hidden_states = self.lns_mlp(hidden_states)
# MLP now includes FAN transformation internally
hidden_states = self.mlp(hidden_states)
# Standard residual connection
hidden_states = residual + hidden_states
# Apply GPAS after MLP residual connection
hidden_states = self.gpas_mlp(hidden_states)
return hidden_states
class NeoLLMPreTrainedModel(PreTrainedModel):
config: NeoLLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NeoLLMDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_is_stateful = True
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, NeoLLMGatedDeltaNet):
module.dt_bias.data.fill_(1.0)
module.A_log.data.uniform_(0, 16).log_()
# ResFormer: initialize lambda parameters for linear attention
if hasattr(module, 'lambda_1'):
module.lambda_1.data.fill_(0.5)
if hasattr(module, 'lambda_2'):
module.lambda_2.data.fill_(0.5)
elif isinstance(module, NeoLLMAttention):
# ResFormer: initialize lambda parameters for full attention
if hasattr(module, 'lambda_1'):
module.lambda_1.data.fill_(0.5)
if hasattr(module, 'lambda_2'):
module.lambda_2.data.fill_(0.5)
elif isinstance(module, GPAS):
# Initialize GPAS alpha to 0 as per paper
module.alpha.data.fill_(0.0)
elif isinstance(module, FANLayer):
# FANLayer initialization is handled within the class
pass
elif isinstance(module, SeeDNorm):
# SeeDNorm initialization:
# gamma (γ) initialized to 1 (default in Parameter definition)
# beta (β) initialized to 0 (default in Parameter definition)
# alpha (α) initialized to 1 (default in Parameter definition)
pass
class NeoLLMModel(NeoLLMPreTrainedModel):
def __init__(self, config: NeoLLMConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
# Each layer creates its own components (no shared parameters)
self.layers = nn.ModuleList(
[NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
# SeeDNorm for final output normalization (replaces RMSNorm)
self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# ResFormer: storage for first layer's FAN features (H_fan_1)
self.first_layer_fan = None
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=position_ids.squeeze(0),
past_key_values=None,
position_ids=position_ids,
)
linear_attn_mask = self._update_linear_attn_mask(attention_mask, position_ids.squeeze(0))
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# ResFormer: reset first_layer_fan at the start of each forward pass
self.first_layer_fan = None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=layer_mask,
first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
**kwargs,
)
# ResFormer: capture H_fan_1 from the first layer
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
self.first_layer_fan = decoder_layer.current_layer_fan
# Apply SeeDNorm for final normalization
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
)
def _update_linear_attn_mask(self, attention_mask, cache_position):
"""
NOTE: Left-padding is used for linear attention mask.
No need for zeroing states when attending to all inputs
"""
linear_attn_mask = attention_mask
if attention_mask is not None and torch.all(attention_mask == 1):
linear_attn_mask = None
return linear_attn_mask
@torch.compiler.disable
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
"""
CCE loss computation excluded from compilation.
Preprocesses labels to eliminate torch.compile warnings.
"""
# Ensure labels are on the correct device
processed_labels = labels.to(hidden_states.device)
# Handle pad tokens: convert pad_token_id to -100 for proper masking
if pad_token_id is not None:
processed_labels = torch.where(
processed_labels == pad_token_id,
torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
processed_labels
)
return linear_cross_entropy(
hidden_states,
lm_head_weight,
processed_labels,
bias=lm_head_bias,
shift=1,
impl="cce_kahan_full_c",
reduction="mean"
)
class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = NeoLLMModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# CCE Loss computation for training
if labels is not None:
loss = compute_cce_loss(
hidden_states,
labels,
self.lm_head.weight,
getattr(self.lm_head, 'bias', None),
self.config.pad_token_id
)
logits = None
else:
# Inference mode - compute logits normally
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# ==================== AUTOMODEL REGISTRATION ====================
__all__ = [
"NeoLLMForCausalLM",
"NeoLLMModel",
"NeoLLMPreTrainedModel",
"NeoLLMConfig",
"FANLayer",
"SeeDNorm",
]
# Register the configuration and model for AutoClass support
AutoConfig.register("neollm", NeoLLMConfig)
AutoModel.register(NeoLLMConfig, NeoLLMModel)
AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)