File size: 2,859 Bytes
7d4ed22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from sentence_transformers import SentenceTransformer
import numpy as np
import requests

class EmbeddingModel:
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2"):
        
        self.model = model_name

    def get_embedder(self):
        return SentenceTransformer(self.model)

    # Remote (insert & search) 
    def _embed_texts(self, texts: list[str]) -> np.ndarray:
        model = self.get_embedder()
        embs = model.encode(
            texts, batch_size=64, show_progress_bar=False,
            convert_to_numpy=True, normalize_embeddings=True
        )
        # Ensure float32
        return embs.astype("float32")

    def search_remote(self, query: str, k: int = 5, HOST: str="") -> list[dict]:
        """

        Embeds the query and searches the remote vector store.

        Returns a list of result dicts. We expect each item to include at least:

        - score (float)

        - payload (dict) with 'text' and optional metadata

        """
        q = self._embed_texts([query])[0].tolist()
        try:
            resp = requests.post(
                f"{HOST}/search",
                json={"vector": q, "k": k},
                headers={"Content-Type": "application/json"},
                timeout=30
            )
            resp.raise_for_status()
            data = resp.json()
            # print("Raw remote search response:", data)
            # print(f"Row Data: {data}")
            # Each result ideally has {'scores': ..., 'payloads': {...}}.
            payload = data.get("payloads")
            scores = data.get("scores")
            dict = {"scores": scores, "payloads": payload}
            return dict

        except Exception as e:
            print(f"Remote search failed: {e}")
            return []

    def retrieve_top_k_remote_texts(self, query: str, k: int = 5, HOST: str="") -> list[str]:
        """

        Uses search_remote() and extracts 'text' from payloads.

        """
        results = self.search_remote(query, k=k, HOST=HOST)
        # print(f"Remote search returned {len(results)} results.")
        # print("res-1:", results)
        texts = []
        sources = []
        # print(results)
        for r in results.get("payloads"):
            t = r.get("text")
            src = r.get("source")
            if isinstance(t, str) and len(t.strip()) > 0:
                texts.append(t.strip())
            if isinstance(src, str) and src:
                sources.append(src)
        # print(f"Retrieved {len(texts)} remote texts for query.")
        # print("Sources:", {len(sources)})
        results = []
        for i in range(len(sources)):
            results.append({"text": texts[i], "source": sources[i]})
        # print("Results-2:", results)
        return results