Spaces:
Running
Running
Commit
·
7fab6d4
1
Parent(s):
2e4a760
feat: implement HFInferenceJudgeHandler for free-tier AI analysis
Browse filesReplace MockJudgeHandler with real AI analysis using HuggingFace Inference API:
- Add HFInferenceJudgeHandler with chat_completion API
- Model fallback chain: Llama 3.1 → Mistral → Zephyr (ungated)
- Robust JSON extraction (handles markdown blocks, nested braces)
- Tenacity retry with exponential backoff for rate limits
- Fix app.py to use HF Inference when no paid API keys present
Priority: User API key → Env API key → HF Inference (free)
Hackathon judges now get real AI analysis without needing API keys.
Set HF_TOKEN as Space secret for best model (Llama 3.1).
- .env.example +14 -0
- docs/implementation/03_phase_judge.md +414 -14
- docs/implementation/04_phase_ui.md +84 -18
- pyproject.toml +1 -0
- src/agent_factory/judges.py +199 -1
- src/app.py +45 -34
- src/prompts/report.py +4 -4
- tests/unit/agent_factory/test_judges_hf.py +137 -0
- uv.lock +2 -0
.env.example
CHANGED
|
@@ -11,6 +11,20 @@ ANTHROPIC_API_KEY=sk-ant-your-key-here
|
|
| 11 |
OPENAI_MODEL=gpt-5.1
|
| 12 |
ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# ============== AGENT CONFIGURATION ==============
|
| 15 |
|
| 16 |
MAX_ITERATIONS=10
|
|
|
|
| 11 |
OPENAI_MODEL=gpt-5.1
|
| 12 |
ANTHROPIC_MODEL=claude-sonnet-4-5-20250929
|
| 13 |
|
| 14 |
+
# ============== HUGGINGFACE (FREE TIER) ==============
|
| 15 |
+
|
| 16 |
+
# HuggingFace Token - enables Llama 3.1 (best quality free model)
|
| 17 |
+
# Get yours at: https://huggingface.co/settings/tokens
|
| 18 |
+
#
|
| 19 |
+
# WITHOUT HF_TOKEN: Falls back to ungated models (zephyr-7b-beta)
|
| 20 |
+
# WITH HF_TOKEN: Uses Llama 3.1 8B Instruct (requires accepting license)
|
| 21 |
+
#
|
| 22 |
+
# For HuggingFace Spaces deployment:
|
| 23 |
+
# Set this as a "Secret" in Space Settings → Variables and secrets
|
| 24 |
+
# Users/judges don't need their own token - the Space secret is used
|
| 25 |
+
#
|
| 26 |
+
HF_TOKEN=hf_your-token-here
|
| 27 |
+
|
| 28 |
# ============== AGENT CONFIGURATION ==============
|
| 29 |
|
| 30 |
MAX_ITERATIONS=10
|
docs/implementation/03_phase_judge.md
CHANGED
|
@@ -350,20 +350,333 @@ class JudgeHandler:
|
|
| 350 |
)
|
| 351 |
|
| 352 |
|
| 353 |
-
class
|
| 354 |
"""
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
"""
|
| 359 |
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
"""
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
Args:
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
self.mock_response = mock_response
|
| 368 |
self.call_count = 0
|
| 369 |
self.last_question = None
|
|
@@ -374,7 +687,7 @@ class MockJudgeHandler:
|
|
| 374 |
question: str,
|
| 375 |
evidence: List[Evidence],
|
| 376 |
) -> JudgeAssessment:
|
| 377 |
-
"""Return the mock response."""
|
| 378 |
self.call_count += 1
|
| 379 |
self.last_question = question
|
| 380 |
self.last_evidence = evidence
|
|
@@ -382,21 +695,21 @@ class MockJudgeHandler:
|
|
| 382 |
if self.mock_response:
|
| 383 |
return self.mock_response
|
| 384 |
|
| 385 |
-
# Default mock response
|
| 386 |
return JudgeAssessment(
|
| 387 |
details=AssessmentDetails(
|
| 388 |
mechanism_score=7,
|
| 389 |
-
mechanism_reasoning="Mock assessment
|
| 390 |
clinical_evidence_score=6,
|
| 391 |
-
clinical_reasoning="Mock assessment
|
| 392 |
-
drug_candidates=["
|
| 393 |
-
key_findings=["
|
| 394 |
),
|
| 395 |
sufficient=len(evidence) >= 3,
|
| 396 |
confidence=0.75,
|
| 397 |
recommendation="synthesize" if len(evidence) >= 3 else "continue",
|
| 398 |
next_search_queries=["query 1", "query 2"] if len(evidence) < 3 else [],
|
| 399 |
-
reasoning="Mock assessment for testing
|
| 400 |
)
|
| 401 |
```
|
| 402 |
|
|
@@ -547,8 +860,89 @@ class TestJudgeHandler:
|
|
| 547 |
assert "failed" in result.reasoning.lower()
|
| 548 |
|
| 549 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
class TestMockJudgeHandler:
|
| 551 |
-
"""Tests for MockJudgeHandler."""
|
| 552 |
|
| 553 |
@pytest.mark.asyncio
|
| 554 |
async def test_mock_handler_returns_default(self):
|
|
@@ -641,9 +1035,15 @@ dependencies = [
|
|
| 641 |
"pydantic-ai>=0.0.16",
|
| 642 |
"openai>=1.0.0",
|
| 643 |
"anthropic>=0.18.0",
|
|
|
|
| 644 |
]
|
| 645 |
```
|
| 646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
---
|
| 648 |
|
| 649 |
## 7. Configuration (`src/utils/config.py`)
|
|
|
|
| 350 |
)
|
| 351 |
|
| 352 |
|
| 353 |
+
class HFInferenceJudgeHandler:
|
| 354 |
"""
|
| 355 |
+
JudgeHandler using HuggingFace Inference API for FREE LLM calls.
|
| 356 |
+
|
| 357 |
+
This is the DEFAULT for demo mode - provides real AI analysis without
|
| 358 |
+
requiring users to have OpenAI/Anthropic API keys.
|
| 359 |
|
| 360 |
+
Model Fallback Chain (handles gated models and rate limits):
|
| 361 |
+
1. meta-llama/Llama-3.1-8B-Instruct (best quality, requires HF_TOKEN)
|
| 362 |
+
2. mistralai/Mistral-7B-Instruct-v0.3 (good quality, may require token)
|
| 363 |
+
3. HuggingFaceH4/zephyr-7b-beta (ungated, always works)
|
| 364 |
+
|
| 365 |
+
Rate Limit Handling:
|
| 366 |
+
- Exponential backoff with 3 retries
|
| 367 |
+
- Falls back to next model on persistent 429/503 errors
|
| 368 |
"""
|
| 369 |
|
| 370 |
+
# Model fallback chain: gated (best) → ungated (fallback)
|
| 371 |
+
FALLBACK_MODELS = [
|
| 372 |
+
"meta-llama/Llama-3.1-8B-Instruct", # Best quality (gated)
|
| 373 |
+
"mistralai/Mistral-7B-Instruct-v0.3", # Good quality
|
| 374 |
+
"HuggingFaceH4/zephyr-7b-beta", # Ungated fallback
|
| 375 |
+
]
|
| 376 |
+
|
| 377 |
+
def __init__(self, model_id: str | None = None):
|
| 378 |
+
"""
|
| 379 |
+
Initialize with HF Inference client.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
model_id: HuggingFace model ID. If None, uses fallback chain.
|
| 383 |
+
Will automatically use HF_TOKEN from env if available.
|
| 384 |
+
"""
|
| 385 |
+
from huggingface_hub import InferenceClient
|
| 386 |
+
import os
|
| 387 |
+
|
| 388 |
+
self.model_id = model_id or self.FALLBACK_MODELS[0]
|
| 389 |
+
self._fallback_models = self.FALLBACK_MODELS.copy()
|
| 390 |
+
|
| 391 |
+
# InferenceClient auto-reads HF_TOKEN from env
|
| 392 |
+
self.client = InferenceClient(model=self.model_id)
|
| 393 |
+
self._has_token = bool(os.getenv("HF_TOKEN"))
|
| 394 |
+
|
| 395 |
+
self.call_count = 0
|
| 396 |
+
self.last_question = None
|
| 397 |
+
self.last_evidence = None
|
| 398 |
+
|
| 399 |
+
logger.info(
|
| 400 |
+
"HFInferenceJudgeHandler initialized",
|
| 401 |
+
model=self.model_id,
|
| 402 |
+
has_token=self._has_token,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def _extract_json(self, response: str) -> dict | None:
|
| 406 |
+
"""
|
| 407 |
+
Robustly extract JSON from LLM response.
|
| 408 |
+
|
| 409 |
+
Handles:
|
| 410 |
+
- Raw JSON: {"key": "value"}
|
| 411 |
+
- Markdown code blocks: ```json\n{"key": "value"}\n```
|
| 412 |
+
- Preamble text: "Here is the JSON:\n{"key": "value"}"
|
| 413 |
+
- Nested braces: {"outer": {"inner": "value"}}
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
Parsed dict or None if extraction fails
|
| 417 |
+
"""
|
| 418 |
+
import json
|
| 419 |
+
import re
|
| 420 |
+
|
| 421 |
+
# Strategy 1: Try markdown code block first
|
| 422 |
+
code_block_match = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", response)
|
| 423 |
+
if code_block_match:
|
| 424 |
+
try:
|
| 425 |
+
return json.loads(code_block_match.group(1))
|
| 426 |
+
except json.JSONDecodeError:
|
| 427 |
+
pass
|
| 428 |
+
|
| 429 |
+
# Strategy 2: Find outermost JSON object with brace matching
|
| 430 |
+
# This handles nested objects correctly
|
| 431 |
+
start = response.find("{")
|
| 432 |
+
if start == -1:
|
| 433 |
+
return None
|
| 434 |
+
|
| 435 |
+
depth = 0
|
| 436 |
+
end = start
|
| 437 |
+
in_string = False
|
| 438 |
+
escape_next = False
|
| 439 |
+
|
| 440 |
+
for i, char in enumerate(response[start:], start):
|
| 441 |
+
if escape_next:
|
| 442 |
+
escape_next = False
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
if char == "\\":
|
| 446 |
+
escape_next = True
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
if char == '"' and not escape_next:
|
| 450 |
+
in_string = not in_string
|
| 451 |
+
continue
|
| 452 |
+
|
| 453 |
+
if in_string:
|
| 454 |
+
continue
|
| 455 |
+
|
| 456 |
+
if char == "{":
|
| 457 |
+
depth += 1
|
| 458 |
+
elif char == "}":
|
| 459 |
+
depth -= 1
|
| 460 |
+
if depth == 0:
|
| 461 |
+
end = i + 1
|
| 462 |
+
break
|
| 463 |
+
|
| 464 |
+
if depth == 0 and end > start:
|
| 465 |
+
try:
|
| 466 |
+
return json.loads(response[start:end])
|
| 467 |
+
except json.JSONDecodeError:
|
| 468 |
+
pass
|
| 469 |
+
|
| 470 |
+
return None
|
| 471 |
+
|
| 472 |
+
async def _call_with_retry(
|
| 473 |
+
self,
|
| 474 |
+
messages: list[dict],
|
| 475 |
+
max_retries: int = 3,
|
| 476 |
+
) -> str:
|
| 477 |
+
"""
|
| 478 |
+
Call HF Inference with exponential backoff retry.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
messages: Chat messages in OpenAI format
|
| 482 |
+
max_retries: Max retry attempts
|
| 483 |
+
|
| 484 |
+
Returns:
|
| 485 |
+
Response text
|
| 486 |
+
|
| 487 |
+
Raises:
|
| 488 |
+
Exception if all retries fail
|
| 489 |
+
"""
|
| 490 |
+
import asyncio
|
| 491 |
+
import time
|
| 492 |
+
|
| 493 |
+
last_error = None
|
| 494 |
+
|
| 495 |
+
for attempt in range(max_retries):
|
| 496 |
+
try:
|
| 497 |
+
loop = asyncio.get_event_loop()
|
| 498 |
+
response = await loop.run_in_executor(
|
| 499 |
+
None,
|
| 500 |
+
lambda: self.client.chat_completion(
|
| 501 |
+
messages=messages,
|
| 502 |
+
max_tokens=1024,
|
| 503 |
+
temperature=0.1,
|
| 504 |
+
)
|
| 505 |
+
)
|
| 506 |
+
return response.choices[0].message.content
|
| 507 |
+
|
| 508 |
+
except Exception as e:
|
| 509 |
+
last_error = e
|
| 510 |
+
error_str = str(e).lower()
|
| 511 |
+
|
| 512 |
+
# Check if rate limited or service unavailable
|
| 513 |
+
is_rate_limit = "429" in error_str or "rate" in error_str
|
| 514 |
+
is_unavailable = "503" in error_str or "unavailable" in error_str
|
| 515 |
+
is_auth_error = "401" in error_str or "403" in error_str
|
| 516 |
+
|
| 517 |
+
if is_auth_error:
|
| 518 |
+
# Gated model without token - try fallback immediately
|
| 519 |
+
logger.warning("Auth error, trying fallback model", error=str(e))
|
| 520 |
+
if self._try_fallback_model():
|
| 521 |
+
continue
|
| 522 |
+
raise
|
| 523 |
+
|
| 524 |
+
if is_rate_limit or is_unavailable:
|
| 525 |
+
# Exponential backoff: 1s, 2s, 4s
|
| 526 |
+
wait_time = 2 ** attempt
|
| 527 |
+
logger.warning(
|
| 528 |
+
"Rate limited, retrying",
|
| 529 |
+
attempt=attempt + 1,
|
| 530 |
+
wait=wait_time,
|
| 531 |
+
error=str(e),
|
| 532 |
+
)
|
| 533 |
+
await asyncio.sleep(wait_time)
|
| 534 |
+
continue
|
| 535 |
+
|
| 536 |
+
# Other errors - raise immediately
|
| 537 |
+
raise
|
| 538 |
+
|
| 539 |
+
# All retries failed - try fallback model
|
| 540 |
+
if self._try_fallback_model():
|
| 541 |
+
return await self._call_with_retry(messages, max_retries=1)
|
| 542 |
+
|
| 543 |
+
raise last_error or Exception("All retries failed")
|
| 544 |
+
|
| 545 |
+
def _try_fallback_model(self) -> bool:
|
| 546 |
+
"""
|
| 547 |
+
Try to switch to a fallback model.
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
True if successfully switched, False if no fallbacks left
|
| 551 |
+
"""
|
| 552 |
+
from huggingface_hub import InferenceClient
|
| 553 |
+
|
| 554 |
+
# Remove current model from fallbacks
|
| 555 |
+
if self.model_id in self._fallback_models:
|
| 556 |
+
self._fallback_models.remove(self.model_id)
|
| 557 |
+
|
| 558 |
+
if not self._fallback_models:
|
| 559 |
+
return False
|
| 560 |
+
|
| 561 |
+
# Switch to next model
|
| 562 |
+
self.model_id = self._fallback_models[0]
|
| 563 |
+
self.client = InferenceClient(model=self.model_id)
|
| 564 |
+
logger.info("Switched to fallback model", model=self.model_id)
|
| 565 |
+
return True
|
| 566 |
+
|
| 567 |
+
async def assess(
|
| 568 |
+
self,
|
| 569 |
+
question: str,
|
| 570 |
+
evidence: List[Evidence],
|
| 571 |
+
) -> JudgeAssessment:
|
| 572 |
"""
|
| 573 |
+
Assess evidence using HuggingFace Inference API.
|
| 574 |
+
|
| 575 |
+
Uses chat_completion API for model-agnostic prompts.
|
| 576 |
+
Includes retry logic and fallback model chain.
|
| 577 |
|
| 578 |
Args:
|
| 579 |
+
question: The user's research question
|
| 580 |
+
evidence: List of Evidence objects from search
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
JudgeAssessment with evaluation results
|
| 584 |
"""
|
| 585 |
+
self.call_count += 1
|
| 586 |
+
self.last_question = question
|
| 587 |
+
self.last_evidence = evidence
|
| 588 |
+
|
| 589 |
+
# Format the prompt
|
| 590 |
+
if evidence:
|
| 591 |
+
user_prompt = format_user_prompt(question, evidence)
|
| 592 |
+
else:
|
| 593 |
+
user_prompt = format_empty_evidence_prompt(question)
|
| 594 |
+
|
| 595 |
+
# Build messages in OpenAI-compatible format (works with chat_completion)
|
| 596 |
+
json_schema = """{
|
| 597 |
+
"details": {
|
| 598 |
+
"mechanism_score": <int 0-10>,
|
| 599 |
+
"mechanism_reasoning": "<string>",
|
| 600 |
+
"clinical_evidence_score": <int 0-10>,
|
| 601 |
+
"clinical_reasoning": "<string>",
|
| 602 |
+
"drug_candidates": ["<string>", ...],
|
| 603 |
+
"key_findings": ["<string>", ...]
|
| 604 |
+
},
|
| 605 |
+
"sufficient": <bool>,
|
| 606 |
+
"confidence": <float 0-1>,
|
| 607 |
+
"recommendation": "continue" | "synthesize",
|
| 608 |
+
"next_search_queries": ["<string>", ...],
|
| 609 |
+
"reasoning": "<string>"
|
| 610 |
+
}"""
|
| 611 |
+
|
| 612 |
+
messages = [
|
| 613 |
+
{
|
| 614 |
+
"role": "system",
|
| 615 |
+
"content": f"{SYSTEM_PROMPT}\n\nIMPORTANT: Respond with ONLY valid JSON matching this schema:\n{json_schema}",
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"role": "user",
|
| 619 |
+
"content": user_prompt,
|
| 620 |
+
},
|
| 621 |
+
]
|
| 622 |
+
|
| 623 |
+
try:
|
| 624 |
+
# Call with retry and fallback
|
| 625 |
+
response = await self._call_with_retry(messages)
|
| 626 |
+
|
| 627 |
+
# Robust JSON extraction
|
| 628 |
+
data = self._extract_json(response)
|
| 629 |
+
if data:
|
| 630 |
+
return JudgeAssessment(**data)
|
| 631 |
+
|
| 632 |
+
# If no valid JSON, return fallback
|
| 633 |
+
logger.warning(
|
| 634 |
+
"HF Inference returned invalid JSON",
|
| 635 |
+
response=response[:200],
|
| 636 |
+
model=self.model_id,
|
| 637 |
+
)
|
| 638 |
+
return self._create_fallback_assessment(question, "Invalid JSON response")
|
| 639 |
+
|
| 640 |
+
except Exception as e:
|
| 641 |
+
logger.error("HF Inference failed", error=str(e), model=self.model_id)
|
| 642 |
+
return self._create_fallback_assessment(question, str(e))
|
| 643 |
+
|
| 644 |
+
def _create_fallback_assessment(
|
| 645 |
+
self,
|
| 646 |
+
question: str,
|
| 647 |
+
error: str,
|
| 648 |
+
) -> JudgeAssessment:
|
| 649 |
+
"""Create a fallback assessment when inference fails."""
|
| 650 |
+
return JudgeAssessment(
|
| 651 |
+
details=AssessmentDetails(
|
| 652 |
+
mechanism_score=0,
|
| 653 |
+
mechanism_reasoning=f"Assessment failed: {error}",
|
| 654 |
+
clinical_evidence_score=0,
|
| 655 |
+
clinical_reasoning=f"Assessment failed: {error}",
|
| 656 |
+
drug_candidates=[],
|
| 657 |
+
key_findings=[],
|
| 658 |
+
),
|
| 659 |
+
sufficient=False,
|
| 660 |
+
confidence=0.0,
|
| 661 |
+
recommendation="continue",
|
| 662 |
+
next_search_queries=[
|
| 663 |
+
f"{question} mechanism",
|
| 664 |
+
f"{question} clinical trials",
|
| 665 |
+
f"{question} drug candidates",
|
| 666 |
+
],
|
| 667 |
+
reasoning=f"HF Inference failed: {error}. Recommend retrying.",
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class MockJudgeHandler:
|
| 672 |
+
"""
|
| 673 |
+
Mock JudgeHandler for UNIT TESTING ONLY.
|
| 674 |
+
|
| 675 |
+
NOT for production use. Use HFInferenceJudgeHandler for demo mode.
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
def __init__(self, mock_response: JudgeAssessment | None = None):
|
| 679 |
+
"""Initialize with optional mock response for testing."""
|
| 680 |
self.mock_response = mock_response
|
| 681 |
self.call_count = 0
|
| 682 |
self.last_question = None
|
|
|
|
| 687 |
question: str,
|
| 688 |
evidence: List[Evidence],
|
| 689 |
) -> JudgeAssessment:
|
| 690 |
+
"""Return the mock response (for testing only)."""
|
| 691 |
self.call_count += 1
|
| 692 |
self.last_question = question
|
| 693 |
self.last_evidence = evidence
|
|
|
|
| 695 |
if self.mock_response:
|
| 696 |
return self.mock_response
|
| 697 |
|
| 698 |
+
# Default mock response for tests
|
| 699 |
return JudgeAssessment(
|
| 700 |
details=AssessmentDetails(
|
| 701 |
mechanism_score=7,
|
| 702 |
+
mechanism_reasoning="Mock assessment for testing",
|
| 703 |
clinical_evidence_score=6,
|
| 704 |
+
clinical_reasoning="Mock assessment for testing",
|
| 705 |
+
drug_candidates=["TestDrug"],
|
| 706 |
+
key_findings=["Test finding"],
|
| 707 |
),
|
| 708 |
sufficient=len(evidence) >= 3,
|
| 709 |
confidence=0.75,
|
| 710 |
recommendation="synthesize" if len(evidence) >= 3 else "continue",
|
| 711 |
next_search_queries=["query 1", "query 2"] if len(evidence) < 3 else [],
|
| 712 |
+
reasoning="Mock assessment for unit testing only",
|
| 713 |
)
|
| 714 |
```
|
| 715 |
|
|
|
|
| 860 |
assert "failed" in result.reasoning.lower()
|
| 861 |
|
| 862 |
|
| 863 |
+
class TestHFInferenceJudgeHandler:
|
| 864 |
+
"""Tests for HFInferenceJudgeHandler."""
|
| 865 |
+
|
| 866 |
+
@pytest.mark.asyncio
|
| 867 |
+
async def test_extract_json_raw(self):
|
| 868 |
+
"""Should extract raw JSON."""
|
| 869 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler
|
| 870 |
+
|
| 871 |
+
handler = HFInferenceJudgeHandler.__new__(HFInferenceJudgeHandler)
|
| 872 |
+
# Bypass __init__ for unit testing extraction
|
| 873 |
+
|
| 874 |
+
result = handler._extract_json('{"key": "value"}')
|
| 875 |
+
assert result == {"key": "value"}
|
| 876 |
+
|
| 877 |
+
@pytest.mark.asyncio
|
| 878 |
+
async def test_extract_json_markdown_block(self):
|
| 879 |
+
"""Should extract JSON from markdown code block."""
|
| 880 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler
|
| 881 |
+
|
| 882 |
+
handler = HFInferenceJudgeHandler.__new__(HFInferenceJudgeHandler)
|
| 883 |
+
|
| 884 |
+
response = '''Here is the assessment:
|
| 885 |
+
```json
|
| 886 |
+
{"key": "value", "nested": {"inner": 1}}
|
| 887 |
+
```
|
| 888 |
+
'''
|
| 889 |
+
result = handler._extract_json(response)
|
| 890 |
+
assert result == {"key": "value", "nested": {"inner": 1}}
|
| 891 |
+
|
| 892 |
+
@pytest.mark.asyncio
|
| 893 |
+
async def test_extract_json_with_preamble(self):
|
| 894 |
+
"""Should extract JSON with preamble text."""
|
| 895 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler
|
| 896 |
+
|
| 897 |
+
handler = HFInferenceJudgeHandler.__new__(HFInferenceJudgeHandler)
|
| 898 |
+
|
| 899 |
+
response = 'Here is your JSON response:\n{"sufficient": true, "confidence": 0.85}'
|
| 900 |
+
result = handler._extract_json(response)
|
| 901 |
+
assert result == {"sufficient": True, "confidence": 0.85}
|
| 902 |
+
|
| 903 |
+
@pytest.mark.asyncio
|
| 904 |
+
async def test_extract_json_nested_braces(self):
|
| 905 |
+
"""Should handle nested braces correctly."""
|
| 906 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler
|
| 907 |
+
|
| 908 |
+
handler = HFInferenceJudgeHandler.__new__(HFInferenceJudgeHandler)
|
| 909 |
+
|
| 910 |
+
response = '{"details": {"mechanism_score": 8}, "reasoning": "test"}'
|
| 911 |
+
result = handler._extract_json(response)
|
| 912 |
+
assert result["details"]["mechanism_score"] == 8
|
| 913 |
+
|
| 914 |
+
@pytest.mark.asyncio
|
| 915 |
+
async def test_hf_handler_uses_fallback_models(self):
|
| 916 |
+
"""HFInferenceJudgeHandler should have fallback model chain."""
|
| 917 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler
|
| 918 |
+
|
| 919 |
+
# Check class has fallback models defined
|
| 920 |
+
assert len(HFInferenceJudgeHandler.FALLBACK_MODELS) >= 3
|
| 921 |
+
assert "zephyr-7b-beta" in HFInferenceJudgeHandler.FALLBACK_MODELS[-1]
|
| 922 |
+
|
| 923 |
+
@pytest.mark.asyncio
|
| 924 |
+
async def test_hf_handler_fallback_on_auth_error(self):
|
| 925 |
+
"""Should fall back to ungated model on auth error."""
|
| 926 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler
|
| 927 |
+
from unittest.mock import MagicMock, patch
|
| 928 |
+
|
| 929 |
+
with patch("src.agent_factory.judges.InferenceClient") as mock_client_class:
|
| 930 |
+
# First call raises 403, second succeeds
|
| 931 |
+
mock_client = MagicMock()
|
| 932 |
+
mock_client.chat_completion.side_effect = [
|
| 933 |
+
Exception("403 Forbidden: gated model"),
|
| 934 |
+
MagicMock(choices=[MagicMock(message=MagicMock(content='{"sufficient": false}'))])
|
| 935 |
+
]
|
| 936 |
+
mock_client_class.return_value = mock_client
|
| 937 |
+
|
| 938 |
+
handler = HFInferenceJudgeHandler()
|
| 939 |
+
# Manually trigger fallback test
|
| 940 |
+
assert handler._try_fallback_model() is True
|
| 941 |
+
assert handler.model_id != "meta-llama/Llama-3.1-8B-Instruct"
|
| 942 |
+
|
| 943 |
+
|
| 944 |
class TestMockJudgeHandler:
|
| 945 |
+
"""Tests for MockJudgeHandler (UNIT TESTING ONLY)."""
|
| 946 |
|
| 947 |
@pytest.mark.asyncio
|
| 948 |
async def test_mock_handler_returns_default(self):
|
|
|
|
| 1035 |
"pydantic-ai>=0.0.16",
|
| 1036 |
"openai>=1.0.0",
|
| 1037 |
"anthropic>=0.18.0",
|
| 1038 |
+
"huggingface-hub>=0.20.0", # For HFInferenceJudgeHandler (FREE LLM)
|
| 1039 |
]
|
| 1040 |
```
|
| 1041 |
|
| 1042 |
+
**Note**: `huggingface-hub` is required for the free tier to work. It:
|
| 1043 |
+
- Provides `InferenceClient` for API calls
|
| 1044 |
+
- Auto-reads `HF_TOKEN` from environment (optional, for gated models)
|
| 1045 |
+
- Works without any token for ungated models like `zephyr-7b-beta`
|
| 1046 |
+
|
| 1047 |
---
|
| 1048 |
|
| 1049 |
## 7. Configuration (`src/utils/config.py`)
|
docs/implementation/04_phase_ui.md
CHANGED
|
@@ -408,33 +408,65 @@ from typing import AsyncGenerator
|
|
| 408 |
|
| 409 |
from src.orchestrator import Orchestrator
|
| 410 |
from src.tools.pubmed import PubMedTool
|
| 411 |
-
from src.tools.
|
|
|
|
| 412 |
from src.tools.search_handler import SearchHandler
|
| 413 |
-
from src.agent_factory.judges import JudgeHandler,
|
| 414 |
from src.utils.models import OrchestratorConfig, AgentEvent
|
| 415 |
|
| 416 |
|
| 417 |
-
def create_orchestrator(
|
|
|
|
|
|
|
|
|
|
| 418 |
"""
|
| 419 |
Create an orchestrator instance.
|
| 420 |
|
| 421 |
Args:
|
| 422 |
-
|
|
|
|
| 423 |
|
| 424 |
Returns:
|
| 425 |
-
Configured Orchestrator instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
"""
|
|
|
|
|
|
|
| 427 |
# Create search tools
|
| 428 |
search_handler = SearchHandler(
|
| 429 |
-
tools=[PubMedTool(),
|
| 430 |
timeout=30.0,
|
| 431 |
)
|
| 432 |
|
| 433 |
-
#
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
else:
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
# Create orchestrator
|
| 440 |
config = OrchestratorConfig(
|
|
@@ -446,12 +478,14 @@ def create_orchestrator(use_mock: bool = False) -> Orchestrator:
|
|
| 446 |
search_handler=search_handler,
|
| 447 |
judge_handler=judge_handler,
|
| 448 |
config=config,
|
| 449 |
-
)
|
| 450 |
|
| 451 |
|
| 452 |
async def research_agent(
|
| 453 |
message: str,
|
| 454 |
history: list[dict],
|
|
|
|
|
|
|
| 455 |
) -> AsyncGenerator[str, None]:
|
| 456 |
"""
|
| 457 |
Gradio chat function that runs the research agent.
|
|
@@ -459,6 +493,8 @@ async def research_agent(
|
|
| 459 |
Args:
|
| 460 |
message: User's research question
|
| 461 |
history: Chat history (Gradio format)
|
|
|
|
|
|
|
| 462 |
|
| 463 |
Yields:
|
| 464 |
Markdown-formatted responses for streaming
|
|
@@ -467,10 +503,31 @@ async def research_agent(
|
|
| 467 |
yield "Please enter a research question."
|
| 468 |
return
|
| 469 |
|
| 470 |
-
# Create orchestrator (use mock if no API key)
|
| 471 |
import os
|
| 472 |
-
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
# Run the agent and stream events
|
| 476 |
response_parts = []
|
|
@@ -952,15 +1009,22 @@ uv run python -m src.app
|
|
| 952 |
import asyncio
|
| 953 |
from src.orchestrator import Orchestrator
|
| 954 |
from src.tools.pubmed import PubMedTool
|
| 955 |
-
from src.tools.
|
|
|
|
| 956 |
from src.tools.search_handler import SearchHandler
|
| 957 |
-
from src.agent_factory.judges import MockJudgeHandler
|
| 958 |
from src.utils.models import OrchestratorConfig
|
| 959 |
|
| 960 |
async def test_full_flow():
|
| 961 |
# Create components
|
| 962 |
-
search_handler = SearchHandler([PubMedTool(),
|
| 963 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 964 |
config = OrchestratorConfig(max_iterations=3)
|
| 965 |
|
| 966 |
# Create orchestrator
|
|
@@ -980,6 +1044,8 @@ async def test_full_flow():
|
|
| 980 |
asyncio.run(test_full_flow())
|
| 981 |
```
|
| 982 |
|
|
|
|
|
|
|
| 983 |
---
|
| 984 |
|
| 985 |
## 10. Deployment Verification
|
|
|
|
| 408 |
|
| 409 |
from src.orchestrator import Orchestrator
|
| 410 |
from src.tools.pubmed import PubMedTool
|
| 411 |
+
from src.tools.clinicaltrials import ClinicalTrialsTool
|
| 412 |
+
from src.tools.biorxiv import BioRxivTool
|
| 413 |
from src.tools.search_handler import SearchHandler
|
| 414 |
+
from src.agent_factory.judges import JudgeHandler, HFInferenceJudgeHandler
|
| 415 |
from src.utils.models import OrchestratorConfig, AgentEvent
|
| 416 |
|
| 417 |
|
| 418 |
+
def create_orchestrator(
|
| 419 |
+
user_api_key: str | None = None,
|
| 420 |
+
api_provider: str = "openai",
|
| 421 |
+
) -> tuple[Orchestrator, str]:
|
| 422 |
"""
|
| 423 |
Create an orchestrator instance.
|
| 424 |
|
| 425 |
Args:
|
| 426 |
+
user_api_key: Optional user-provided API key (BYOK)
|
| 427 |
+
api_provider: API provider ("openai" or "anthropic")
|
| 428 |
|
| 429 |
Returns:
|
| 430 |
+
Tuple of (Configured Orchestrator instance, backend_name)
|
| 431 |
+
|
| 432 |
+
Priority:
|
| 433 |
+
1. User-provided API key → JudgeHandler (OpenAI/Anthropic)
|
| 434 |
+
2. Environment API key → JudgeHandler (OpenAI/Anthropic)
|
| 435 |
+
3. No key → HFInferenceJudgeHandler (FREE, automatic fallback chain)
|
| 436 |
+
|
| 437 |
+
HF Inference Fallback Chain:
|
| 438 |
+
1. Llama 3.1 8B (requires HF_TOKEN for gated model)
|
| 439 |
+
2. Mistral 7B (may require token)
|
| 440 |
+
3. Zephyr 7B (ungated, always works)
|
| 441 |
"""
|
| 442 |
+
import os
|
| 443 |
+
|
| 444 |
# Create search tools
|
| 445 |
search_handler = SearchHandler(
|
| 446 |
+
tools=[PubMedTool(), ClinicalTrialsTool(), BioRxivTool()],
|
| 447 |
timeout=30.0,
|
| 448 |
)
|
| 449 |
|
| 450 |
+
# Determine which judge to use
|
| 451 |
+
has_env_key = bool(os.getenv("OPENAI_API_KEY") or os.getenv("ANTHROPIC_API_KEY"))
|
| 452 |
+
has_user_key = bool(user_api_key)
|
| 453 |
+
has_hf_token = bool(os.getenv("HF_TOKEN"))
|
| 454 |
+
|
| 455 |
+
if has_user_key:
|
| 456 |
+
# User provided their own key
|
| 457 |
+
judge_handler = JudgeHandler(model=None)
|
| 458 |
+
backend_name = f"your {api_provider.upper()} API key"
|
| 459 |
+
elif has_env_key:
|
| 460 |
+
# Environment has API key configured
|
| 461 |
+
judge_handler = JudgeHandler(model=None)
|
| 462 |
+
backend_name = "configured API key"
|
| 463 |
else:
|
| 464 |
+
# Use FREE HuggingFace Inference with automatic fallback
|
| 465 |
+
judge_handler = HFInferenceJudgeHandler()
|
| 466 |
+
if has_hf_token:
|
| 467 |
+
backend_name = "HuggingFace Inference (Llama 3.1)"
|
| 468 |
+
else:
|
| 469 |
+
backend_name = "HuggingFace Inference (free tier)"
|
| 470 |
|
| 471 |
# Create orchestrator
|
| 472 |
config = OrchestratorConfig(
|
|
|
|
| 478 |
search_handler=search_handler,
|
| 479 |
judge_handler=judge_handler,
|
| 480 |
config=config,
|
| 481 |
+
), backend_name
|
| 482 |
|
| 483 |
|
| 484 |
async def research_agent(
|
| 485 |
message: str,
|
| 486 |
history: list[dict],
|
| 487 |
+
api_key: str = "",
|
| 488 |
+
api_provider: str = "openai",
|
| 489 |
) -> AsyncGenerator[str, None]:
|
| 490 |
"""
|
| 491 |
Gradio chat function that runs the research agent.
|
|
|
|
| 493 |
Args:
|
| 494 |
message: User's research question
|
| 495 |
history: Chat history (Gradio format)
|
| 496 |
+
api_key: Optional user-provided API key (BYOK)
|
| 497 |
+
api_provider: API provider ("openai" or "anthropic")
|
| 498 |
|
| 499 |
Yields:
|
| 500 |
Markdown-formatted responses for streaming
|
|
|
|
| 503 |
yield "Please enter a research question."
|
| 504 |
return
|
| 505 |
|
|
|
|
| 506 |
import os
|
| 507 |
+
|
| 508 |
+
# Clean user-provided API key
|
| 509 |
+
user_api_key = api_key.strip() if api_key else None
|
| 510 |
+
|
| 511 |
+
# Create orchestrator with appropriate judge
|
| 512 |
+
orchestrator, backend_name = create_orchestrator(
|
| 513 |
+
user_api_key=user_api_key,
|
| 514 |
+
api_provider=api_provider,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Determine icon based on backend
|
| 518 |
+
has_hf_token = bool(os.getenv("HF_TOKEN"))
|
| 519 |
+
if "HuggingFace" in backend_name:
|
| 520 |
+
icon = "🤗"
|
| 521 |
+
extra_note = (
|
| 522 |
+
"\n*For premium analysis, enter an OpenAI or Anthropic API key.*"
|
| 523 |
+
if not has_hf_token else ""
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
icon = "🔑"
|
| 527 |
+
extra_note = ""
|
| 528 |
+
|
| 529 |
+
# Inform user which backend is being used
|
| 530 |
+
yield f"{icon} **Using {backend_name}**{extra_note}\n\n"
|
| 531 |
|
| 532 |
# Run the agent and stream events
|
| 533 |
response_parts = []
|
|
|
|
| 1009 |
import asyncio
|
| 1010 |
from src.orchestrator import Orchestrator
|
| 1011 |
from src.tools.pubmed import PubMedTool
|
| 1012 |
+
from src.tools.biorxiv import BioRxivTool
|
| 1013 |
+
from src.tools.clinicaltrials import ClinicalTrialsTool
|
| 1014 |
from src.tools.search_handler import SearchHandler
|
| 1015 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler, MockJudgeHandler
|
| 1016 |
from src.utils.models import OrchestratorConfig
|
| 1017 |
|
| 1018 |
async def test_full_flow():
|
| 1019 |
# Create components
|
| 1020 |
+
search_handler = SearchHandler([PubMedTool(), ClinicalTrialsTool(), BioRxivTool()])
|
| 1021 |
+
|
| 1022 |
+
# Option 1: Use FREE HuggingFace Inference (real AI analysis)
|
| 1023 |
+
judge_handler = HFInferenceJudgeHandler()
|
| 1024 |
+
|
| 1025 |
+
# Option 2: Use MockJudgeHandler for UNIT TESTING ONLY
|
| 1026 |
+
# judge_handler = MockJudgeHandler()
|
| 1027 |
+
|
| 1028 |
config = OrchestratorConfig(max_iterations=3)
|
| 1029 |
|
| 1030 |
# Create orchestrator
|
|
|
|
| 1044 |
asyncio.run(test_full_flow())
|
| 1045 |
```
|
| 1046 |
|
| 1047 |
+
**Important**: `MockJudgeHandler` is for **unit testing only**. For actual demo/production use, always use `HFInferenceJudgeHandler` (free) or `JudgeHandler` (with API key).
|
| 1048 |
+
|
| 1049 |
---
|
| 1050 |
|
| 1051 |
## 10. Deployment Verification
|
pyproject.toml
CHANGED
|
@@ -16,6 +16,7 @@ dependencies = [
|
|
| 16 |
"httpx>=0.27", # Async HTTP client (PubMed)
|
| 17 |
"beautifulsoup4>=4.12", # HTML parsing
|
| 18 |
"xmltodict>=0.13", # PubMed XML -> dict
|
|
|
|
| 19 |
# UI
|
| 20 |
"gradio[mcp]>=6.0.0", # Chat interface with MCP server support (6.0 required for css in launch())
|
| 21 |
# Utils
|
|
|
|
| 16 |
"httpx>=0.27", # Async HTTP client (PubMed)
|
| 17 |
"beautifulsoup4>=4.12", # HTML parsing
|
| 18 |
"xmltodict>=0.13", # PubMed XML -> dict
|
| 19 |
+
"huggingface-hub>=0.20.0", # Hugging Face Inference API
|
| 20 |
# UI
|
| 21 |
"gradio[mcp]>=6.0.0", # Chat interface with MCP server support (6.0 required for css in launch())
|
| 22 |
# Utils
|
src/agent_factory/judges.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
| 1 |
"""Judge handler for evidence assessment using PydanticAI."""
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import structlog
|
|
|
|
| 6 |
from pydantic_ai import Agent
|
| 7 |
from pydantic_ai.models.anthropic import AnthropicModel
|
| 8 |
from pydantic_ai.models.openai import OpenAIModel
|
| 9 |
from pydantic_ai.providers.anthropic import AnthropicProvider
|
| 10 |
from pydantic_ai.providers.openai import OpenAIProvider
|
|
|
|
| 11 |
|
| 12 |
from src.prompts.judge import (
|
| 13 |
SYSTEM_PROMPT,
|
|
@@ -146,6 +150,200 @@ class JudgeHandler:
|
|
| 146 |
)
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
class MockJudgeHandler:
|
| 150 |
"""
|
| 151 |
Mock JudgeHandler for demo mode without LLM calls.
|
|
|
|
| 1 |
"""Judge handler for evidence assessment using PydanticAI."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
from typing import Any, ClassVar
|
| 6 |
|
| 7 |
import structlog
|
| 8 |
+
from huggingface_hub import InferenceClient
|
| 9 |
from pydantic_ai import Agent
|
| 10 |
from pydantic_ai.models.anthropic import AnthropicModel
|
| 11 |
from pydantic_ai.models.openai import OpenAIModel
|
| 12 |
from pydantic_ai.providers.anthropic import AnthropicProvider
|
| 13 |
from pydantic_ai.providers.openai import OpenAIProvider
|
| 14 |
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
| 15 |
|
| 16 |
from src.prompts.judge import (
|
| 17 |
SYSTEM_PROMPT,
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
|
| 153 |
+
class HFInferenceJudgeHandler:
|
| 154 |
+
"""
|
| 155 |
+
JudgeHandler using HuggingFace Inference API for FREE LLM calls.
|
| 156 |
+
Defaults to Llama-3.1-8B-Instruct (requires HF_TOKEN) or falls back to public models.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
FALLBACK_MODELS: ClassVar[list[str]] = [
|
| 160 |
+
"meta-llama/Llama-3.1-8B-Instruct", # Primary (Gated)
|
| 161 |
+
"mistralai/Mistral-7B-Instruct-v0.3", # Secondary
|
| 162 |
+
"HuggingFaceH4/zephyr-7b-beta", # Fallback (Ungated)
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
def __init__(self, model_id: str | None = None) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Initialize with HF Inference client.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
model_id: Optional specific model ID. If None, uses FALLBACK_MODELS chain.
|
| 171 |
+
"""
|
| 172 |
+
self.model_id = model_id
|
| 173 |
+
# Will automatically use HF_TOKEN from env if available
|
| 174 |
+
self.client = InferenceClient()
|
| 175 |
+
self.call_count = 0
|
| 176 |
+
self.last_question: str | None = None
|
| 177 |
+
self.last_evidence: list[Evidence] | None = None
|
| 178 |
+
|
| 179 |
+
async def assess(
|
| 180 |
+
self,
|
| 181 |
+
question: str,
|
| 182 |
+
evidence: list[Evidence],
|
| 183 |
+
) -> JudgeAssessment:
|
| 184 |
+
"""
|
| 185 |
+
Assess evidence using HuggingFace Inference API.
|
| 186 |
+
Attempts models in order until one succeeds.
|
| 187 |
+
"""
|
| 188 |
+
self.call_count += 1
|
| 189 |
+
self.last_question = question
|
| 190 |
+
self.last_evidence = evidence
|
| 191 |
+
|
| 192 |
+
# Format the user prompt
|
| 193 |
+
if evidence:
|
| 194 |
+
user_prompt = format_user_prompt(question, evidence)
|
| 195 |
+
else:
|
| 196 |
+
user_prompt = format_empty_evidence_prompt(question)
|
| 197 |
+
|
| 198 |
+
models_to_try = [self.model_id] if self.model_id else self.FALLBACK_MODELS
|
| 199 |
+
last_error = None
|
| 200 |
+
|
| 201 |
+
for model in models_to_try:
|
| 202 |
+
try:
|
| 203 |
+
return await self._call_with_retry(model, user_prompt, question)
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.warning(f"Model {model} failed", error=str(e))
|
| 206 |
+
last_error = e
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
# All models failed
|
| 210 |
+
logger.error("All HF models failed", error=str(last_error))
|
| 211 |
+
return self._create_fallback_assessment(question, str(last_error))
|
| 212 |
+
|
| 213 |
+
@retry(
|
| 214 |
+
stop=stop_after_attempt(3),
|
| 215 |
+
wait=wait_exponential(multiplier=1, min=1, max=4),
|
| 216 |
+
retry=retry_if_exception_type(Exception),
|
| 217 |
+
reraise=True,
|
| 218 |
+
)
|
| 219 |
+
async def _call_with_retry(self, model: str, prompt: str, question: str) -> JudgeAssessment:
|
| 220 |
+
"""Make API call with retry logic using chat_completion."""
|
| 221 |
+
loop = asyncio.get_running_loop()
|
| 222 |
+
|
| 223 |
+
# Build messages for chat_completion (model-agnostic)
|
| 224 |
+
messages = [
|
| 225 |
+
{
|
| 226 |
+
"role": "system",
|
| 227 |
+
"content": f"""{SYSTEM_PROMPT}
|
| 228 |
+
|
| 229 |
+
IMPORTANT: Respond with ONLY valid JSON matching this schema:
|
| 230 |
+
{{
|
| 231 |
+
"details": {{
|
| 232 |
+
"mechanism_score": <int 0-10>,
|
| 233 |
+
"mechanism_reasoning": "<string>",
|
| 234 |
+
"clinical_evidence_score": <int 0-10>,
|
| 235 |
+
"clinical_reasoning": "<string>",
|
| 236 |
+
"drug_candidates": ["<string>", ...],
|
| 237 |
+
"key_findings": ["<string>", ...]
|
| 238 |
+
}},
|
| 239 |
+
"sufficient": <bool>,
|
| 240 |
+
"confidence": <float 0-1>,
|
| 241 |
+
"recommendation": "continue" | "synthesize",
|
| 242 |
+
"next_search_queries": ["<string>", ...],
|
| 243 |
+
"reasoning": "<string>"
|
| 244 |
+
}}""",
|
| 245 |
+
},
|
| 246 |
+
{"role": "user", "content": prompt},
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
# Use chat_completion (conversational task - supported by all models)
|
| 250 |
+
response = await loop.run_in_executor(
|
| 251 |
+
None,
|
| 252 |
+
lambda: self.client.chat_completion(
|
| 253 |
+
messages=messages,
|
| 254 |
+
model=model,
|
| 255 |
+
max_tokens=1024,
|
| 256 |
+
temperature=0.1,
|
| 257 |
+
),
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Extract content from response
|
| 261 |
+
content = response.choices[0].message.content
|
| 262 |
+
if not content:
|
| 263 |
+
raise ValueError("Empty response from model")
|
| 264 |
+
|
| 265 |
+
# Extract and parse JSON
|
| 266 |
+
json_data = self._extract_json(content)
|
| 267 |
+
if not json_data:
|
| 268 |
+
raise ValueError("No valid JSON found in response")
|
| 269 |
+
|
| 270 |
+
return JudgeAssessment(**json_data)
|
| 271 |
+
|
| 272 |
+
def _extract_json(self, text: str) -> dict[str, Any] | None:
|
| 273 |
+
"""
|
| 274 |
+
Robust JSON extraction that handles markdown blocks and nested braces.
|
| 275 |
+
"""
|
| 276 |
+
text = text.strip()
|
| 277 |
+
|
| 278 |
+
# Remove markdown code blocks if present
|
| 279 |
+
if "```json" in text:
|
| 280 |
+
text = text.split("```json")[1].split("```")[0]
|
| 281 |
+
elif "```" in text:
|
| 282 |
+
text = text.split("```")[1].split("```")[0]
|
| 283 |
+
|
| 284 |
+
text = text.strip()
|
| 285 |
+
|
| 286 |
+
# Find first '{'
|
| 287 |
+
start_idx = text.find("{")
|
| 288 |
+
if start_idx == -1:
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
# Stack-based parsing ignoring chars in strings
|
| 292 |
+
count = 0
|
| 293 |
+
in_string = False
|
| 294 |
+
escape = False
|
| 295 |
+
|
| 296 |
+
for i, char in enumerate(text[start_idx:], start=start_idx):
|
| 297 |
+
if in_string:
|
| 298 |
+
if escape:
|
| 299 |
+
escape = False
|
| 300 |
+
elif char == "\\":
|
| 301 |
+
escape = True
|
| 302 |
+
elif char == '"':
|
| 303 |
+
in_string = False
|
| 304 |
+
elif char == '"':
|
| 305 |
+
in_string = True
|
| 306 |
+
elif char == "{":
|
| 307 |
+
count += 1
|
| 308 |
+
elif char == "}":
|
| 309 |
+
count -= 1
|
| 310 |
+
if count == 0:
|
| 311 |
+
try:
|
| 312 |
+
result = json.loads(text[start_idx : i + 1])
|
| 313 |
+
if isinstance(result, dict):
|
| 314 |
+
return result
|
| 315 |
+
return None
|
| 316 |
+
except json.JSONDecodeError:
|
| 317 |
+
return None
|
| 318 |
+
|
| 319 |
+
return None
|
| 320 |
+
|
| 321 |
+
def _create_fallback_assessment(
|
| 322 |
+
self,
|
| 323 |
+
question: str,
|
| 324 |
+
error: str,
|
| 325 |
+
) -> JudgeAssessment:
|
| 326 |
+
"""Create a fallback assessment when inference fails."""
|
| 327 |
+
return JudgeAssessment(
|
| 328 |
+
details=AssessmentDetails(
|
| 329 |
+
mechanism_score=0,
|
| 330 |
+
mechanism_reasoning=f"Assessment failed: {error}",
|
| 331 |
+
clinical_evidence_score=0,
|
| 332 |
+
clinical_reasoning=f"Assessment failed: {error}",
|
| 333 |
+
drug_candidates=[],
|
| 334 |
+
key_findings=[],
|
| 335 |
+
),
|
| 336 |
+
sufficient=False,
|
| 337 |
+
confidence=0.0,
|
| 338 |
+
recommendation="continue",
|
| 339 |
+
next_search_queries=[
|
| 340 |
+
f"{question} mechanism",
|
| 341 |
+
f"{question} clinical trials",
|
| 342 |
+
],
|
| 343 |
+
reasoning=f"HF Inference failed: {error}. Recommend configuring OpenAI/Anthropic key.",
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
class MockJudgeHandler:
|
| 348 |
"""
|
| 349 |
Mock JudgeHandler for demo mode without LLM calls.
|
src/app.py
CHANGED
|
@@ -10,7 +10,7 @@ from pydantic_ai.models.openai import OpenAIModel
|
|
| 10 |
from pydantic_ai.providers.anthropic import AnthropicProvider
|
| 11 |
from pydantic_ai.providers.openai import OpenAIProvider
|
| 12 |
|
| 13 |
-
from src.agent_factory.judges import JudgeHandler, MockJudgeHandler
|
| 14 |
from src.mcp_tools import (
|
| 15 |
analyze_hypothesis,
|
| 16 |
search_all_sources,
|
|
@@ -32,7 +32,7 @@ def configure_orchestrator(
|
|
| 32 |
mode: str = "simple",
|
| 33 |
user_api_key: str | None = None,
|
| 34 |
api_provider: str = "openai",
|
| 35 |
-
) -> Any:
|
| 36 |
"""
|
| 37 |
Create an orchestrator instance.
|
| 38 |
|
|
@@ -43,7 +43,7 @@ def configure_orchestrator(
|
|
| 43 |
api_provider: API provider ("openai" or "anthropic")
|
| 44 |
|
| 45 |
Returns:
|
| 46 |
-
|
| 47 |
"""
|
| 48 |
# Create orchestrator config
|
| 49 |
config = OrchestratorConfig(
|
|
@@ -57,12 +57,21 @@ def configure_orchestrator(
|
|
| 57 |
timeout=config.search_timeout,
|
| 58 |
)
|
| 59 |
|
| 60 |
-
# Create judge (mock or
|
| 61 |
-
judge_handler: JudgeHandler | MockJudgeHandler
|
|
|
|
|
|
|
|
|
|
| 62 |
if use_mock:
|
| 63 |
judge_handler = MockJudgeHandler()
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
model: AnthropicModel | OpenAIModel | None = None
|
| 67 |
if user_api_key:
|
| 68 |
if api_provider == "anthropic":
|
|
@@ -71,17 +80,26 @@ def configure_orchestrator(
|
|
| 71 |
elif api_provider == "openai":
|
| 72 |
openai_provider = OpenAIProvider(api_key=user_api_key)
|
| 73 |
model = OpenAIModel(settings.openai_model, provider=openai_provider)
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
judge_handler = JudgeHandler(model=model)
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
search_handler=search_handler,
|
| 80 |
judge_handler=judge_handler,
|
| 81 |
config=config,
|
| 82 |
mode=mode, # type: ignore
|
| 83 |
)
|
| 84 |
|
|
|
|
|
|
|
| 85 |
|
| 86 |
async def research_agent(
|
| 87 |
message: str,
|
|
@@ -110,54 +128,47 @@ async def research_agent(
|
|
| 110 |
# Clean user-provided API key
|
| 111 |
user_api_key = api_key.strip() if api_key else None
|
| 112 |
|
| 113 |
-
#
|
| 114 |
has_openai = bool(os.getenv("OPENAI_API_KEY"))
|
| 115 |
has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY"))
|
| 116 |
has_user_key = bool(user_api_key)
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
use_mock = not (has_openai or (has_user_key and api_provider == "openai"))
|
| 121 |
-
else:
|
| 122 |
-
# Simple mode can work with either provider
|
| 123 |
-
use_mock = not (has_openai or has_anthropic or has_user_key)
|
| 124 |
-
|
| 125 |
-
# If magentic mode requested but no OpenAI key, fallback/warn
|
| 126 |
-
if mode == "magentic" and use_mock:
|
| 127 |
yield (
|
| 128 |
-
"⚠️ **Warning**: Magentic mode requires OpenAI API key. "
|
| 129 |
-
"Falling back to demo mode.\n\n"
|
| 130 |
)
|
| 131 |
mode = "simple"
|
| 132 |
|
| 133 |
# Inform user about their key being used
|
| 134 |
-
if has_user_key
|
| 135 |
yield (
|
| 136 |
f"🔑 **Using your {api_provider.upper()} API key** - "
|
| 137 |
"Your key is used only for this session and is never stored.\n\n"
|
| 138 |
)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
if use_mock:
|
| 142 |
yield (
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"**To unlock full AI analysis:**\n"
|
| 146 |
-
"- Enter your OpenAI or Anthropic API key below, OR\n"
|
| 147 |
-
"- Configure secrets in HuggingFace Space settings\n\n"
|
| 148 |
-
"---\n\n"
|
| 149 |
)
|
| 150 |
|
| 151 |
# Run the agent and stream events
|
| 152 |
response_parts: list[str] = []
|
| 153 |
|
| 154 |
try:
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
| 157 |
mode=mode,
|
| 158 |
user_api_key=user_api_key,
|
| 159 |
api_provider=api_provider,
|
| 160 |
)
|
|
|
|
|
|
|
|
|
|
| 161 |
async for event in orchestrator.run(message):
|
| 162 |
# Format event as markdown
|
| 163 |
event_md = event.to_markdown()
|
|
|
|
| 10 |
from pydantic_ai.providers.anthropic import AnthropicProvider
|
| 11 |
from pydantic_ai.providers.openai import OpenAIProvider
|
| 12 |
|
| 13 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler
|
| 14 |
from src.mcp_tools import (
|
| 15 |
analyze_hypothesis,
|
| 16 |
search_all_sources,
|
|
|
|
| 32 |
mode: str = "simple",
|
| 33 |
user_api_key: str | None = None,
|
| 34 |
api_provider: str = "openai",
|
| 35 |
+
) -> tuple[Any, str]:
|
| 36 |
"""
|
| 37 |
Create an orchestrator instance.
|
| 38 |
|
|
|
|
| 43 |
api_provider: API provider ("openai" or "anthropic")
|
| 44 |
|
| 45 |
Returns:
|
| 46 |
+
Tuple of (Orchestrator instance, backend_name)
|
| 47 |
"""
|
| 48 |
# Create orchestrator config
|
| 49 |
config = OrchestratorConfig(
|
|
|
|
| 57 |
timeout=config.search_timeout,
|
| 58 |
)
|
| 59 |
|
| 60 |
+
# Create judge (mock, real, or free tier)
|
| 61 |
+
judge_handler: JudgeHandler | MockJudgeHandler | HFInferenceJudgeHandler
|
| 62 |
+
backend_info = "Unknown"
|
| 63 |
+
|
| 64 |
+
# 1. Forced Mock (Unit Testing)
|
| 65 |
if use_mock:
|
| 66 |
judge_handler = MockJudgeHandler()
|
| 67 |
+
backend_info = "Mock (Testing)"
|
| 68 |
+
|
| 69 |
+
# 2. Paid API Key (User provided or Env)
|
| 70 |
+
elif (
|
| 71 |
+
user_api_key
|
| 72 |
+
or (api_provider == "openai" and os.getenv("OPENAI_API_KEY"))
|
| 73 |
+
or (api_provider == "anthropic" and os.getenv("ANTHROPIC_API_KEY"))
|
| 74 |
+
):
|
| 75 |
model: AnthropicModel | OpenAIModel | None = None
|
| 76 |
if user_api_key:
|
| 77 |
if api_provider == "anthropic":
|
|
|
|
| 80 |
elif api_provider == "openai":
|
| 81 |
openai_provider = OpenAIProvider(api_key=user_api_key)
|
| 82 |
model = OpenAIModel(settings.openai_model, provider=openai_provider)
|
| 83 |
+
backend_info = f"Paid API ({api_provider.upper()})"
|
| 84 |
+
else:
|
| 85 |
+
backend_info = "Paid API (Env Config)"
|
| 86 |
+
|
| 87 |
judge_handler = JudgeHandler(model=model)
|
| 88 |
|
| 89 |
+
# 3. Free Tier (HuggingFace Inference)
|
| 90 |
+
else:
|
| 91 |
+
judge_handler = HFInferenceJudgeHandler()
|
| 92 |
+
backend_info = "Free Tier (Llama 3.1 / Mistral)"
|
| 93 |
+
|
| 94 |
+
orchestrator = create_orchestrator(
|
| 95 |
search_handler=search_handler,
|
| 96 |
judge_handler=judge_handler,
|
| 97 |
config=config,
|
| 98 |
mode=mode, # type: ignore
|
| 99 |
)
|
| 100 |
|
| 101 |
+
return orchestrator, backend_info
|
| 102 |
+
|
| 103 |
|
| 104 |
async def research_agent(
|
| 105 |
message: str,
|
|
|
|
| 128 |
# Clean user-provided API key
|
| 129 |
user_api_key = api_key.strip() if api_key else None
|
| 130 |
|
| 131 |
+
# Check available keys
|
| 132 |
has_openai = bool(os.getenv("OPENAI_API_KEY"))
|
| 133 |
has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY"))
|
| 134 |
has_user_key = bool(user_api_key)
|
| 135 |
+
has_paid_key = has_openai or has_anthropic or has_user_key
|
| 136 |
|
| 137 |
+
# Magentic mode requires OpenAI specifically
|
| 138 |
+
if mode == "magentic" and not (has_openai or (has_user_key and api_provider == "openai")):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
yield (
|
| 140 |
+
"⚠️ **Warning**: Magentic mode requires OpenAI API key. Falling back to simple mode.\n\n"
|
|
|
|
| 141 |
)
|
| 142 |
mode = "simple"
|
| 143 |
|
| 144 |
# Inform user about their key being used
|
| 145 |
+
if has_user_key:
|
| 146 |
yield (
|
| 147 |
f"🔑 **Using your {api_provider.upper()} API key** - "
|
| 148 |
"Your key is used only for this session and is never stored.\n\n"
|
| 149 |
)
|
| 150 |
+
elif not has_paid_key:
|
| 151 |
+
# No paid keys - will use FREE HuggingFace Inference
|
|
|
|
| 152 |
yield (
|
| 153 |
+
"🤗 **Free Tier**: Using HuggingFace Inference (Llama 3.1 / Mistral) for AI analysis.\n"
|
| 154 |
+
"For premium models, enter an OpenAI or Anthropic API key below.\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
# Run the agent and stream events
|
| 158 |
response_parts: list[str] = []
|
| 159 |
|
| 160 |
try:
|
| 161 |
+
# use_mock=False - let configure_orchestrator decide based on available keys
|
| 162 |
+
# It will use: Paid API > HF Inference (free tier)
|
| 163 |
+
orchestrator, backend_name = configure_orchestrator(
|
| 164 |
+
use_mock=False, # Never use mock in production - HF Inference is the free fallback
|
| 165 |
mode=mode,
|
| 166 |
user_api_key=user_api_key,
|
| 167 |
api_provider=api_provider,
|
| 168 |
)
|
| 169 |
+
|
| 170 |
+
yield f"🧠 **Backend**: {backend_name}\n\n"
|
| 171 |
+
|
| 172 |
async for event in orchestrator.run(message):
|
| 173 |
# Format event as markdown
|
| 174 |
event_md = event.to_markdown()
|
src/prompts/report.py
CHANGED
|
@@ -124,13 +124,13 @@ async def format_report_prompt(
|
|
| 124 |
{hypotheses_summary}
|
| 125 |
|
| 126 |
## Assessment Scores
|
| 127 |
-
- Mechanism Score: {assessment.get(
|
| 128 |
-
- Clinical Evidence Score: {assessment.get(
|
| 129 |
-
- Overall Confidence: {assessment.get(
|
| 130 |
|
| 131 |
## Metadata
|
| 132 |
- Sources Searched: {sources}
|
| 133 |
-
- Search Iterations: {metadata.get(
|
| 134 |
|
| 135 |
Generate a complete ResearchReport with all sections filled in.
|
| 136 |
|
|
|
|
| 124 |
{hypotheses_summary}
|
| 125 |
|
| 126 |
## Assessment Scores
|
| 127 |
+
- Mechanism Score: {assessment.get("mechanism_score", "N/A")}/10
|
| 128 |
+
- Clinical Evidence Score: {assessment.get("clinical_score", "N/A")}/10
|
| 129 |
+
- Overall Confidence: {assessment.get("confidence", 0):.0%}
|
| 130 |
|
| 131 |
## Metadata
|
| 132 |
- Sources Searched: {sources}
|
| 133 |
+
- Search Iterations: {metadata.get("iterations", 0)}
|
| 134 |
|
| 135 |
Generate a complete ResearchReport with all sections filled in.
|
| 136 |
|
tests/unit/agent_factory/test_judges_hf.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for HFInferenceJudgeHandler."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.agent_factory.judges import HFInferenceJudgeHandler
|
| 8 |
+
from src.utils.models import Citation, Evidence
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestHFInferenceJudgeHandler:
|
| 12 |
+
"""Tests for HFInferenceJudgeHandler."""
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def mock_client(self):
|
| 16 |
+
"""Mock HuggingFace InferenceClient."""
|
| 17 |
+
with patch("src.agent_factory.judges.InferenceClient") as mock:
|
| 18 |
+
client_instance = MagicMock()
|
| 19 |
+
mock.return_value = client_instance
|
| 20 |
+
yield client_instance
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def handler(self, mock_client):
|
| 24 |
+
"""Create a handler instance with mocked client."""
|
| 25 |
+
return HFInferenceJudgeHandler()
|
| 26 |
+
|
| 27 |
+
@pytest.mark.asyncio
|
| 28 |
+
async def test_assess_success(self, handler, mock_client):
|
| 29 |
+
"""Test successful assessment with primary model."""
|
| 30 |
+
import json
|
| 31 |
+
|
| 32 |
+
# Construct valid JSON payload
|
| 33 |
+
data = {
|
| 34 |
+
"details": {
|
| 35 |
+
"mechanism_score": 8,
|
| 36 |
+
"mechanism_reasoning": "Good mechanism",
|
| 37 |
+
"clinical_evidence_score": 7,
|
| 38 |
+
"clinical_reasoning": "Good clinical",
|
| 39 |
+
"drug_candidates": ["Drug A"],
|
| 40 |
+
"key_findings": ["Finding 1"],
|
| 41 |
+
},
|
| 42 |
+
"sufficient": True,
|
| 43 |
+
"confidence": 0.85,
|
| 44 |
+
"recommendation": "synthesize",
|
| 45 |
+
"next_search_queries": [],
|
| 46 |
+
"reasoning": (
|
| 47 |
+
"Sufficient evidence provided to support the hypothesis with high confidence."
|
| 48 |
+
),
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Mock chat_completion response structure
|
| 52 |
+
mock_message = MagicMock()
|
| 53 |
+
mock_message.content = f"""Here is the analysis:
|
| 54 |
+
```json
|
| 55 |
+
{json.dumps(data)}
|
| 56 |
+
```"""
|
| 57 |
+
mock_choice = MagicMock()
|
| 58 |
+
mock_choice.message = mock_message
|
| 59 |
+
mock_response = MagicMock()
|
| 60 |
+
mock_response.choices = [mock_choice]
|
| 61 |
+
|
| 62 |
+
# Setup async mock for run_in_executor
|
| 63 |
+
with patch("asyncio.get_running_loop") as mock_loop:
|
| 64 |
+
mock_loop.return_value.run_in_executor = AsyncMock(return_value=mock_response)
|
| 65 |
+
|
| 66 |
+
evidence = [
|
| 67 |
+
Evidence(
|
| 68 |
+
content="test", citation=Citation(source="pubmed", title="t", url="u", date="d")
|
| 69 |
+
)
|
| 70 |
+
]
|
| 71 |
+
result = await handler.assess("test question", evidence)
|
| 72 |
+
|
| 73 |
+
assert result.sufficient is True
|
| 74 |
+
assert result.confidence == 0.85
|
| 75 |
+
assert result.details.drug_candidates == ["Drug A"]
|
| 76 |
+
|
| 77 |
+
@pytest.mark.asyncio
|
| 78 |
+
async def test_assess_fallback_logic(self, handler, mock_client):
|
| 79 |
+
"""Test fallback to secondary model when primary fails."""
|
| 80 |
+
|
| 81 |
+
# Setup async mock to fail first, then succeed
|
| 82 |
+
with patch("asyncio.get_running_loop"):
|
| 83 |
+
# We need to mock the _call_with_retry method directly to test the loop in assess
|
| 84 |
+
# but _call_with_retry is decorated with tenacity,
|
| 85 |
+
# which makes it harder to mock partial failures easily
|
| 86 |
+
# without triggering the tenacity retry loop first.
|
| 87 |
+
# Instead, let's mock run_in_executor to raise exception on first call
|
| 88 |
+
|
| 89 |
+
# This is tricky because assess loops over models,
|
| 90 |
+
# and for each model _call_with_retry retries.
|
| 91 |
+
# We want to simulate: Model 1 fails (retries exhausted) -> Model 2 succeeds.
|
| 92 |
+
|
| 93 |
+
# Let's patch _call_with_retry to avoid waiting for real retries
|
| 94 |
+
side_effect = [
|
| 95 |
+
Exception("Model 1 failed"),
|
| 96 |
+
Exception("Model 2 failed"),
|
| 97 |
+
Exception("Model 3 failed"),
|
| 98 |
+
]
|
| 99 |
+
with patch.object(handler, "_call_with_retry", side_effect=side_effect) as mock_call:
|
| 100 |
+
evidence = []
|
| 101 |
+
result = await handler.assess("test", evidence)
|
| 102 |
+
|
| 103 |
+
# Should have tried all 3 fallback models
|
| 104 |
+
assert mock_call.call_count == 3
|
| 105 |
+
assert result.sufficient is False # Fallback assessment
|
| 106 |
+
error_msg = "All HF models failed"
|
| 107 |
+
assert error_msg in str(mock_call.side_effect) or "failed" in result.reasoning
|
| 108 |
+
|
| 109 |
+
def test_extract_json_robustness(self, handler):
|
| 110 |
+
"""Test JSON extraction with various inputs."""
|
| 111 |
+
|
| 112 |
+
# 1. Clean JSON
|
| 113 |
+
assert handler._extract_json('{"a": 1}') == {"a": 1}
|
| 114 |
+
|
| 115 |
+
# 2. Markdown block
|
| 116 |
+
assert handler._extract_json('```json\n{"a": 1}\n```') == {"a": 1}
|
| 117 |
+
|
| 118 |
+
# 3. Text preamble/postamble
|
| 119 |
+
text = """
|
| 120 |
+
Sure, here is the JSON:
|
| 121 |
+
{
|
| 122 |
+
"a": 1,
|
| 123 |
+
"b": {
|
| 124 |
+
"c": 2
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
Hope that helps!
|
| 128 |
+
"""
|
| 129 |
+
assert handler._extract_json(text) == {"a": 1, "b": {"c": 2}}
|
| 130 |
+
|
| 131 |
+
# 4. Nested braces
|
| 132 |
+
nested = '{"a": {"b": "}"}}'
|
| 133 |
+
assert handler._extract_json(nested) == {"a": {"b": "}"}}
|
| 134 |
+
|
| 135 |
+
# 5. Invalid JSON
|
| 136 |
+
assert handler._extract_json("Not JSON") is None
|
| 137 |
+
assert handler._extract_json("{Incomplete") is None
|
uv.lock
CHANGED
|
@@ -1065,6 +1065,7 @@ dependencies = [
|
|
| 1065 |
{ name = "beautifulsoup4" },
|
| 1066 |
{ name = "gradio", extra = ["mcp"] },
|
| 1067 |
{ name = "httpx" },
|
|
|
|
| 1068 |
{ name = "openai" },
|
| 1069 |
{ name = "pydantic" },
|
| 1070 |
{ name = "pydantic-ai" },
|
|
@@ -1114,6 +1115,7 @@ requires-dist = [
|
|
| 1114 |
{ name = "chromadb", marker = "extra == 'modal'", specifier = ">=0.4.0" },
|
| 1115 |
{ name = "gradio", extras = ["mcp"], specifier = ">=6.0.0" },
|
| 1116 |
{ name = "httpx", specifier = ">=0.27" },
|
|
|
|
| 1117 |
{ name = "llama-index", marker = "extra == 'modal'", specifier = ">=0.11.0" },
|
| 1118 |
{ name = "llama-index-embeddings-openai", marker = "extra == 'modal'" },
|
| 1119 |
{ name = "llama-index-llms-openai", marker = "extra == 'modal'" },
|
|
|
|
| 1065 |
{ name = "beautifulsoup4" },
|
| 1066 |
{ name = "gradio", extra = ["mcp"] },
|
| 1067 |
{ name = "httpx" },
|
| 1068 |
+
{ name = "huggingface-hub" },
|
| 1069 |
{ name = "openai" },
|
| 1070 |
{ name = "pydantic" },
|
| 1071 |
{ name = "pydantic-ai" },
|
|
|
|
| 1115 |
{ name = "chromadb", marker = "extra == 'modal'", specifier = ">=0.4.0" },
|
| 1116 |
{ name = "gradio", extras = ["mcp"], specifier = ">=6.0.0" },
|
| 1117 |
{ name = "httpx", specifier = ">=0.27" },
|
| 1118 |
+
{ name = "huggingface-hub", specifier = ">=0.20.0" },
|
| 1119 |
{ name = "llama-index", marker = "extra == 'modal'", specifier = ">=0.11.0" },
|
| 1120 |
{ name = "llama-index-embeddings-openai", marker = "extra == 'modal'" },
|
| 1121 |
{ name = "llama-index-llms-openai", marker = "extra == 'modal'" },
|