likhonsheikh commited on
Commit
eb82cd0
·
verified ·
1 Parent(s): 85348e2

Add model.py

Browse files
Files changed (1) hide show
  1. model.py +362 -0
model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sheikh-2.5-Coder Model Implementation
3
+ ====================================
4
+
5
+ This module implements the Sheikh-2.5-Coder model architecture, a 3B parameter
6
+ transformer model optimized for code generation and on-device deployment.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional, Tuple, List
13
+ from dataclasses import dataclass
14
+ from transformers import (
15
+ PreTrainedModel,
16
+ PreTrainedTokenizer,
17
+ AutoConfig,
18
+ AutoTokenizer,
19
+ AutoModelForCausalLM,
20
+ BitsAndBytesConfig,
21
+ TrainingArguments
22
+ )
23
+ import json
24
+
25
+ @dataclass
26
+ class SheikhConfig:
27
+ """Configuration class for Sheikh-2.5-Coder model."""
28
+
29
+ # Model architecture
30
+ num_attention_heads: int = 16
31
+ num_key_value_heads: int = 2
32
+ hidden_size: int = 3072
33
+ intermediate_size: int = 8192
34
+ num_hidden_layers: int = 36
35
+ vocab_size: int = 50257
36
+
37
+ # Position embeddings
38
+ max_position_embeddings: int = 32768
39
+ rope_theta: float = 10000.0
40
+
41
+ # Attention
42
+ attention_dropout: float = 0.1
43
+ hidden_dropout: float = 0.1
44
+
45
+ # Normalization
46
+ layer_norm_epsilon: float = 1e-6
47
+ rms_norm_eps: float = 1e-6
48
+
49
+ # Activation
50
+ activation_function: str = "swiglu"
51
+
52
+ # Precision
53
+ torch_dtype: str = "bfloat16"
54
+
55
+ # Cache
56
+ use_cache: bool = True
57
+
58
+ # Tie word embeddings
59
+ tie_word_embeddings: bool = True
60
+
61
+ class SheikhRMSNorm(nn.Module):
62
+ """Root Mean Square Layer Normalization."""
63
+
64
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
65
+ super().__init__()
66
+ self.eps = eps
67
+ self.weight = nn.Parameter(torch.ones(hidden_size))
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ input_dtype = x.dtype
71
+ x = x.float()
72
+ variance = x.pow(2).mean(-1, keepdim=True)
73
+ x = x * torch.rsqrt(variance + self.eps)
74
+ return (self.weight * x).to(input_dtype)
75
+
76
+ class SheikhRotaryEmbedding(nn.Module):
77
+ """Rotary Positional Embedding."""
78
+
79
+ def __init__(self, dim: int, max_position_embeddings: int = 32768, base: int = 10000):
80
+ super().__init__()
81
+ self.dim = dim
82
+ self.max_position_embeddings = max_position_embeddings
83
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
84
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
85
+ self._set_cos_sin_cache(
86
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
87
+ )
88
+
89
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
90
+ self.max_seq_len_cached = seq_len
91
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
92
+ freqs = torch.outer(t, self.inv_freq)
93
+ emb = torch.cat((freqs, freqs), dim=-1)
94
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
95
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
96
+
97
+ def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
98
+ if seq_len > self.max_seq_len_cached:
99
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
100
+ return (
101
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
102
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
103
+ )
104
+
105
+ class SheikhAttention(nn.Module):
106
+ """Multi-head attention with Grouped Query Attention."""
107
+
108
+ def __init__(self, config: SheikhConfig):
109
+ super().__init__()
110
+ self.config = config
111
+ self.hidden_size = config.hidden_size
112
+ self.num_heads = config.num_attention_heads
113
+ self.head_dim = self.hidden_size // self.num_heads
114
+ self.num_key_value_heads = config.num_key_value_heads
115
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
116
+
117
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
118
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
119
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
120
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
121
+
122
+ self.rotary_emb = SheikhRotaryEmbedding(
123
+ self.head_dim, max_position_embeddings=config.max_position_embeddings
124
+ )
125
+
126
+ def forward(
127
+ self,
128
+ hidden_states: torch.Tensor,
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ position_ids: Optional[torch.Tensor] = None,
131
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
132
+ output_attentions: bool = False,
133
+ use_cache: bool = False,
134
+ ):
135
+ bsz, q_len, _ = hidden_states.size()
136
+
137
+ # Query, Key, Value projections
138
+ q = self.q_proj(hidden_states)
139
+ k = self.k_proj(hidden_states)
140
+ v = self.v_proj(hidden_states)
141
+
142
+ # Reshape for grouped query attention
143
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
144
+ k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
145
+ v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
146
+
147
+ # Apply rotary embeddings
148
+ cos, sin = self.rotary_emb(v, seq_len=q_len)
149
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
150
+
151
+ # Group key and value for grouped query attention
152
+ k = repeat_kv(k, self.num_key_value_groups)
153
+ v = repeat_kv(v, self.num_key_value_groups)
154
+
155
+ # Scaled dot-product attention
156
+ attn_output = F.scaled_dot_product_attention(
157
+ q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=True
158
+ )
159
+
160
+ # Reshape and project output
161
+ attn_output = attn_output.transpose(1, 2).contiguous()
162
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
163
+ attn_output = self.o_proj(attn_output)
164
+
165
+ if not output_attentions:
166
+ attn_weights = None
167
+
168
+ outputs = (attn_output,)
169
+ if output_attentions:
170
+ outputs += (attn_weights,)
171
+
172
+ return outputs
173
+
174
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
175
+ """Repeat key/value states for grouped query attention."""
176
+ batch, slen, num_key_value_heads, head_dim = hidden_states.shape
177
+ if n_rep == 1:
178
+ return hidden_states
179
+ hidden_states = hidden_states[:, :, :, None, :].repeat(1, 1, 1, n_rep, 1)
180
+ return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
181
+
182
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor):
183
+ """Apply rotary positional embeddings."""
184
+ def rotate_half(x):
185
+ x1 = x[..., : x.shape[-1] // 2]
186
+ x2 = x[..., x.shape[-1] // 2 :]
187
+ return torch.cat((-x2, x1), dim=-1)
188
+
189
+ cos = cos.squeeze(1).squeeze(0)
190
+ sin = sin.squeeze(1).squeeze(0)
191
+
192
+ cos = cos[position_ids].unsqueeze(1)
193
+ sin = sin[position_ids].unsqueeze(1)
194
+
195
+ q_embed = (q * cos) + (rotate_half(q) * sin)
196
+ k_embed = (k * cos) + (rotate_half(k) * sin)
197
+ return q_embed, k_embed
198
+
199
+ class SheikhMLP(nn.Module):
200
+ """SwiGLU MLP."""
201
+
202
+ def __init__(self, config: SheikhConfig):
203
+ super().__init__()
204
+ self.hidden_size = config.hidden_size
205
+ self.intermediate_size = config.intermediate_size
206
+
207
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
208
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
209
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
210
+
211
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
212
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
213
+
214
+ class SheikhTransformerBlock(nn.Module):
215
+ """Transformer block for Sheikh-2.5-Coder."""
216
+
217
+ def __init__(self, config: SheikhConfig):
218
+ super().__init__()
219
+ self.self_attn = SheikhAttention(config)
220
+ self.mlp = SheikhMLP(config)
221
+ self.input_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
222
+ self.post_attention_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
223
+
224
+ def forward(
225
+ self,
226
+ hidden_states: torch.Tensor,
227
+ attention_mask: Optional[torch.Tensor] = None,
228
+ position_ids: Optional[torch.Tensor] = None,
229
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
230
+ output_attentions: bool = False,
231
+ use_cache: bool = False,
232
+ ):
233
+ # Self-attention
234
+ attn_output, _ = self.self_attn(
235
+ self.input_layernorm(hidden_states),
236
+ attention_mask=attention_mask,
237
+ position_ids=position_ids,
238
+ past_key_value=past_key_value,
239
+ output_attentions=output_attentions,
240
+ use_cache=use_cache,
241
+ )
242
+ hidden_states = hidden_states + attn_output
243
+
244
+ # MLP
245
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
246
+ hidden_states = hidden_states + mlp_output
247
+
248
+ return hidden_states
249
+
250
+ class SheikhModel(PreTrainedModel):
251
+ """Sheikh-2.5-Coder base model."""
252
+
253
+ def __init__(self, config: SheikhConfig):
254
+ super().__init__(config)
255
+ self.config = config
256
+
257
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
258
+ self.layers = nn.ModuleList([SheikhTransformerBlock(config) for _ in range(config.num_hidden_layers)])
259
+ self.norm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
260
+
261
+ # Initialize weights
262
+ self.apply(self._init_weights)
263
+
264
+ def _init_weights(self, module):
265
+ """Initialize model weights."""
266
+ if isinstance(module, nn.Linear):
267
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
268
+ if module.bias is not None:
269
+ torch.nn.init.zeros_(module.bias)
270
+ elif isinstance(module, nn.Embedding):
271
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
272
+
273
+ def get_input_embeddings(self):
274
+ return self.embed_tokens
275
+
276
+ def set_input_embeddings(self, value):
277
+ self.embed_tokens = value
278
+
279
+ def forward(
280
+ self,
281
+ input_ids: torch.Tensor = None,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ position_ids: Optional[torch.Tensor] = None,
284
+ past_key_values: Optional[List[torch.Tensor]] = None,
285
+ inputs_embeds: Optional[torch.Tensor] = None,
286
+ use_cache: Optional[bool] = None,
287
+ output_attentions: Optional[bool] = None,
288
+ output_hidden_states: Optional[bool] = None,
289
+ return_dict: Optional[bool] = None,
290
+ ):
291
+ # Implementation continues...
292
+ pass
293
+
294
+ # Model loading utilities
295
+ def load_sheikh_model(
296
+ model_name_or_path: str,
297
+ device_map: Optional[str] = "auto",
298
+ torch_dtype: torch.dtype = torch.bfloat16,
299
+ load_in_8bit: bool = False,
300
+ load_in_4bit: bool = False,
301
+ ) -> AutoModelForCausalLM:
302
+ """Load Sheikh-2.5-Coder model with optional quantization."""
303
+
304
+ # Setup quantization config
305
+ quantization_config = None
306
+ if load_in_8bit:
307
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
308
+ elif load_in_4bit:
309
+ quantization_config = BitsAndBytesConfig(
310
+ load_in_4bit=True,
311
+ bnb_4bit_compute_dtype=torch.bfloat16,
312
+ bnb_4bit_use_double_quant=True,
313
+ bnb_4bit_quant_type="nf4",
314
+ )
315
+
316
+ # Load tokenizer and model
317
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
318
+ model = AutoModelForCausalLM.from_pretrained(
319
+ model_name_or_path,
320
+ device_map=device_map,
321
+ torch_dtype=torch_dtype,
322
+ quantization_config=quantization_config,
323
+ )
324
+
325
+ return model, tokenizer
326
+
327
+ # Model training utilities
328
+ def setup_training_args(output_dir: str, learning_rate: float = 1e-4) -> TrainingArguments:
329
+ """Setup training arguments for Sheikh-2.5-Coder."""
330
+
331
+ return TrainingArguments(
332
+ output_dir=output_dir,
333
+ learning_rate=learning_rate,
334
+ per_device_train_batch_size=8,
335
+ per_device_eval_batch_size=8,
336
+ num_train_epochs=3,
337
+ max_steps=100000,
338
+ logging_steps=100,
339
+ save_steps=2000,
340
+ eval_steps=1000,
341
+ warmup_steps=2000,
342
+ fp16=True,
343
+ bf16=True,
344
+ gradient_accumulation_steps=4,
345
+ gradient_checkpointing=True,
346
+ remove_unused_columns=False,
347
+ dataloader_pin_memory=True,
348
+ report_to="wandb",
349
+ run_name="sheikh-2.5-coder",
350
+ )
351
+
352
+ if __name__ == "__main__":
353
+ # Example usage
354
+ config = SheikhConfig()
355
+ model = SheikhModel(config)
356
+
357
+ # Save configuration
358
+ with open("config.json", "w") as f:
359
+ json.dump(config.__dict__, f, indent=2)
360
+
361
+ print("Sheikh-2.5-Coder model configuration created successfully!")
362
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")