katuni4ka commited on
Commit
8b6d087
·
verified ·
1 Parent(s): 37debcc

Upload 8 files

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Mistral4ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_mistral4.Mistral4Config",
9
+ "AutoModelForCausalLM": "modeling_mistral4.Mistral4ForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "dtype": "float32",
13
+ "eos_token_id": 2,
14
+ "first_k_dense_replace": 0,
15
+ "head_dim": 24,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 64,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 32,
20
+ "kv_lora_rank": 16,
21
+ "max_position_embeddings": 1048576,
22
+ "mlp_bias": false,
23
+ "model_type": "mistral4",
24
+ "moe_intermediate_size": 16,
25
+ "n_group": 1,
26
+ "n_routed_experts": 4,
27
+ "n_shared_experts": 1,
28
+ "norm_topk_prob": true,
29
+ "num_attention_heads": 4,
30
+ "num_experts_per_tok": 2,
31
+ "num_hidden_layers": 3,
32
+ "num_key_value_heads": 4,
33
+ "pad_token_id": 11,
34
+ "pretraining_tp": 1,
35
+ "q_lora_rank": 32,
36
+ "qk_head_dim": 24,
37
+ "qk_nope_head_dim": 16,
38
+ "qk_rope_head_dim": 8,
39
+ "rms_norm_eps": 1e-06,
40
+ "rope_interleave": true,
41
+ "rope_parameters": {
42
+ "beta_fast": 32.0,
43
+ "beta_slow": 1.0,
44
+ "factor": 128.0,
45
+ "llama_4_scaling_beta": 0.1,
46
+ "max_position_embeddings": 1048576,
47
+ "mscale": 1.0,
48
+ "mscale_all_dim": 1.0,
49
+ "original_max_position_embeddings": 8192,
50
+ "partial_rotary_factor": 0.3333333333333333,
51
+ "rope_theta": 10000.0,
52
+ "rope_type": "yarn",
53
+ "type": "yarn"
54
+ },
55
+ "rope_scaling": {
56
+ "beta_fast": 32.0,
57
+ "beta_slow": 1.0,
58
+ "factor": 128.0,
59
+ "llama_4_scaling_beta": 0.1,
60
+ "max_position_embeddings": 1048576,
61
+ "mscale": 1.0,
62
+ "mscale_all_dim": 1.0,
63
+ "original_max_position_embeddings": 8192,
64
+ "partial_rotary_factor": 0.3333333333333333,
65
+ "rope_theta": 10000.0,
66
+ "rope_type": "yarn",
67
+ "type": "yarn"
68
+ },
69
+ "rope_theta": 10000.0,
70
+ "routed_scaling_factor": 1.0,
71
+ "tie_word_embeddings": false,
72
+ "topk_group": 1,
73
+ "transformers_version": "4.57.1",
74
+ "use_cache": true,
75
+ "v_head_dim": 16,
76
+ "vocab_size": 32000
77
+ }
configuration_mistral4.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ class Mistral4Config(PretrainedConfig):
4
+ r"""
5
+ n_group (`int`, *optional*, defaults to 1):
6
+ Number of groups for routed experts.
7
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
8
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
9
+ \--k dense layers--/
10
+ rope_interleave (`bool`, *optional*, defaults to `True`):
11
+ Whether to interleave the rotary position embeddings.
12
+
13
+ Example:
14
+
15
+ ```python
16
+ >>> from transformers import Mistral4Model, Mistral4Config
17
+
18
+ >>> # Initializing a Mistral4 style configuration
19
+ >>> configuration = Mistral4Config()
20
+
21
+ >>> # Accessing the model configuration
22
+ >>> configuration = model.config
23
+ ```"""
24
+
25
+ model_type = "mistral4"
26
+ keys_to_ignore_at_inference = ["past_key_values"]
27
+ base_model_tp_plan = {
28
+ "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
29
+ "layers.*.mlp.experts.down_proj": "rowwise",
30
+ "layers.*.mlp.experts": "moe_tp_experts",
31
+ "layers.*.mlp.shared_experts.gate_proj": "colwise",
32
+ "layers.*.mlp.shared_experts.up_proj": "colwise",
33
+ "layers.*.mlp.shared_experts.down_proj": "rowwise",
34
+ "layers.*.mlp.gate_proj": "colwise",
35
+ "layers.*.mlp.up_proj": "colwise",
36
+ "layers.*.mlp.down_proj": "rowwise",
37
+ }
38
+ base_model_pp_plan = {
39
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
40
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
41
+ "norm": (["hidden_states"], ["hidden_states"]),
42
+ }
43
+ attribute_map = {
44
+ "num_local_experts": "n_routed_experts",
45
+ }
46
+
47
+ def __init__(
48
+ self,
49
+ vocab_size: int = 131072,
50
+ hidden_size: int = 4096,
51
+ intermediate_size: int = 12288,
52
+ moe_intermediate_size: int = 2048,
53
+ num_hidden_layers: int = 36,
54
+ num_attention_heads: int = 32,
55
+ num_key_value_heads: int | None = 32,
56
+ n_shared_experts: int = 1,
57
+ n_routed_experts: int = 128,
58
+ routed_scaling_factor: float = 1.0,
59
+ kv_lora_rank: int = 256,
60
+ q_lora_rank: int | None = 1024,
61
+ qk_rope_head_dim: int = 64,
62
+ v_head_dim: int | None = 128,
63
+ qk_nope_head_dim: int = 64,
64
+ n_group: int | None = 1,
65
+ topk_group: int | None = 1,
66
+ num_experts_per_tok: int | None = 4,
67
+ first_k_dense_replace: int | None = 0,
68
+ norm_topk_prob: bool | None = True,
69
+ hidden_act: str = "silu",
70
+ max_position_embeddings: int = 1048576,
71
+ initializer_range: float = 0.02,
72
+ rms_norm_eps: float = 1e-6,
73
+ use_cache: bool = True,
74
+ pad_token_id: int | None = 11,
75
+ bos_token_id: int | None = 1,
76
+ eos_token_id: int | list[int] | None = 2,
77
+ pretraining_tp: int | None = 1,
78
+ tie_word_embeddings: bool = False,
79
+ rope_parameters: dict | None = None,
80
+ rope_interleave: bool | None = True,
81
+ attention_bias: bool = False,
82
+ attention_dropout: float | int | None = 0.0,
83
+ mlp_bias: bool = False,
84
+ **kwargs
85
+ ):
86
+ super().__init__(**kwargs)
87
+ self.vocab_size = vocab_size
88
+ self.hidden_size = hidden_size
89
+ self.intermediate_size = intermediate_size
90
+ self.moe_intermediate_size = moe_intermediate_size
91
+ self.num_hidden_layers = num_hidden_layers
92
+ self.num_attention_heads = num_attention_heads
93
+ self.num_key_value_heads = num_key_value_heads
94
+ self.n_shared_experts = n_shared_experts
95
+ self.n_routed_experts = n_routed_experts
96
+ self.routed_scaling_factor = routed_scaling_factor
97
+ self.kv_lora_rank = kv_lora_rank
98
+ self.q_lora_rank = q_lora_rank
99
+ self.qk_rope_head_dim = qk_rope_head_dim
100
+ self.v_head_dim = v_head_dim
101
+ self.qk_nope_head_dim = qk_nope_head_dim
102
+ self.n_group = n_group
103
+ self.topk_group = topk_group
104
+ self.num_experts_per_tok = num_experts_per_tok
105
+ self.first_k_dense_replace = first_k_dense_replace
106
+ self.norm_topk_prob = norm_topk_prob
107
+ self.hidden_act = hidden_act
108
+ self.max_position_embeddings = max_position_embeddings
109
+ self.initializer_range = initializer_range
110
+ self.rms_norm_eps = rms_norm_eps
111
+ self.use_cache = use_cache
112
+ self.pad_token_id = pad_token_id
113
+ self.bos_token_id = bos_token_id
114
+ self.eos_token_id = eos_token_id
115
+ self.pretraining_tp = pretraining_tp
116
+ self.mlp_bias = mlp_bias
117
+ if rope_parameters is None:
118
+ rope_parameters = {
119
+ "type": "yarn",
120
+ "rope_theta": 10000.0,
121
+ "factor": 128.0,
122
+ "original_max_position_embeddings": 8192,
123
+ "max_position_embeddings": self.max_position_embeddings,
124
+ "beta_fast": 32.0,
125
+ "beta_slow": 1.0,
126
+ "mscale_all_dim": 1.0,
127
+ "mscale": 1.0,
128
+ "llama_4_scaling_beta": 0.1,
129
+ "partial_rotary_factor": self.qk_rope_head_dim / (self.qk_nope_head_dim + self.qk_rope_head_dim),
130
+ }
131
+ if "partial_rotary_factor" not in rope_parameters:
132
+ rope_parameters["partial_rotary_factor"] = self.qk_rope_head_dim / (self.qk_nope_head_dim + self.qk_rope_head_dim)
133
+ self.rope_parameters = rope_parameters
134
+ self.rope_theta = rope_parameters["rope_theta"]
135
+ self.rope_scaling = rope_parameters
136
+ self.rope_interleave = rope_interleave
137
+ self.attention_bias = attention_bias
138
+ self.attention_dropout = attention_dropout
139
+ self.tie_word_embeddings = tie_word_embeddings
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+ __all__ = ["Mistral4Config"]
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 11,
6
+ "transformers_version": "4.57.1"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfdf912f19d09e9fbd44cafdfef1981e2d65945f39aa0c66de89c6aae128d4fb
3
+ size 16880152
modeling_mistral4.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from transformers.cache_utils import Cache, DynamicCache
8
+ from transformers.generation import GenerationMixin
9
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
12
+ from transformers.utils import logging
13
+ from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
14
+ DeepseekV3Attention,
15
+ DeepseekV3DecoderLayer,
16
+ apply_rotary_pos_emb_interleave,
17
+ )
18
+ from transformers.models.llama.modeling_llama import (
19
+ LlamaRMSNorm,
20
+ LlamaRotaryEmbedding,
21
+ apply_rotary_pos_emb,
22
+ eager_attention_forward,
23
+ )
24
+ from transformers.masking_utils import create_causal_mask
25
+ from transformers.models.gemma.modeling_gemma import GemmaMLP
26
+ from .configuration_mistral4 import Mistral4Config
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
33
+ scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
34
+ return scaling[:, None, :, None]
35
+
36
+ class Mistral4RMSNorm(LlamaRMSNorm):
37
+ pass
38
+
39
+
40
+ class Mistral4RotaryEmbedding(LlamaRotaryEmbedding):
41
+ pass
42
+
43
+
44
+ class Mistral4MLP(GemmaMLP):
45
+ def __init__(self, config, intermediate_size=None):
46
+ super().__init__(config)
47
+ self.config = config
48
+ self.hidden_size = config.hidden_size
49
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
50
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
51
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
52
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
53
+ self.act_fn = ACT2FN[config.hidden_act]
54
+
55
+
56
+ class Mistral4TopkRouter(nn.Module):
57
+ def __init__(self, config):
58
+ super().__init__()
59
+ self.config = config
60
+ self.n_routed_experts = config.n_routed_experts
61
+
62
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
63
+
64
+ def forward(self, hidden_states):
65
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
66
+ router_logits = F.linear(hidden_states, self.weight)
67
+ return router_logits
68
+
69
+
70
+ class Mistral4NaiveMoe(nn.Module):
71
+ """Collection of expert weights stored as 3D tensors."""
72
+
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ self.num_experts = config.num_local_experts
76
+ self.hidden_dim = config.hidden_size
77
+ self.intermediate_dim = config.intermediate_size
78
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
79
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
80
+ self.act_fn = ACT2FN[config.hidden_act]
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ top_k_index: torch.Tensor,
86
+ top_k_weights: torch.Tensor,
87
+ ) -> torch.Tensor:
88
+ final_hidden_states = torch.zeros_like(hidden_states)
89
+ with torch.no_grad():
90
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
91
+ expert_mask = expert_mask.permute(2, 1, 0)
92
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
93
+
94
+ for expert_idx in expert_hit:
95
+ expert_idx = expert_idx[0]
96
+ if expert_idx == self.num_experts:
97
+ continue
98
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
99
+ current_state = hidden_states[token_idx]
100
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
101
+ current_hidden_states = self.act_fn(gate) * up
102
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
103
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
104
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
105
+
106
+ return final_hidden_states
107
+
108
+
109
+ class Mistral4MoE(nn.Module):
110
+ """
111
+ A mixed expert module containing shared experts.
112
+ """
113
+
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.config = config
117
+ self.experts = Mistral4NaiveMoe(config)
118
+ self.gate = Mistral4TopkRouter(config)
119
+ self.shared_experts = Mistral4MLP(
120
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
121
+ )
122
+ self.n_routed_experts = config.n_routed_experts
123
+ self.n_group = config.n_group
124
+ self.topk_group = config.topk_group
125
+ self.norm_topk_prob = config.norm_topk_prob
126
+ self.routed_scaling_factor = config.routed_scaling_factor
127
+ self.top_k = config.num_experts_per_tok
128
+
129
+ def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
130
+ router_logits = router_logits.softmax(-1)
131
+ group_scores = (
132
+ router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
133
+ )
134
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
135
+ group_mask = torch.zeros_like(group_scores)
136
+ group_mask.scatter_(1, group_idx, 1)
137
+ score_mask = (
138
+ group_mask.unsqueeze(-1)
139
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
140
+ .reshape(-1, self.n_routed_experts)
141
+ )
142
+ scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
143
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
144
+ topk_weights = router_logits.gather(1, topk_indices)
145
+ if self.norm_topk_prob:
146
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
147
+ topk_weights /= denominator
148
+ topk_weights = topk_weights * self.routed_scaling_factor
149
+ return topk_indices, topk_weights
150
+
151
+ def forward(self, hidden_states):
152
+ residuals = hidden_states
153
+ orig_shape = hidden_states.shape
154
+ router_logits = self.gate(hidden_states)
155
+ topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
156
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
157
+ hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
158
+ hidden_states = hidden_states + self.shared_experts(residuals)
159
+ return hidden_states
160
+
161
+
162
+ class Mistral4Attention(DeepseekV3Attention):
163
+ def __init__(self, config: Mistral4Config, layer_idx: int):
164
+ nn.Module.__init__(self)
165
+ self.config = config
166
+ self.layer_idx = layer_idx
167
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
168
+ self.attention_dropout = config.attention_dropout
169
+ self.num_heads = config.num_attention_heads
170
+
171
+ self.q_lora_rank = config.q_lora_rank
172
+ self.qk_rope_head_dim = config.qk_rope_head_dim
173
+ self.kv_lora_rank = config.kv_lora_rank
174
+ self.v_head_dim = config.v_head_dim
175
+ self.qk_nope_head_dim = config.qk_nope_head_dim
176
+ self.qk_head_dim = config.qk_head_dim
177
+
178
+ self.is_causal = True
179
+ if self.q_lora_rank is None:
180
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
181
+ else:
182
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
183
+ self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank)
184
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
185
+
186
+ self.kv_a_proj_with_mqa = nn.Linear(
187
+ config.hidden_size,
188
+ self.kv_lora_rank + self.qk_rope_head_dim,
189
+ bias=config.attention_bias,
190
+ )
191
+ self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank)
192
+ self.kv_b_proj = nn.Linear(
193
+ self.kv_lora_rank,
194
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
195
+ bias=False,
196
+ )
197
+
198
+ self.o_proj = nn.Linear(
199
+ self.num_heads * self.v_head_dim,
200
+ config.hidden_size,
201
+ bias=config.attention_bias,
202
+ )
203
+
204
+ self.scaling = self.qk_head_dim ** (-0.5)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
210
+ attention_mask: torch.Tensor | None,
211
+ position_ids: torch.Tensor,
212
+ past_key_values: Cache | None = None,
213
+ **kwargs,
214
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
215
+ batch_size, seq_length = hidden_states.shape[:-1]
216
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
217
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
218
+
219
+ if self.q_lora_rank is None:
220
+ q_states = self.q_proj(hidden_states)
221
+ else:
222
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
223
+ q_states = q_states.view(query_shape).transpose(1, 2)
224
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
225
+
226
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
227
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
228
+
229
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
230
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
231
+
232
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
233
+
234
+ cos, sin = position_embeddings
235
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
236
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
237
+ else:
238
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
239
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
240
+
241
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
242
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
243
+
244
+ query_states = query_states * get_llama_4_attn_scale(
245
+ position_ids,
246
+ self.config.rope_parameters.get("llama_4_scaling_beta"),
247
+ self.config.rope_parameters.get("original_max_position_embeddings"),
248
+ ).to(query_states.dtype)
249
+
250
+ if past_key_values is not None:
251
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
252
+
253
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
254
+ self.config._attn_implementation, eager_attention_forward
255
+ )
256
+
257
+ attn_output, attn_weights = attention_interface(
258
+ self,
259
+ query_states,
260
+ key_states,
261
+ value_states,
262
+ attention_mask,
263
+ dropout=0.0 if not self.training else self.attention_dropout,
264
+ scaling=self.scaling,
265
+ **kwargs,
266
+ )
267
+
268
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
269
+ attn_output = self.o_proj(attn_output)
270
+ return attn_output, attn_weights
271
+
272
+
273
+ class Mistral4DecoderLayer(DeepseekV3DecoderLayer):
274
+ def __init__(self, config: Mistral4Config, layer_idx: int):
275
+ nn.Module.__init__(self)
276
+ self.hidden_size = config.hidden_size
277
+
278
+ self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx)
279
+
280
+ if layer_idx >= config.first_k_dense_replace:
281
+ self.mlp = Mistral4MoE(config)
282
+ else:
283
+ self.mlp = Mistral4MLP(config)
284
+
285
+ self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+ self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
287
+
288
+
289
+ class Mistral4PreTrainedModel(PreTrainedModel):
290
+ config: Mistral4Config
291
+ base_model_prefix = "model"
292
+ supports_gradient_checkpointing = True
293
+ _no_split_modules = ["Mistral4DecoderLayer"]
294
+ _skip_keys_device_placement = ["past_key_values"]
295
+ _supports_flash_attn = True
296
+ _supports_sdpa = True
297
+ _supports_flex_attn = True
298
+
299
+ _can_compile_fullgraph = True
300
+ _supports_attention_backend = True
301
+ _can_record_outputs = {
302
+ "hidden_states": Mistral4DecoderLayer,
303
+ "attentions": Mistral4Attention,
304
+ }
305
+ _keep_in_fp32_modules_strict = []
306
+ _keys_to_ignore_on_load_unexpected = []
307
+
308
+ @torch.no_grad()
309
+ def _init_weights(self, module):
310
+ super()._init_weights(module)
311
+ if isinstance(module, Mistral4TopkRouter):
312
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
313
+ elif isinstance(module, Mistral4NaiveMoe):
314
+ module.gate_up_proj.data.normal_(mean=0.0, std=self.config.initializer_range)
315
+ module.down_proj.normal_(mean=0.0, std=self.config.initializer_range)
316
+
317
+
318
+ class Mistral4Model(Mistral4PreTrainedModel):
319
+ def __init__(self, config: Mistral4Config):
320
+ super().__init__(config)
321
+ self.padding_idx = config.pad_token_id
322
+ self.vocab_size = config.vocab_size
323
+
324
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
325
+ self.layers = nn.ModuleList(
326
+ [Mistral4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
327
+ )
328
+ self.norm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
329
+ self.rotary_emb = Mistral4RotaryEmbedding(config=config)
330
+ self.gradient_checkpointing = False
331
+
332
+ # Initialize weights and apply final processing
333
+ self.post_init()
334
+
335
+ def forward(
336
+ self,
337
+ input_ids: torch.LongTensor | None = None,
338
+ attention_mask: torch.Tensor | None = None,
339
+ position_ids: torch.LongTensor | None = None,
340
+ past_key_values: Cache | None = None,
341
+ inputs_embeds: torch.FloatTensor | None = None,
342
+ use_cache: bool | None = None,
343
+ **kwargs,
344
+ ) :
345
+ if (input_ids is None) ^ (inputs_embeds is not None):
346
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
347
+
348
+ if inputs_embeds is None:
349
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
350
+
351
+ if use_cache and past_key_values is None:
352
+ past_key_values = DynamicCache(config=self.config)
353
+
354
+ if position_ids is None:
355
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
356
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
357
+ position_ids = position_ids.unsqueeze(0)
358
+
359
+ causal_mask = create_causal_mask(
360
+ config=self.config,
361
+ inputs_embeds=inputs_embeds,
362
+ attention_mask=attention_mask,
363
+ past_key_values=past_key_values,
364
+ position_ids=position_ids,
365
+ )
366
+
367
+ hidden_states = inputs_embeds
368
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
369
+
370
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
371
+ hidden_states = decoder_layer(
372
+ hidden_states,
373
+ attention_mask=causal_mask,
374
+ position_embeddings=position_embeddings,
375
+ position_ids=position_ids,
376
+ past_key_values=past_key_values,
377
+ use_cache=use_cache,
378
+ **kwargs,
379
+ )
380
+
381
+ hidden_states = self.norm(hidden_states)
382
+ return BaseModelOutputWithPast(
383
+ last_hidden_state=hidden_states,
384
+ past_key_values=past_key_values,
385
+ )
386
+
387
+
388
+ class Mistral4ForCausalLM(Mistral4PreTrainedModel, GenerationMixin):
389
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
390
+ _tp_plan = {"lm_head": "colwise_gather_output"}
391
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
392
+
393
+ def __init__(self, config):
394
+ super().__init__(config)
395
+ self.model = Mistral4Model(config)
396
+ self.vocab_size = config.vocab_size
397
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
398
+
399
+ # Initialize weights and apply final processing
400
+ self.post_init()
401
+
402
+ def forward(
403
+ self,
404
+ input_ids: torch.LongTensor | None = None,
405
+ attention_mask: torch.Tensor | None = None,
406
+ position_ids: torch.LongTensor | None = None,
407
+ past_key_values: Cache | None = None,
408
+ inputs_embeds: torch.FloatTensor | None = None,
409
+ labels: torch.LongTensor | None = None,
410
+ use_cache: bool | None = None,
411
+ logits_to_keep: int | torch.Tensor = 0,
412
+ **kwargs,
413
+ ) -> CausalLMOutputWithPast:
414
+ r"""
415
+ Example:
416
+
417
+ ```python
418
+ >>> from transformers import AutoTokenizer, Mistral4ForCausalLM
419
+
420
+ >>> model = Mistral4ForCausalLM.from_pretrained("meta-mistral4/Mistral4-2-7b-hf")
421
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral4/Mistral4-2-7b-hf")
422
+
423
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
424
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
425
+
426
+ >>> # Generate
427
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
428
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
429
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
430
+ ```"""
431
+ outputs: BaseModelOutputWithPast = self.model(
432
+ input_ids=input_ids,
433
+ attention_mask=attention_mask,
434
+ position_ids=position_ids,
435
+ past_key_values=past_key_values,
436
+ inputs_embeds=inputs_embeds,
437
+ use_cache=use_cache,
438
+ **kwargs,
439
+ )
440
+
441
+ hidden_states = outputs.last_hidden_state
442
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
443
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
444
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
445
+
446
+ loss = None
447
+ if labels is not None:
448
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
449
+
450
+ return CausalLMOutputWithPast(
451
+ loss=loss,
452
+ logits=logits,
453
+ past_key_values=outputs.past_key_values,
454
+ hidden_states=outputs.hidden_states,
455
+ attentions=outputs.attentions,
456
+ )
457
+
458
+
459
+
460
+
461
+ __all__ = [
462
+ "Mistral4PreTrainedModel",
463
+ "Mistral4Model",
464
+ "Mistral4ForCausalLM",
465
+ ]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "tokenizer_class": "PreTrainedTokenizer",
5
+ "bos_token": "<s>",
6
+ "clean_up_tokenization_spaces": false,
7
+ "eos_token": "</s>",
8
+ "extra_special_tokens": {},
9
+ "is_local": false,
10
+ "legacy": false,
11
+ "model_max_length": 1000000000000000019884624838656,
12
+ "pad_token": "</s>",
13
+ "sp_model_kwargs": {},
14
+ "spaces_between_special_tokens": false,
15
+ "unk_token": "<unk>",
16
+ "use_default_system_prompt": false
17
+ }