qherreros commited on
Commit
bf1c93a
·
1 Parent(s): 7fa51ea

add: max_query/doc_length parametrization

Browse files
Files changed (1) hide show
  1. modeling.py +52 -27
modeling.py CHANGED
@@ -1,10 +1,11 @@
1
- import numpy as np
2
  from dataclasses import dataclass
 
 
 
3
  import torch
4
  from torch import nn
5
- from typing import Optional, List, Dict, Tuple
6
- from transformers.models.qwen3 import modeling_qwen3
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
8
 
9
 
10
  @dataclass
@@ -51,9 +52,12 @@ def format_docs_prompts_func(
51
  )
52
 
53
  if instruction:
54
- prompt += f'<instruct>\n{instruction}\n</instruct>\n'
55
 
56
- doc_prompts = [f'<passage id="{i}">\n{doc}{doc_emb_token}\n</passage>' for i, doc in enumerate(docs)]
 
 
 
57
  prompt += "\n".join(doc_prompts) + "\n"
58
  prompt += f"<query>\n{query}{query_emb_token}\n</query>"
59
 
@@ -76,14 +80,19 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
76
 
77
  self.post_init()
78
 
79
- self.special_tokens = {"query_embed_token": "<|rerank_token|>", "doc_embed_token": "<|embed_token|>"}
 
 
 
80
  self.doc_embed_token_id = 151670
81
  self.query_embed_token_id = 151671
82
 
83
  def forward(self, *args, **kwargs) -> CausalLMOutputWithScores:
84
  kwargs.pop("output_hidden_states", None)
85
  kwargs.pop("use_cache", None)
86
- assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()"
 
 
87
  input_ids = kwargs.pop("input_ids", None)
88
 
89
  outputs = super().forward(
@@ -107,7 +116,9 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
107
  query_embeds = self.projector(query_embeds)
108
 
109
  query_embeds_expanded = query_embeds.expand_as(doc_embeds)
110
- scores = torch.nn.functional.cosine_similarity(doc_embeds, query_embeds_expanded, dim=-1).squeeze(-1)
 
 
111
 
112
  return CausalLMOutputWithScores(
113
  loss=None,
@@ -124,13 +135,17 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
124
  if not hasattr(self, "_tokenizer"):
125
  from transformers import AutoTokenizer
126
 
127
- self._tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
 
 
128
 
129
  if self._tokenizer.pad_token is None:
130
  self._tokenizer.pad_token = self._tokenizer.unk_token
131
- self._tokenizer.pad_token_id = self._tokenizer.convert_tokens_to_ids(self._tokenizer.pad_token)
 
 
132
 
133
- self._tokenizer.padding_side = 'left'
134
 
135
  def _truncate_texts(
136
  self,
@@ -144,17 +159,21 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
144
  docs = []
145
  doc_lengths = []
146
  for doc in documents:
147
- doc_tokens = self._tokenizer(doc, truncation=True, max_length=max_doc_length)
148
- if len(doc_tokens['input_ids']) >= max_doc_length:
149
- doc = self._tokenizer.decode(doc_tokens['input_ids'])
150
- doc_lengths.append(len(doc_tokens['input_ids']))
 
 
151
  docs.append(doc)
152
 
153
- query_tokens = self._tokenizer(query, truncation=True, max_length=max_query_length)
154
- if len(query_tokens['input_ids']) >= max_query_length:
155
- query = self._tokenizer.decode(query_tokens['input_ids'])
 
 
156
 
157
- query_length = len(query_tokens['input_ids'])
158
 
159
  return query, docs, doc_lengths, query_length
160
 
@@ -200,6 +219,8 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
200
  documents: List[str],
201
  top_n: Optional[int] = None,
202
  return_embeddings: bool = False,
 
 
203
  ) -> List[dict]:
204
  """
205
  Rerank documents by relevance to a query.
@@ -221,14 +242,14 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
221
 
222
  # Derived from model configuration
223
  max_length = self._tokenizer.model_max_length
224
- max_query_length = 512
225
- max_doc_length = 2048
226
 
227
  # Derive block_size from max_length to fit documents efficiently
228
  # Heuristic: allow ~125 docs per batch for typical doc sizes
229
  block_size = 125
230
 
231
- query, docs, doc_lengths, query_length = self._truncate_texts(query, documents, max_query_length, max_doc_length)
 
 
232
 
233
  length_capacity = max_length - 2 * query_length
234
 
@@ -242,7 +263,9 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
242
  length_capacity -= length
243
 
244
  if len(block_docs) >= block_size or length_capacity <= max_doc_length:
245
- outputs = self._compute_single_batch(query, block_docs, instruction=None)
 
 
246
 
247
  doc_embeddings.extend(outputs.doc_embeds[0].cpu().float().numpy())
248
  query_embeddings.append(outputs.query_embeds[0].cpu().float().numpy())
@@ -277,10 +300,12 @@ class JinaForRanking(modeling_qwen3.Qwen3ForCausalLM):
277
 
278
  return [
279
  {
280
- 'document': documents[scores_argsort[i]],
281
- 'relevance_score': scores[0][scores_argsort[i]],
282
- 'index': scores_argsort[i],
283
- 'embedding': doc_embeddings[scores_argsort[i]] if return_embeddings else None,
 
 
284
  }
285
  for i in range(top_n)
286
  ]
 
 
1
  from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import numpy as np
5
  import torch
6
  from torch import nn
 
 
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.qwen3 import modeling_qwen3
9
 
10
 
11
  @dataclass
 
52
  )
53
 
54
  if instruction:
55
+ prompt += f"<instruct>\n{instruction}\n</instruct>\n"
56
 
57
+ doc_prompts = [
58
+ f'<passage id="{i}">\n{doc}{doc_emb_token}\n</passage>'
59
+ for i, doc in enumerate(docs)
60
+ ]
61
  prompt += "\n".join(doc_prompts) + "\n"
62
  prompt += f"<query>\n{query}{query_emb_token}\n</query>"
63
 
 
80
 
81
  self.post_init()
82
 
83
+ self.special_tokens = {
84
+ "query_embed_token": "<|rerank_token|>",
85
+ "doc_embed_token": "<|embed_token|>",
86
+ }
87
  self.doc_embed_token_id = 151670
88
  self.query_embed_token_id = 151671
89
 
90
  def forward(self, *args, **kwargs) -> CausalLMOutputWithScores:
91
  kwargs.pop("output_hidden_states", None)
92
  kwargs.pop("use_cache", None)
93
+ assert kwargs.pop("labels", None) is None, (
94
+ "labels should not be passed to forward()"
95
+ )
96
  input_ids = kwargs.pop("input_ids", None)
97
 
98
  outputs = super().forward(
 
116
  query_embeds = self.projector(query_embeds)
117
 
118
  query_embeds_expanded = query_embeds.expand_as(doc_embeds)
119
+ scores = torch.nn.functional.cosine_similarity(
120
+ doc_embeds, query_embeds_expanded, dim=-1
121
+ ).squeeze(-1)
122
 
123
  return CausalLMOutputWithScores(
124
  loss=None,
 
135
  if not hasattr(self, "_tokenizer"):
136
  from transformers import AutoTokenizer
137
 
138
+ self._tokenizer = AutoTokenizer.from_pretrained(
139
+ self.name_or_path, trust_remote_code=True
140
+ )
141
 
142
  if self._tokenizer.pad_token is None:
143
  self._tokenizer.pad_token = self._tokenizer.unk_token
144
+ self._tokenizer.pad_token_id = self._tokenizer.convert_tokens_to_ids(
145
+ self._tokenizer.pad_token
146
+ )
147
 
148
+ self._tokenizer.padding_side = "left"
149
 
150
  def _truncate_texts(
151
  self,
 
159
  docs = []
160
  doc_lengths = []
161
  for doc in documents:
162
+ doc_tokens = self._tokenizer(
163
+ doc, truncation=True, max_length=max_doc_length
164
+ )
165
+ if len(doc_tokens["input_ids"]) >= max_doc_length:
166
+ doc = self._tokenizer.decode(doc_tokens["input_ids"])
167
+ doc_lengths.append(len(doc_tokens["input_ids"]))
168
  docs.append(doc)
169
 
170
+ query_tokens = self._tokenizer(
171
+ query, truncation=True, max_length=max_query_length
172
+ )
173
+ if len(query_tokens["input_ids"]) >= max_query_length:
174
+ query = self._tokenizer.decode(query_tokens["input_ids"])
175
 
176
+ query_length = len(query_tokens["input_ids"])
177
 
178
  return query, docs, doc_lengths, query_length
179
 
 
219
  documents: List[str],
220
  top_n: Optional[int] = None,
221
  return_embeddings: bool = False,
222
+ max_doc_length: int = 2048,
223
+ max_query_length: int = 512,
224
  ) -> List[dict]:
225
  """
226
  Rerank documents by relevance to a query.
 
242
 
243
  # Derived from model configuration
244
  max_length = self._tokenizer.model_max_length
 
 
245
 
246
  # Derive block_size from max_length to fit documents efficiently
247
  # Heuristic: allow ~125 docs per batch for typical doc sizes
248
  block_size = 125
249
 
250
+ query, docs, doc_lengths, query_length = self._truncate_texts(
251
+ query, documents, max_query_length, max_doc_length
252
+ )
253
 
254
  length_capacity = max_length - 2 * query_length
255
 
 
263
  length_capacity -= length
264
 
265
  if len(block_docs) >= block_size or length_capacity <= max_doc_length:
266
+ outputs = self._compute_single_batch(
267
+ query, block_docs, instruction=None
268
+ )
269
 
270
  doc_embeddings.extend(outputs.doc_embeds[0].cpu().float().numpy())
271
  query_embeddings.append(outputs.query_embeds[0].cpu().float().numpy())
 
300
 
301
  return [
302
  {
303
+ "document": documents[scores_argsort[i]],
304
+ "relevance_score": scores[0][scores_argsort[i]],
305
+ "index": scores_argsort[i],
306
+ "embedding": doc_embeddings[scores_argsort[i]]
307
+ if return_embeddings
308
+ else None,
309
  }
310
  for i in range(top_n)
311
  ]