Spaces:
Running
Running
Merge pull request #10 from The-Obstacle-Is-The-Way/feat/phase6-embeddings
Browse files- Dockerfile +15 -4
- docs/architecture/overview.md +1 -1
- docs/implementation/08_phase_report.md +4 -4
- docs/implementation/roadmap.md +1 -1
- pyproject.toml +4 -0
- src/agents/search_agent.py +67 -14
- src/app.py +15 -8
- src/orchestrator.py +2 -2
- src/orchestrator_magentic.py +44 -4
- src/services/__init__.py +1 -0
- src/services/embeddings.py +166 -0
- src/tools/pubmed.py +1 -1
- src/utils/config.py +2 -2
- src/utils/models.py +1 -0
- tests/unit/agents/test_search_agent.py +44 -1
- tests/unit/services/test_embeddings.py +146 -0
- uv.lock +0 -0
Dockerfile
CHANGED
|
@@ -4,9 +4,10 @@ FROM python:3.11-slim
|
|
| 4 |
# Set working directory
|
| 5 |
WORKDIR /app
|
| 6 |
|
| 7 |
-
# Install system dependencies
|
| 8 |
RUN apt-get update && apt-get install -y \
|
| 9 |
git \
|
|
|
|
| 10 |
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
|
| 12 |
# Install uv
|
|
@@ -18,12 +19,22 @@ COPY uv.lock .
|
|
| 18 |
COPY src/ src/
|
| 19 |
COPY README.md .
|
| 20 |
|
| 21 |
-
# Install dependencies
|
| 22 |
-
RUN uv sync --frozen --no-dev
|
| 23 |
|
| 24 |
-
# Create non-root user
|
| 25 |
RUN useradd --create-home --shell /bin/bash appuser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
USER appuser
|
|
|
|
| 27 |
|
| 28 |
# Expose port
|
| 29 |
EXPOSE 7860
|
|
|
|
| 4 |
# Set working directory
|
| 5 |
WORKDIR /app
|
| 6 |
|
| 7 |
+
# Install system dependencies (curl needed for HEALTHCHECK)
|
| 8 |
RUN apt-get update && apt-get install -y \
|
| 9 |
git \
|
| 10 |
+
curl \
|
| 11 |
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
|
| 13 |
# Install uv
|
|
|
|
| 19 |
COPY src/ src/
|
| 20 |
COPY README.md .
|
| 21 |
|
| 22 |
+
# Install runtime dependencies only (no dev/test tools)
|
| 23 |
+
RUN uv sync --frozen --no-dev --extra embeddings --extra magentic
|
| 24 |
|
| 25 |
+
# Create non-root user BEFORE downloading models
|
| 26 |
RUN useradd --create-home --shell /bin/bash appuser
|
| 27 |
+
|
| 28 |
+
# Set cache directory for HuggingFace models (must be writable by appuser)
|
| 29 |
+
ENV HF_HOME=/app/.cache
|
| 30 |
+
ENV TRANSFORMERS_CACHE=/app/.cache
|
| 31 |
+
|
| 32 |
+
# Create cache dir with correct ownership
|
| 33 |
+
RUN mkdir -p /app/.cache && chown -R appuser:appuser /app/.cache
|
| 34 |
+
|
| 35 |
+
# Pre-download the embedding model during build (as appuser to set correct ownership)
|
| 36 |
USER appuser
|
| 37 |
+
RUN uv run python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
|
| 38 |
|
| 39 |
# Expose port
|
| 40 |
EXPOSE 7860
|
docs/architecture/overview.md
CHANGED
|
@@ -65,7 +65,7 @@ Using existing approved drugs to treat NEW diseases they weren't originally desi
|
|
| 65 |
|
| 66 |
### High-Level Design (Phases 1-8)
|
| 67 |
|
| 68 |
-
```
|
| 69 |
User Query
|
| 70 |
↓
|
| 71 |
Gradio UI (Phase 4)
|
|
|
|
| 65 |
|
| 66 |
### High-Level Design (Phases 1-8)
|
| 67 |
|
| 68 |
+
```text
|
| 69 |
User Query
|
| 70 |
↓
|
| 71 |
Gradio UI (Phase 4)
|
docs/implementation/08_phase_report.md
CHANGED
|
@@ -11,7 +11,7 @@
|
|
| 11 |
Current limitation: **Synthesis is basic markdown, not a scientific report.**
|
| 12 |
|
| 13 |
Current output:
|
| 14 |
-
```
|
| 15 |
## Drug Repurposing Analysis
|
| 16 |
### Drug Candidates
|
| 17 |
- Metformin
|
|
@@ -22,7 +22,7 @@ Current output:
|
|
| 22 |
```
|
| 23 |
|
| 24 |
With Report Agent:
|
| 25 |
-
```
|
| 26 |
## Executive Summary
|
| 27 |
One-paragraph summary for busy readers...
|
| 28 |
|
|
@@ -59,7 +59,7 @@ Properly formatted citations...
|
|
| 59 |
## 2. Architecture
|
| 60 |
|
| 61 |
### Phase 8 Addition
|
| 62 |
-
```
|
| 63 |
Evidence + Hypotheses + Assessment
|
| 64 |
↓
|
| 65 |
Report Agent
|
|
@@ -68,7 +68,7 @@ Evidence + Hypotheses + Assessment
|
|
| 68 |
```
|
| 69 |
|
| 70 |
### Report Generation Flow
|
| 71 |
-
```
|
| 72 |
1. JudgeAgent says "synthesize"
|
| 73 |
2. Magentic Manager selects ReportAgent
|
| 74 |
3. ReportAgent gathers:
|
|
|
|
| 11 |
Current limitation: **Synthesis is basic markdown, not a scientific report.**
|
| 12 |
|
| 13 |
Current output:
|
| 14 |
+
```markdown
|
| 15 |
## Drug Repurposing Analysis
|
| 16 |
### Drug Candidates
|
| 17 |
- Metformin
|
|
|
|
| 22 |
```
|
| 23 |
|
| 24 |
With Report Agent:
|
| 25 |
+
```markdown
|
| 26 |
## Executive Summary
|
| 27 |
One-paragraph summary for busy readers...
|
| 28 |
|
|
|
|
| 59 |
## 2. Architecture
|
| 60 |
|
| 61 |
### Phase 8 Addition
|
| 62 |
+
```text
|
| 63 |
Evidence + Hypotheses + Assessment
|
| 64 |
↓
|
| 65 |
Report Agent
|
|
|
|
| 68 |
```
|
| 69 |
|
| 70 |
### Report Generation Flow
|
| 71 |
+
```text
|
| 72 |
1. JudgeAgent says "synthesize"
|
| 73 |
2. Magentic Manager selects ReportAgent
|
| 74 |
3. ReportAgent gathers:
|
docs/implementation/roadmap.md
CHANGED
|
@@ -165,7 +165,7 @@ tests/
|
|
| 165 |
|
| 166 |
## Complete Architecture (Phases 1-8)
|
| 167 |
|
| 168 |
-
```
|
| 169 |
User Query
|
| 170 |
↓
|
| 171 |
Gradio UI (Phase 4)
|
|
|
|
| 165 |
|
| 166 |
## Complete Architecture (Phases 1-8)
|
| 167 |
|
| 168 |
+
```text
|
| 169 |
User Query
|
| 170 |
↓
|
| 171 |
Gradio UI (Phase 4)
|
pyproject.toml
CHANGED
|
@@ -49,6 +49,10 @@ dev = [
|
|
| 49 |
magentic = [
|
| 50 |
"agent-framework-core",
|
| 51 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
[build-system]
|
| 54 |
requires = ["hatchling"]
|
|
|
|
| 49 |
magentic = [
|
| 50 |
"agent-framework-core",
|
| 51 |
]
|
| 52 |
+
embeddings = [
|
| 53 |
+
"chromadb>=0.4.0",
|
| 54 |
+
"sentence-transformers>=2.2.0",
|
| 55 |
+
]
|
| 56 |
|
| 57 |
[build-system]
|
| 58 |
requires = ["hatchling"]
|
src/agents/search_agent.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from collections.abc import AsyncIterable
|
| 2 |
-
from typing import Any
|
| 3 |
|
| 4 |
from agent_framework import (
|
| 5 |
AgentRunResponse,
|
|
@@ -11,7 +11,10 @@ from agent_framework import (
|
|
| 11 |
)
|
| 12 |
|
| 13 |
from src.orchestrator import SearchHandlerProtocol
|
| 14 |
-
from src.utils.models import Evidence, SearchResult
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
@@ -21,6 +24,7 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
| 21 |
self,
|
| 22 |
search_handler: SearchHandlerProtocol,
|
| 23 |
evidence_store: dict[str, list[Evidence]],
|
|
|
|
| 24 |
) -> None:
|
| 25 |
super().__init__(
|
| 26 |
name="SearchAgent",
|
|
@@ -28,6 +32,7 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
| 28 |
)
|
| 29 |
self._handler = search_handler
|
| 30 |
self._evidence_store = evidence_store
|
|
|
|
| 31 |
|
| 32 |
async def run(
|
| 33 |
self,
|
|
@@ -61,31 +66,79 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
| 61 |
# Execute search
|
| 62 |
result: SearchResult = await self._handler.execute(query, max_results_per_tool=10)
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# Update shared evidence store
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
evidence_text = "\n".join(
|
| 75 |
[
|
| 76 |
f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..."
|
| 77 |
-
for e in
|
| 78 |
]
|
| 79 |
)
|
| 80 |
|
| 81 |
response_text = (
|
| 82 |
-
f"Found {result.total_found} sources ({
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
return AgentRunResponse(
|
| 86 |
messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
|
| 87 |
response_id=f"search-{result.total_found}",
|
| 88 |
-
additional_properties={"evidence": [e.model_dump() for e in
|
| 89 |
)
|
| 90 |
|
| 91 |
async def run_stream(
|
|
|
|
| 1 |
from collections.abc import AsyncIterable
|
| 2 |
+
from typing import TYPE_CHECKING, Any
|
| 3 |
|
| 4 |
from agent_framework import (
|
| 5 |
AgentRunResponse,
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
from src.orchestrator import SearchHandlerProtocol
|
| 14 |
+
from src.utils.models import Citation, Evidence, SearchResult
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from src.services.embeddings import EmbeddingService
|
| 18 |
|
| 19 |
|
| 20 |
class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
|
|
| 24 |
self,
|
| 25 |
search_handler: SearchHandlerProtocol,
|
| 26 |
evidence_store: dict[str, list[Evidence]],
|
| 27 |
+
embedding_service: "EmbeddingService | None" = None,
|
| 28 |
) -> None:
|
| 29 |
super().__init__(
|
| 30 |
name="SearchAgent",
|
|
|
|
| 32 |
)
|
| 33 |
self._handler = search_handler
|
| 34 |
self._evidence_store = evidence_store
|
| 35 |
+
self._embeddings = embedding_service
|
| 36 |
|
| 37 |
async def run(
|
| 38 |
self,
|
|
|
|
| 66 |
# Execute search
|
| 67 |
result: SearchResult = await self._handler.execute(query, max_results_per_tool=10)
|
| 68 |
|
| 69 |
+
# Track what to show in response (initialized to search results as default)
|
| 70 |
+
evidence_to_show: list[Evidence] = result.evidence
|
| 71 |
+
total_new = 0
|
| 72 |
+
|
| 73 |
# Update shared evidence store
|
| 74 |
+
if self._embeddings:
|
| 75 |
+
# Deduplicate by semantic similarity (async-safe)
|
| 76 |
+
unique_evidence = await self._embeddings.deduplicate(result.evidence)
|
| 77 |
+
|
| 78 |
+
# Also search for semantically related evidence (async-safe)
|
| 79 |
+
related = await self._embeddings.search_similar(query, n_results=5)
|
| 80 |
+
|
| 81 |
+
# Merge related evidence not already in results
|
| 82 |
+
existing_urls = {e.citation.url for e in unique_evidence}
|
| 83 |
+
|
| 84 |
+
# Reconstruct Evidence objects from stored vector DB data
|
| 85 |
+
related_evidence: list[Evidence] = []
|
| 86 |
+
for item in related:
|
| 87 |
+
if item["id"] not in existing_urls:
|
| 88 |
+
meta = item.get("metadata", {})
|
| 89 |
+
# Parse authors (stored as comma-separated string)
|
| 90 |
+
authors_str = meta.get("authors", "")
|
| 91 |
+
authors = [a.strip() for a in authors_str.split(",") if a.strip()]
|
| 92 |
+
|
| 93 |
+
ev = Evidence(
|
| 94 |
+
content=item["content"],
|
| 95 |
+
citation=Citation(
|
| 96 |
+
title=meta.get("title", "Related Evidence"),
|
| 97 |
+
url=item["id"],
|
| 98 |
+
source=meta.get("source", "vector_db"),
|
| 99 |
+
date=meta.get("date", "n.d."),
|
| 100 |
+
authors=authors,
|
| 101 |
+
),
|
| 102 |
+
# Convert distance to relevance (lower distance = higher relevance)
|
| 103 |
+
relevance=max(0.0, 1.0 - item.get("distance", 0.5)),
|
| 104 |
+
)
|
| 105 |
+
related_evidence.append(ev)
|
| 106 |
+
|
| 107 |
+
# Combine unique from search + related from vector DB
|
| 108 |
+
final_new_evidence = unique_evidence + related_evidence
|
| 109 |
+
|
| 110 |
+
# Add to global store (deduping against global store)
|
| 111 |
+
global_urls = {e.citation.url for e in self._evidence_store["current"]}
|
| 112 |
+
really_new = [e for e in final_new_evidence if e.citation.url not in global_urls]
|
| 113 |
+
self._evidence_store["current"].extend(really_new)
|
| 114 |
+
|
| 115 |
+
total_new = len(really_new)
|
| 116 |
+
evidence_to_show = unique_evidence + related_evidence
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
# Fallback to URL-based deduplication (no embeddings)
|
| 120 |
+
existing_urls = {e.citation.url for e in self._evidence_store["current"]}
|
| 121 |
+
new_unique = [e for e in result.evidence if e.citation.url not in existing_urls]
|
| 122 |
+
self._evidence_store["current"].extend(new_unique)
|
| 123 |
+
total_new = len(new_unique)
|
| 124 |
+
evidence_to_show = result.evidence
|
| 125 |
+
|
| 126 |
evidence_text = "\n".join(
|
| 127 |
[
|
| 128 |
f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..."
|
| 129 |
+
for e in evidence_to_show[:5]
|
| 130 |
]
|
| 131 |
)
|
| 132 |
|
| 133 |
response_text = (
|
| 134 |
+
f"Found {result.total_found} sources ({total_new} new added to context):\n\n"
|
| 135 |
+
f"{evidence_text}"
|
| 136 |
)
|
| 137 |
|
| 138 |
return AgentRunResponse(
|
| 139 |
messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
|
| 140 |
response_id=f"search-{result.total_found}",
|
| 141 |
+
additional_properties={"evidence": [e.model_dump() for e in evidence_to_show]},
|
| 142 |
)
|
| 143 |
|
| 144 |
async def run_stream(
|
src/app.py
CHANGED
|
@@ -72,23 +72,30 @@ async def research_agent(
|
|
| 72 |
yield "Please enter a research question."
|
| 73 |
return
|
| 74 |
|
| 75 |
-
#
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
if mode == "magentic" and use_mock:
|
| 80 |
yield (
|
| 81 |
-
"⚠️ **Warning**: Magentic mode requires
|
| 82 |
"Falling back to Mock Simple mode."
|
| 83 |
)
|
| 84 |
mode = "simple"
|
| 85 |
|
| 86 |
-
orchestrator = configure_orchestrator(use_mock=use_mock, mode=mode)
|
| 87 |
-
|
| 88 |
# Run the agent and stream events
|
| 89 |
-
response_parts = []
|
| 90 |
|
| 91 |
try:
|
|
|
|
| 92 |
async for event in orchestrator.run(message):
|
| 93 |
# Format event as markdown
|
| 94 |
event_md = event.to_markdown()
|
|
@@ -144,7 +151,7 @@ def create_demo() -> Any:
|
|
| 144 |
choices=["simple", "magentic"],
|
| 145 |
value="simple",
|
| 146 |
label="Orchestrator Mode",
|
| 147 |
-
info="Simple: Linear
|
| 148 |
)
|
| 149 |
],
|
| 150 |
)
|
|
|
|
| 72 |
yield "Please enter a research question."
|
| 73 |
return
|
| 74 |
|
| 75 |
+
# Decide whether to use real LLMs or mock based on mode and available keys
|
| 76 |
+
has_openai = bool(os.getenv("OPENAI_API_KEY"))
|
| 77 |
+
has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY"))
|
| 78 |
|
| 79 |
+
if mode == "magentic":
|
| 80 |
+
# Magentic currently supports OpenAI only
|
| 81 |
+
use_mock = not has_openai
|
| 82 |
+
else:
|
| 83 |
+
# Simple mode can work with either provider
|
| 84 |
+
use_mock = not (has_openai or has_anthropic)
|
| 85 |
+
|
| 86 |
+
# If magentic mode requested but no OpenAI key, fallback/warn
|
| 87 |
if mode == "magentic" and use_mock:
|
| 88 |
yield (
|
| 89 |
+
"⚠️ **Warning**: Magentic mode requires OpenAI API key. "
|
| 90 |
"Falling back to Mock Simple mode."
|
| 91 |
)
|
| 92 |
mode = "simple"
|
| 93 |
|
|
|
|
|
|
|
| 94 |
# Run the agent and stream events
|
| 95 |
+
response_parts: list[str] = []
|
| 96 |
|
| 97 |
try:
|
| 98 |
+
orchestrator = configure_orchestrator(use_mock=use_mock, mode=mode)
|
| 99 |
async for event in orchestrator.run(message):
|
| 100 |
# Format event as markdown
|
| 101 |
event_md = event.to_markdown()
|
|
|
|
| 151 |
choices=["simple", "magentic"],
|
| 152 |
value="simple",
|
| 153 |
label="Orchestrator Mode",
|
| 154 |
+
info="Simple: Linear (OpenAI/Anthropic) | Magentic: Multi-Agent (OpenAI)",
|
| 155 |
)
|
| 156 |
],
|
| 157 |
)
|
src/orchestrator.py
CHANGED
|
@@ -263,7 +263,7 @@ class Orchestrator:
|
|
| 263 |
|
| 264 |
citations = "\n".join(
|
| 265 |
[
|
| 266 |
-
f"{i+1}. [{e.citation.title}]({e.citation.url}) "
|
| 267 |
f"({e.citation.source.upper()}, {e.citation.date})"
|
| 268 |
for i, e in enumerate(evidence[:10]) # Limit to 10 citations
|
| 269 |
]
|
|
@@ -312,7 +312,7 @@ class Orchestrator:
|
|
| 312 |
"""
|
| 313 |
citations = "\n".join(
|
| 314 |
[
|
| 315 |
-
f"{i+1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})"
|
| 316 |
for i, e in enumerate(evidence[:10])
|
| 317 |
]
|
| 318 |
)
|
|
|
|
| 263 |
|
| 264 |
citations = "\n".join(
|
| 265 |
[
|
| 266 |
+
f"{i + 1}. [{e.citation.title}]({e.citation.url}) "
|
| 267 |
f"({e.citation.source.upper()}, {e.citation.date})"
|
| 268 |
for i, e in enumerate(evidence[:10]) # Limit to 10 citations
|
| 269 |
]
|
|
|
|
| 312 |
"""
|
| 313 |
citations = "\n".join(
|
| 314 |
[
|
| 315 |
+
f"{i + 1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})"
|
| 316 |
for i, e in enumerate(evidence[:10])
|
| 317 |
]
|
| 318 |
)
|
src/orchestrator_magentic.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
| 1 |
-
"""Magentic-based orchestrator for DeepCritical.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from collections.abc import AsyncGenerator
|
| 4 |
|
|
@@ -17,6 +22,7 @@ from src.agents.judge_agent import JudgeAgent
|
|
| 17 |
from src.agents.search_agent import SearchAgent
|
| 18 |
from src.orchestrator import JudgeHandlerProtocol, SearchHandlerProtocol
|
| 19 |
from src.utils.config import settings
|
|
|
|
| 20 |
from src.utils.models import AgentEvent, Evidence
|
| 21 |
|
| 22 |
logger = structlog.get_logger()
|
|
@@ -27,6 +33,11 @@ class MagenticOrchestrator:
|
|
| 27 |
Magentic-based orchestrator - same API as Orchestrator.
|
| 28 |
|
| 29 |
Uses Microsoft Agent Framework's MagenticBuilder for multi-agent coordination.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"""
|
| 31 |
|
| 32 |
def __init__(
|
|
@@ -54,12 +65,32 @@ class MagenticOrchestrator:
|
|
| 54 |
iteration=0,
|
| 55 |
)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Create agent wrappers
|
| 58 |
-
search_agent = SearchAgent(
|
|
|
|
|
|
|
| 59 |
judge_agent = JudgeAgent(self._judge_handler, self._evidence_store)
|
| 60 |
|
| 61 |
# Build Magentic workflow
|
| 62 |
-
# Note: MagenticBuilder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
workflow = (
|
| 64 |
MagenticBuilder()
|
| 65 |
.participants(
|
|
@@ -78,8 +109,17 @@ class MagenticOrchestrator:
|
|
| 78 |
)
|
| 79 |
|
| 80 |
# Task instruction for the manager
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
|
|
|
|
|
|
| 83 |
Instructions:
|
| 84 |
1. Use SearcherAgent to find evidence. SEND ONLY A SIMPLE KEYWORD QUERY (e.g. "metformin aging")
|
| 85 |
as the instruction. Complex queries fail.
|
|
|
|
| 1 |
+
"""Magentic-based orchestrator for DeepCritical.
|
| 2 |
+
|
| 3 |
+
NOTE: Magentic mode currently requires OpenAI API keys. The MagenticBuilder's
|
| 4 |
+
standard manager uses OpenAIChatClient. Anthropic support may be added when
|
| 5 |
+
the agent_framework provides an AnthropicChatClient.
|
| 6 |
+
"""
|
| 7 |
|
| 8 |
from collections.abc import AsyncGenerator
|
| 9 |
|
|
|
|
| 22 |
from src.agents.search_agent import SearchAgent
|
| 23 |
from src.orchestrator import JudgeHandlerProtocol, SearchHandlerProtocol
|
| 24 |
from src.utils.config import settings
|
| 25 |
+
from src.utils.exceptions import ConfigurationError
|
| 26 |
from src.utils.models import AgentEvent, Evidence
|
| 27 |
|
| 28 |
logger = structlog.get_logger()
|
|
|
|
| 33 |
Magentic-based orchestrator - same API as Orchestrator.
|
| 34 |
|
| 35 |
Uses Microsoft Agent Framework's MagenticBuilder for multi-agent coordination.
|
| 36 |
+
|
| 37 |
+
Note:
|
| 38 |
+
Magentic mode requires OPENAI_API_KEY. The MagenticBuilder's standard
|
| 39 |
+
manager currently only supports OpenAI. If you have only an Anthropic
|
| 40 |
+
key, use the "simple" orchestrator mode instead.
|
| 41 |
"""
|
| 42 |
|
| 43 |
def __init__(
|
|
|
|
| 65 |
iteration=0,
|
| 66 |
)
|
| 67 |
|
| 68 |
+
# Initialize embedding service (optional)
|
| 69 |
+
embedding_service = None
|
| 70 |
+
try:
|
| 71 |
+
from src.services.embeddings import get_embedding_service
|
| 72 |
+
|
| 73 |
+
embedding_service = get_embedding_service()
|
| 74 |
+
logger.info("Embedding service enabled")
|
| 75 |
+
except ImportError:
|
| 76 |
+
logger.info("Embedding service not available (dependencies missing)")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.warning("Failed to initialize embedding service", error=str(e))
|
| 79 |
+
|
| 80 |
# Create agent wrappers
|
| 81 |
+
search_agent = SearchAgent(
|
| 82 |
+
self._search_handler, self._evidence_store, embedding_service=embedding_service
|
| 83 |
+
)
|
| 84 |
judge_agent = JudgeAgent(self._judge_handler, self._evidence_store)
|
| 85 |
|
| 86 |
# Build Magentic workflow
|
| 87 |
+
# Note: MagenticBuilder requires OpenAI - validate key exists
|
| 88 |
+
if not settings.openai_api_key:
|
| 89 |
+
raise ConfigurationError(
|
| 90 |
+
"Magentic mode requires OPENAI_API_KEY. "
|
| 91 |
+
"Set the key or use mode='simple' with Anthropic."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
workflow = (
|
| 95 |
MagenticBuilder()
|
| 96 |
.participants(
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
# Task instruction for the manager
|
| 112 |
+
semantic_note = ""
|
| 113 |
+
if embedding_service:
|
| 114 |
+
semantic_note = """
|
| 115 |
+
The system has semantic search enabled. When evidence is found:
|
| 116 |
+
1. Related concepts will be automatically surfaced
|
| 117 |
+
2. Duplicates are removed by meaning, not just URL
|
| 118 |
+
3. Use the surfaced related concepts to refine searches
|
| 119 |
+
"""
|
| 120 |
|
| 121 |
+
task = f"""Research drug repurposing opportunities for: {query}
|
| 122 |
+
{semantic_note}
|
| 123 |
Instructions:
|
| 124 |
1. Use SearcherAgent to find evidence. SEND ONLY A SIMPLE KEYWORD QUERY (e.g. "metformin aging")
|
| 125 |
as the instruction. Complex queries fail.
|
src/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Services for DeepCritical."""
|
src/services/embeddings.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embedding service for semantic search.
|
| 2 |
+
|
| 3 |
+
IMPORTANT: All public methods are async to avoid blocking the event loop.
|
| 4 |
+
The sentence-transformers model is CPU-bound, so we use run_in_executor().
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import chromadb
|
| 11 |
+
import structlog
|
| 12 |
+
from sentence_transformers import SentenceTransformer
|
| 13 |
+
|
| 14 |
+
from src.utils.models import Evidence
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EmbeddingService:
|
| 18 |
+
"""Handles text embedding and vector storage.
|
| 19 |
+
|
| 20 |
+
All embedding operations run in a thread pool to avoid blocking
|
| 21 |
+
the async event loop. See src/tools/websearch.py for the pattern.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
| 25 |
+
self._model = SentenceTransformer(model_name)
|
| 26 |
+
self._client = chromadb.Client() # In-memory for hackathon
|
| 27 |
+
self._collection = self._client.create_collection(
|
| 28 |
+
name="evidence", metadata={"hnsw:space": "cosine"}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# ─────────────────────────────────────────────────────────────────
|
| 32 |
+
# Sync internal methods (run in thread pool)
|
| 33 |
+
# ─────────────────────────────────────────────────────────────────
|
| 34 |
+
|
| 35 |
+
def _sync_embed(self, text: str) -> list[float]:
|
| 36 |
+
"""Synchronous embedding - DO NOT call directly from async code."""
|
| 37 |
+
result: list[float] = self._model.encode(text).tolist()
|
| 38 |
+
return result
|
| 39 |
+
|
| 40 |
+
def _sync_batch_embed(self, texts: list[str]) -> list[list[float]]:
|
| 41 |
+
"""Batch embedding for efficiency - DO NOT call directly from async code."""
|
| 42 |
+
embeddings = self._model.encode(texts)
|
| 43 |
+
return [e.tolist() for e in embeddings]
|
| 44 |
+
|
| 45 |
+
# ─────────────────────────────────────────────────────────────────
|
| 46 |
+
# Async public methods (safe for event loop)
|
| 47 |
+
# ─────────────────────────────────────────────────────────────────
|
| 48 |
+
|
| 49 |
+
async def embed(self, text: str) -> list[float]:
|
| 50 |
+
"""Embed a single text (async-safe).
|
| 51 |
+
|
| 52 |
+
Uses run_in_executor to avoid blocking the event loop.
|
| 53 |
+
"""
|
| 54 |
+
loop = asyncio.get_running_loop()
|
| 55 |
+
return await loop.run_in_executor(None, self._sync_embed, text)
|
| 56 |
+
|
| 57 |
+
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
| 58 |
+
"""Batch embed multiple texts (async-safe, more efficient)."""
|
| 59 |
+
loop = asyncio.get_running_loop()
|
| 60 |
+
return await loop.run_in_executor(None, self._sync_batch_embed, texts)
|
| 61 |
+
|
| 62 |
+
async def add_evidence(self, evidence_id: str, content: str, metadata: dict[str, Any]) -> None:
|
| 63 |
+
"""Add evidence to vector store (async-safe)."""
|
| 64 |
+
embedding = await self.embed(content)
|
| 65 |
+
# ChromaDB operations are fast, but wrap for consistency
|
| 66 |
+
loop = asyncio.get_running_loop()
|
| 67 |
+
await loop.run_in_executor(
|
| 68 |
+
None,
|
| 69 |
+
lambda: self._collection.add(
|
| 70 |
+
ids=[evidence_id],
|
| 71 |
+
embeddings=[embedding], # type: ignore[arg-type]
|
| 72 |
+
metadatas=[metadata],
|
| 73 |
+
documents=[content],
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
async def search_similar(self, query: str, n_results: int = 5) -> list[dict[str, Any]]:
|
| 78 |
+
"""Find semantically similar evidence (async-safe)."""
|
| 79 |
+
query_embedding = await self.embed(query)
|
| 80 |
+
|
| 81 |
+
loop = asyncio.get_running_loop()
|
| 82 |
+
results = await loop.run_in_executor(
|
| 83 |
+
None,
|
| 84 |
+
lambda: self._collection.query(
|
| 85 |
+
query_embeddings=[query_embedding], # type: ignore[arg-type]
|
| 86 |
+
n_results=n_results,
|
| 87 |
+
),
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Handle empty results gracefully
|
| 91 |
+
ids = results.get("ids")
|
| 92 |
+
docs = results.get("documents")
|
| 93 |
+
metas = results.get("metadatas")
|
| 94 |
+
dists = results.get("distances")
|
| 95 |
+
|
| 96 |
+
if not ids or not ids[0] or not docs or not metas or not dists:
|
| 97 |
+
return []
|
| 98 |
+
|
| 99 |
+
return [
|
| 100 |
+
{"id": id, "content": doc, "metadata": meta, "distance": dist}
|
| 101 |
+
for id, doc, meta, dist in zip(
|
| 102 |
+
ids[0],
|
| 103 |
+
docs[0],
|
| 104 |
+
metas[0],
|
| 105 |
+
dists[0],
|
| 106 |
+
strict=False,
|
| 107 |
+
)
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
async def deduplicate(
|
| 111 |
+
self, new_evidence: list[Evidence], threshold: float = 0.9
|
| 112 |
+
) -> list[Evidence]:
|
| 113 |
+
"""Remove semantically duplicate evidence (async-safe).
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
new_evidence: List of evidence items to deduplicate
|
| 117 |
+
threshold: Similarity threshold (0.9 = 90% similar is duplicate).
|
| 118 |
+
ChromaDB cosine distance: 0=identical, 2=opposite.
|
| 119 |
+
We consider duplicate if distance < (1 - threshold).
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
List of unique evidence items (not already in vector store).
|
| 123 |
+
"""
|
| 124 |
+
unique = []
|
| 125 |
+
for evidence in new_evidence:
|
| 126 |
+
try:
|
| 127 |
+
similar = await self.search_similar(evidence.content, n_results=1)
|
| 128 |
+
# ChromaDB cosine distance: 0 = identical, 2 = opposite
|
| 129 |
+
# threshold=0.9 means distance < 0.1 is considered duplicate
|
| 130 |
+
is_duplicate = similar and similar[0]["distance"] < (1 - threshold)
|
| 131 |
+
|
| 132 |
+
if not is_duplicate:
|
| 133 |
+
unique.append(evidence)
|
| 134 |
+
# Store FULL citation metadata for reconstruction later
|
| 135 |
+
await self.add_evidence(
|
| 136 |
+
evidence_id=evidence.citation.url,
|
| 137 |
+
content=evidence.content,
|
| 138 |
+
metadata={
|
| 139 |
+
"source": evidence.citation.source,
|
| 140 |
+
"title": evidence.citation.title,
|
| 141 |
+
"date": evidence.citation.date,
|
| 142 |
+
"authors": ",".join(evidence.citation.authors or []),
|
| 143 |
+
},
|
| 144 |
+
)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
# Log but don't fail entire deduplication for one bad item
|
| 147 |
+
structlog.get_logger().warning(
|
| 148 |
+
"Failed to process evidence in deduplicate",
|
| 149 |
+
url=evidence.citation.url,
|
| 150 |
+
error=str(e),
|
| 151 |
+
)
|
| 152 |
+
# Still add to unique list - better to have duplicates than lose data
|
| 153 |
+
unique.append(evidence)
|
| 154 |
+
|
| 155 |
+
return unique
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
_embedding_service: EmbeddingService | None = None
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_embedding_service() -> EmbeddingService:
|
| 162 |
+
"""Get singleton instance of EmbeddingService."""
|
| 163 |
+
global _embedding_service # noqa: PLW0603
|
| 164 |
+
if _embedding_service is None:
|
| 165 |
+
_embedding_service = EmbeddingService()
|
| 166 |
+
return _embedding_service
|
src/tools/pubmed.py
CHANGED
|
@@ -22,7 +22,7 @@ class PubMedTool:
|
|
| 22 |
def __init__(self, api_key: str | None = None) -> None:
|
| 23 |
self.api_key = api_key or settings.ncbi_api_key
|
| 24 |
# Ignore placeholder values from .env.example
|
| 25 |
-
if self.api_key
|
| 26 |
self.api_key = None
|
| 27 |
self._last_request_time = 0.0
|
| 28 |
|
|
|
|
| 22 |
def __init__(self, api_key: str | None = None) -> None:
|
| 23 |
self.api_key = api_key or settings.ncbi_api_key
|
| 24 |
# Ignore placeholder values from .env.example
|
| 25 |
+
if self.api_key == "your-ncbi-key-here":
|
| 26 |
self.api_key = None
|
| 27 |
self._last_request_time = 0.0
|
| 28 |
|
src/utils/config.py
CHANGED
|
@@ -26,8 +26,8 @@ class Settings(BaseSettings):
|
|
| 26 |
llm_provider: Literal["openai", "anthropic"] = Field(
|
| 27 |
default="openai", description="Which LLM provider to use"
|
| 28 |
)
|
| 29 |
-
openai_model: str = Field(default="gpt-
|
| 30 |
-
anthropic_model: str = Field(default="claude-sonnet-4-
|
| 31 |
|
| 32 |
# PubMed Configuration
|
| 33 |
ncbi_api_key: str | None = Field(
|
|
|
|
| 26 |
llm_provider: Literal["openai", "anthropic"] = Field(
|
| 27 |
default="openai", description="Which LLM provider to use"
|
| 28 |
)
|
| 29 |
+
openai_model: str = Field(default="gpt-4o", description="OpenAI model name")
|
| 30 |
+
anthropic_model: str = Field(default="claude-sonnet-4-20250514", description="Anthropic model")
|
| 31 |
|
| 32 |
# PubMed Configuration
|
| 33 |
ncbi_api_key: str | None = Field(
|
src/utils/models.py
CHANGED
|
@@ -125,6 +125,7 @@ class AgentEvent(BaseModel):
|
|
| 125 |
"synthesizing": "📝",
|
| 126 |
"complete": "🎉",
|
| 127 |
"error": "❌",
|
|
|
|
| 128 |
}
|
| 129 |
icon = icons.get(self.type, "•")
|
| 130 |
return f"{icon} **{self.type.upper()}**: {self.message}"
|
|
|
|
| 125 |
"synthesizing": "📝",
|
| 126 |
"complete": "🎉",
|
| 127 |
"error": "❌",
|
| 128 |
+
"streaming": "📡",
|
| 129 |
}
|
| 130 |
icon = icons.get(self.type, "•")
|
| 131 |
return f"{icon} **{self.type.upper()}**: {self.message}"
|
tests/unit/agents/test_search_agent.py
CHANGED
|
@@ -81,5 +81,48 @@ async def test_run_handles_list_input(mock_handler: AsyncMock) -> None:
|
|
| 81 |
ChatMessage(role=Role.USER, text="test query"),
|
| 82 |
]
|
| 83 |
await agent.run(messages)
|
| 84 |
-
|
| 85 |
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
ChatMessage(role=Role.USER, text="test query"),
|
| 82 |
]
|
| 83 |
await agent.run(messages)
|
|
|
|
| 84 |
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@pytest.mark.asyncio
|
| 88 |
+
async def test_run_uses_embeddings(mock_handler: AsyncMock) -> None:
|
| 89 |
+
"""Test that run uses embedding service if provided."""
|
| 90 |
+
store: dict = {"current": []}
|
| 91 |
+
|
| 92 |
+
# Mock embedding service
|
| 93 |
+
mock_embeddings = AsyncMock()
|
| 94 |
+
# Mock deduplicate to return the evidence as is (or filtered)
|
| 95 |
+
mock_embeddings.deduplicate.return_value = [
|
| 96 |
+
Evidence(
|
| 97 |
+
content="unique content",
|
| 98 |
+
citation=Citation(source="pubmed", url="u1", title="t1", date="2024"),
|
| 99 |
+
)
|
| 100 |
+
]
|
| 101 |
+
# Mock search_similar to return related items
|
| 102 |
+
mock_embeddings.search_similar.return_value = [
|
| 103 |
+
{
|
| 104 |
+
"id": "u2",
|
| 105 |
+
"content": "related content",
|
| 106 |
+
"metadata": {"source": "web", "title": "related", "date": "2024"},
|
| 107 |
+
"distance": 0.1,
|
| 108 |
+
}
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
agent = SearchAgent(mock_handler, store, embedding_service=mock_embeddings)
|
| 112 |
+
|
| 113 |
+
await agent.run("test query")
|
| 114 |
+
|
| 115 |
+
# Verify deduplicate called
|
| 116 |
+
mock_embeddings.deduplicate.assert_awaited_once()
|
| 117 |
+
|
| 118 |
+
# Verify semantic search called
|
| 119 |
+
mock_embeddings.search_similar.assert_awaited_once_with("test query", n_results=5)
|
| 120 |
+
|
| 121 |
+
# Verify store contains related evidence (if logic implemented to add it)
|
| 122 |
+
# Note: logic for adding related evidence needs to be implemented in SearchAgent
|
| 123 |
+
# The spec says: "Merge related evidence not already in results"
|
| 124 |
+
|
| 125 |
+
# Check if u1 (deduplicated result) is in store
|
| 126 |
+
assert any(e.citation.url == "u1" for e in store["current"])
|
| 127 |
+
# Check if u2 (related result) is in store
|
| 128 |
+
assert any(e.citation.url == "u2" for e in store["current"])
|
tests/unit/services/test_embeddings.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for EmbeddingService."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
# Skip if embeddings dependencies are not installed
|
| 9 |
+
pytest.importorskip("chromadb")
|
| 10 |
+
pytest.importorskip("sentence_transformers")
|
| 11 |
+
|
| 12 |
+
from src.services.embeddings import EmbeddingService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestEmbeddingService:
|
| 16 |
+
@pytest.fixture
|
| 17 |
+
def mock_sentence_transformer(self):
|
| 18 |
+
with patch("src.services.embeddings.SentenceTransformer") as mock_st_class:
|
| 19 |
+
mock_model = mock_st_class.return_value
|
| 20 |
+
# Mock encode to return a numpy array
|
| 21 |
+
mock_model.encode.return_value = np.array([0.1, 0.2, 0.3])
|
| 22 |
+
yield mock_model
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def mock_chroma_client(self):
|
| 26 |
+
with patch("src.services.embeddings.chromadb.Client") as mock_client_class:
|
| 27 |
+
mock_client = mock_client_class.return_value
|
| 28 |
+
mock_collection = mock_client.create_collection.return_value
|
| 29 |
+
# Mock query return structure
|
| 30 |
+
mock_collection.query.return_value = {
|
| 31 |
+
"ids": [["id1"]],
|
| 32 |
+
"documents": [["doc1"]],
|
| 33 |
+
"metadatas": [[{"source": "pubmed"}]],
|
| 34 |
+
"distances": [[0.1]],
|
| 35 |
+
}
|
| 36 |
+
yield mock_client
|
| 37 |
+
|
| 38 |
+
@pytest.mark.asyncio
|
| 39 |
+
async def test_embed_returns_vector(self, mock_sentence_transformer, mock_chroma_client):
|
| 40 |
+
"""Embedding should return a float vector (async check)."""
|
| 41 |
+
service = EmbeddingService()
|
| 42 |
+
embedding = await service.embed("metformin diabetes")
|
| 43 |
+
|
| 44 |
+
assert isinstance(embedding, list)
|
| 45 |
+
assert len(embedding) == 3 # noqa: PLR2004
|
| 46 |
+
assert all(isinstance(x, float) for x in embedding)
|
| 47 |
+
# Ensure it ran in executor (mock encode called)
|
| 48 |
+
mock_sentence_transformer.encode.assert_called_once()
|
| 49 |
+
|
| 50 |
+
@pytest.mark.asyncio
|
| 51 |
+
async def test_batch_embed_efficient(self, mock_sentence_transformer, mock_chroma_client):
|
| 52 |
+
"""Batch embedding should call encode with list."""
|
| 53 |
+
# Setup mock for batch return (list of arrays)
|
| 54 |
+
mock_sentence_transformer.encode.return_value = np.array([[0.1, 0.2], [0.3, 0.4]])
|
| 55 |
+
|
| 56 |
+
service = EmbeddingService()
|
| 57 |
+
texts = ["text one", "text two"]
|
| 58 |
+
|
| 59 |
+
batch_results = await service.embed_batch(texts)
|
| 60 |
+
|
| 61 |
+
assert len(batch_results) == 2 # noqa: PLR2004
|
| 62 |
+
assert isinstance(batch_results[0], list)
|
| 63 |
+
mock_sentence_transformer.encode.assert_called_with(texts)
|
| 64 |
+
|
| 65 |
+
@pytest.mark.asyncio
|
| 66 |
+
async def test_add_and_search(self, mock_sentence_transformer, mock_chroma_client):
|
| 67 |
+
"""Should be able to add evidence and search for similar."""
|
| 68 |
+
service = EmbeddingService()
|
| 69 |
+
await service.add_evidence(
|
| 70 |
+
evidence_id="test1",
|
| 71 |
+
content="Metformin activates AMPK pathway",
|
| 72 |
+
metadata={"source": "pubmed"},
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Verify add was called
|
| 76 |
+
mock_collection = mock_chroma_client.create_collection.return_value
|
| 77 |
+
mock_collection.add.assert_called_once()
|
| 78 |
+
|
| 79 |
+
results = await service.search_similar("AMPK activation drugs", n_results=1)
|
| 80 |
+
|
| 81 |
+
# Verify query was called
|
| 82 |
+
mock_collection.query.assert_called_once()
|
| 83 |
+
assert len(results) == 1
|
| 84 |
+
assert results[0]["id"] == "id1"
|
| 85 |
+
|
| 86 |
+
@pytest.mark.asyncio
|
| 87 |
+
async def test_search_similar_empty_collection(
|
| 88 |
+
self, mock_sentence_transformer, mock_chroma_client
|
| 89 |
+
):
|
| 90 |
+
"""Search on empty collection should return empty list, not error."""
|
| 91 |
+
mock_collection = mock_chroma_client.create_collection.return_value
|
| 92 |
+
mock_collection.query.return_value = {
|
| 93 |
+
"ids": [[]],
|
| 94 |
+
"documents": [[]],
|
| 95 |
+
"metadatas": [[]],
|
| 96 |
+
"distances": [[]],
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
service = EmbeddingService()
|
| 100 |
+
results = await service.search_similar("anything", n_results=5)
|
| 101 |
+
assert results == []
|
| 102 |
+
|
| 103 |
+
@pytest.mark.asyncio
|
| 104 |
+
async def test_deduplicate(self, mock_sentence_transformer, mock_chroma_client):
|
| 105 |
+
"""Deduplicate should remove similar items."""
|
| 106 |
+
from src.utils.models import Citation, Evidence
|
| 107 |
+
|
| 108 |
+
service = EmbeddingService()
|
| 109 |
+
|
| 110 |
+
# Mock search to return a match for the first item (duplicate)
|
| 111 |
+
# and no match for the second (unique)
|
| 112 |
+
mock_collection = mock_chroma_client.create_collection.return_value
|
| 113 |
+
|
| 114 |
+
# First call returns match (distance 0.05 < threshold)
|
| 115 |
+
# Second call returns no match or high distance
|
| 116 |
+
mock_collection.query.side_effect = [
|
| 117 |
+
{
|
| 118 |
+
"ids": [["existing_id"]],
|
| 119 |
+
"documents": [["doc"]],
|
| 120 |
+
"metadatas": [[{}]],
|
| 121 |
+
"distances": [[0.05]], # Very similar
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"ids": [[]], # No match
|
| 125 |
+
"documents": [[]],
|
| 126 |
+
"metadatas": [[]],
|
| 127 |
+
"distances": [[]],
|
| 128 |
+
},
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
evidence = [
|
| 132 |
+
Evidence(
|
| 133 |
+
content="Duplicate content",
|
| 134 |
+
citation=Citation(source="web", url="u1", title="t1", date="2024"),
|
| 135 |
+
),
|
| 136 |
+
Evidence(
|
| 137 |
+
content="Unique content",
|
| 138 |
+
citation=Citation(source="web", url="u2", title="t2", date="2024"),
|
| 139 |
+
),
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
unique = await service.deduplicate(evidence, threshold=0.9)
|
| 143 |
+
|
| 144 |
+
# Only the unique one should remain
|
| 145 |
+
assert len(unique) == 1
|
| 146 |
+
assert unique[0].citation.url == "u2"
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|