Ahmed-El-Sharkawy commited on
Commit
7d4ed22
·
verified ·
1 Parent(s): fab5200

Upload Reranker and Embedding Model

Browse files
Reranker/RerankerModel.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+
3
+
4
+ class Reranker:
5
+ def __init__(self, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
6
+ self.model = CrossEncoder(model_name)
7
+
8
+ def rerank_results(self, query: str, results: list[dict], top_n: int = 5) -> list[dict]:
9
+ pairs = [(query, r["text"]) for r in results if r.get("text")]
10
+ scores = self.model.predict(pairs)
11
+ scored_results = sorted(zip(scores, results), key=lambda x: x[0], reverse=True)
12
+ return [r for _, r in scored_results[:top_n]]
Reranker/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .RerankerModel import Reranker
embedder/EmbeddingModels.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import numpy as np
3
+ import requests
4
+
5
+ class EmbeddingModel:
6
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2"):
7
+
8
+ self.model = model_name
9
+
10
+ def get_embedder(self):
11
+ return SentenceTransformer(self.model)
12
+
13
+ # Remote (insert & search)
14
+ def _embed_texts(self, texts: list[str]) -> np.ndarray:
15
+ model = self.get_embedder()
16
+ embs = model.encode(
17
+ texts, batch_size=64, show_progress_bar=False,
18
+ convert_to_numpy=True, normalize_embeddings=True
19
+ )
20
+ # Ensure float32
21
+ return embs.astype("float32")
22
+
23
+ def search_remote(self, query: str, k: int = 5, HOST: str="") -> list[dict]:
24
+ """
25
+ Embeds the query and searches the remote vector store.
26
+ Returns a list of result dicts. We expect each item to include at least:
27
+ - score (float)
28
+ - payload (dict) with 'text' and optional metadata
29
+ """
30
+ q = self._embed_texts([query])[0].tolist()
31
+ try:
32
+ resp = requests.post(
33
+ f"{HOST}/search",
34
+ json={"vector": q, "k": k},
35
+ headers={"Content-Type": "application/json"},
36
+ timeout=30
37
+ )
38
+ resp.raise_for_status()
39
+ data = resp.json()
40
+ # print("Raw remote search response:", data)
41
+ # print(f"Row Data: {data}")
42
+ # Each result ideally has {'scores': ..., 'payloads': {...}}.
43
+ payload = data.get("payloads")
44
+ scores = data.get("scores")
45
+ dict = {"scores": scores, "payloads": payload}
46
+ return dict
47
+
48
+ except Exception as e:
49
+ print(f"Remote search failed: {e}")
50
+ return []
51
+
52
+ def retrieve_top_k_remote_texts(self, query: str, k: int = 5, HOST: str="") -> list[str]:
53
+ """
54
+ Uses search_remote() and extracts 'text' from payloads.
55
+ """
56
+ results = self.search_remote(query, k=k, HOST=HOST)
57
+ # print(f"Remote search returned {len(results)} results.")
58
+ # print("res-1:", results)
59
+ texts = []
60
+ sources = []
61
+ # print(results)
62
+ for r in results.get("payloads"):
63
+ t = r.get("text")
64
+ src = r.get("source")
65
+ if isinstance(t, str) and len(t.strip()) > 0:
66
+ texts.append(t.strip())
67
+ if isinstance(src, str) and src:
68
+ sources.append(src)
69
+ # print(f"Retrieved {len(texts)} remote texts for query.")
70
+ # print("Sources:", {len(sources)})
71
+ results = []
72
+ for i in range(len(sources)):
73
+ results.append({"text": texts[i], "source": sources[i]})
74
+ # print("Results-2:", results)
75
+ return results
76
+
embedder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .EmbeddingModels import EmbeddingModel